OmniPart / modules /part_synthesis /pipelines /omnipart_image_to_parts.py
omnipart's picture
init
491eded
from typing import *
from contextlib import contextmanager
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from PIL import Image
import rembg
from transformers import AutoModel
from .base import Pipeline
from . import samplers
from ..modules import sparse as sp
from ..modules.sparse.basic import SparseTensor, sparse_cat
class OmniPartImageTo3DPipeline(Pipeline):
"""
Pipeline for inferring OmniPart image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
slat_sampler (samplers.Sampler): The sampler for the structured latent.
slat_normalization (dict): The normalization parameters for the structured latent.
image_cond_model (str): The name of the image conditioning model.
"""
def __init__(
self,
models: Dict[str, nn.Module] = None,
sparse_structure_sampler: samplers.Sampler = None,
slat_sampler: samplers.Sampler = None,
slat_normalization: dict = None,
image_cond_model: str = None,
):
# Skip initialization if models is None (used in from_pretrained)
if models is None:
return
super().__init__(models)
self.sparse_structure_sampler = sparse_structure_sampler
self.slat_sampler = slat_sampler
self.sparse_structure_sampler_params = {}
self.slat_sampler_params = {}
self.slat_normalization = slat_normalization
self.rembg_session = None
self._init_image_cond_model(image_cond_model)
@staticmethod
def from_pretrained(path: str) -> "OmniPartImageTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
Returns:
OmniPartImageTo3DPipeline: Loaded pipeline instance
"""
pipeline = super(OmniPartImageTo3DPipeline, OmniPartImageTo3DPipeline).from_pretrained(path)
new_pipeline = OmniPartImageTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
args = pipeline._pretrained_args
# Initialize samplers from saved arguments
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(
**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(
**args['slat_sampler']['args'])
new_pipeline.slat_sampler_params = args['slat_sampler']['params']
new_pipeline.slat_normalization = args['slat_normalization']
new_pipeline._init_image_cond_model(args['image_cond_model'])
return new_pipeline
def _init_image_cond_model(self, name: str):
"""
Initialize the image conditioning model.
Args:
name (str): Name of the DINOv2 model to load
"""
dinov2_model = torch.hub.load('facebookresearch/dinov2', name)
dinov2_model.eval()
self.models['image_cond_model'] = dinov2_model
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.image_cond_model_transform = transform
def preprocess_image(self, input: Image.Image, size=(518, 518)) -> Image.Image:
"""
Preprocess the input image for the model.
Args:
input (Image.Image): Input image
size (tuple): Target size for resizing
Returns:
Image.Image: Preprocessed image
"""
img = np.array(input)
if img.shape[-1] == 4:
# Handle alpha channel by replacing transparent pixels with black
mask_img = img[..., 3] == 0
img[mask_img] = [0, 0, 0, 255]
img = img[..., :3]
img_rgb = Image.fromarray(img.astype('uint8'))
# Resize to target size
img_rgb = img_rgb.resize(size, resample=Image.Resampling.BILINEAR)
return img_rgb
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Encode the image using the conditioning model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image(s) to encode
Returns:
torch.Tensor: The encoded features
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
# Convert PIL images to tensors
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).to(self.device)
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
# Apply normalization and run through DINOv2 model
image = self.image_cond_model_transform(image).to(self.device)
features = self.models['image_cond_model'](image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
def get_cond(self, image: Union[torch.Tensor, List[Image.Image]]) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: Dictionary with conditioning information
"""
cond = self.encode_image(image)
neg_cond = torch.zeros_like(cond) # Negative conditioning (zero)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def sample_sparse_structure(
self,
cond: dict,
num_samples: int = 1,
sampler_params: dict = {},
save_coords: bool = False,
) -> torch.Tensor:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
save_coords (bool): Whether to save coordinates internally.
Returns:
torch.Tensor: Coordinates of the sparse structure
"""
# Sample occupancy latent
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
# Merge default and custom sampler parameters
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
# Generate samples using the sampler
z_s = self.sparse_structure_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
# Decode occupancy latent to get coordinates
decoder = self.models['sparse_structure_decoder']
coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
if save_coords:
self.save_coordinates = coords
return coords
@torch.no_grad()
def get_coords(
self,
image: Union[Image.Image, List[Image.Image]],
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
preprocess_image: bool = True,
save_coords: bool = False,
) -> dict:
"""
Get coordinates of the sparse structure from an input image.
Args:
image: Input image or list of images
num_samples: Number of samples to generate
seed: Random seed
sparse_structure_sampler_params: Additional parameters for the sparse structure sampler
preprocess_image: Whether to preprocess the image
save_coords: Whether to save coordinates internally
Returns:
torch.Tensor: Coordinates of the sparse structure
"""
if isinstance(image, Image.Image):
if preprocess_image:
image = self.preprocess_image(image)
cond = self.get_cond([image])
torch.manual_seed(seed)
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, save_coords)
return coords
elif isinstance(image, torch.Tensor):
cond = self.get_cond(image.unsqueeze(0))
torch.manual_seed(seed)
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, save_coords)
return coords
elif isinstance(image, list):
if preprocess_image:
image = [self.preprocess_image(i) for i in image]
cond = self.get_cond(image)
torch.manual_seed(seed)
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, save_coords)
return coords
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
def sample_slat(
self,
cond: dict,
coords: torch.Tensor,
part_layouts: List[slice] = None,
masks: torch.Tensor = None,
sampler_params: dict = {},
**kwargs
) -> sp.SparseTensor:
# Sample structured latent
flow_model = self.models['slat_flow_model']
# Create noise tensor with same coordinates as the sparse structure
noise = sp.SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
# Merge default and custom sampler parameters
sampler_params = {**self.slat_sampler_params, **sampler_params}
# Add part information if provided
if part_layouts is not None:
kwargs['part_layouts'] = part_layouts
if masks is not None:
kwargs['masks'] = masks
# Generate samples
slat = self.slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True,
**kwargs
).samples
# Normalize the features
feat_dim = slat.feats.shape[1]
base_std = torch.tensor(self.slat_normalization['std']).to(slat.device)
base_mean = torch.tensor(self.slat_normalization['mean']).to(slat.device)
# Handle different dimensionality cases
if feat_dim == len(base_std):
# Dimensions match, apply directly
std = base_std[None, :]
mean = base_mean[None, :]
elif feat_dim == 8 and len(base_std) == 9:
# Use first 8 dimensions when latent is 8-dimensional but normalization is 9-dimensional
std = base_std[:8][None, :]
mean = base_mean[:8][None, :]
print(f"Warning: Normalizing {feat_dim}-dimensional features with first 8 dimensions of 9-dimensional parameters")
else:
# Handle general case of dimension mismatch
std = torch.ones((1, feat_dim), device=slat.device)
mean = torch.zeros((1, feat_dim), device=slat.device)
copy_dim = min(feat_dim, len(base_std))
std[0, :copy_dim] = base_std[:copy_dim]
mean[0, :copy_dim] = base_mean[:copy_dim]
print(f"Warning: Feature dimensions mismatch. Using {copy_dim} dimensions for normalization")
# Apply normalization
slat = slat * std + mean
return slat
@torch.no_grad()
def get_slat(
self,
image: Union[Image.Image, List[Image.Image], torch.Tensor],
coords: torch.Tensor,
part_layouts: List[slice],
masks: torch.Tensor,
seed: int = 42,
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
preprocess_image: bool = True,
) -> dict:
if isinstance(image, Image.Image):
if preprocess_image:
image = self.preprocess_image(image)
cond = self.get_cond([image])
torch.manual_seed(seed)
slat = self.sample_slat(cond, coords, part_layouts, masks, slat_sampler_params)
return self.decode_slat(self.divide_slat(slat, part_layouts), formats)
elif isinstance(image, list):
if preprocess_image:
image = [self.preprocess_image(i) for i in image]
cond = self.get_cond(image)
torch.manual_seed(seed)
slat = self.sample_slat(cond, coords, part_layouts, masks, slat_sampler_params)
return self.decode_slat(self.divide_slat(slat, part_layouts), formats)
elif isinstance(image, torch.Tensor):
cond = self.get_cond(image.unsqueeze(0))
torch.manual_seed(seed)
slat = self.sample_slat(cond, coords, part_layouts, masks, slat_sampler_params)
return self.decode_slat(self.divide_slat(slat, part_layouts), formats)
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
def decode_slat(
self,
slat: sp.SparseTensor,
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
) -> dict:
"""
Decode the structured latent.
Args:
slat (sp.SparseTensor): The structured latent
formats (List[str]): The formats to decode to
Returns:
dict: Decoded outputs in requested formats
"""
ret = {}
if 'mesh' in formats:
ret['mesh'] = self.models['slat_decoder_mesh'](slat)
if 'gaussian' in formats:
ret['gaussian'] = self.models['slat_decoder_gs'](slat)
if 'radiance_field' in formats:
ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
return ret
def divide_slat(
self,
slat: sp.SparseTensor,
part_layouts: List[slice],
) -> List[sp.SparseTensor]:
"""
Divide the structured latent into parts.
Args:
slat (sp.SparseTensor): The structured latent
part_layouts (List[slice]): Layout information for parts
Returns:
sp.SparseTensor: Processed and divided latent
"""
sparse_part = []
for part_id, part_layout in enumerate(part_layouts):
for part_obj_id, part_slice in enumerate(part_layout):
part_x_sparse_tensor = SparseTensor(
coords=slat[part_id].coords[part_slice],
feats=slat[part_id].feats[part_slice],
)
sparse_part.append(part_x_sparse_tensor)
slat = sparse_cat(sparse_part)
return self.remove_noise(slat)
def remove_noise(self, z_batch):
"""
Remove noise from latent vectors by filtering out points with low confidence.
Args:
z_batch: Latent vectors to process
Returns:
sp.SparseTensor: Processed latent with noise removed
"""
# Create a new list for processed tensors
processed_batch = []
for i, z in enumerate(z_batch):
coords = z.coords
feats = z.feats
# Only filter if features have a confidence dimension (9th dimension)
if feats.shape[1] == 9:
# Get the confidence values (last dimension)
last_dim = feats[:, -1]
sigmoid_val = torch.sigmoid(last_dim)
# Calculate filtering statistics
total_points = coords.shape[0]
to_keep = sigmoid_val >= 0.5
kept_points = to_keep.sum().item()
discarded_points = total_points - kept_points
discard_percentage = (discarded_points / total_points) * 100 if total_points > 0 else 0
if kept_points == 0:
print(f"No points kept for part {i}")
continue
print(f"Discarded {discarded_points}/{total_points} points ({discard_percentage:.2f}%)")
# Filter coordinates and features
coords = coords[to_keep]
feats = feats[to_keep]
feats = feats[:, :-1] # Remove the confidence dimension
# Create a filtered SparseTensor
processed_z = z.replace(coords=coords, feats=feats)
else:
processed_z = z
processed_batch.append(processed_z)
return sparse_cat(processed_batch)
@contextmanager
def inject_sampler_multi_image(
self,
sampler_name: str,
num_images: int,
num_steps: int,
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
):
"""
Inject a sampler with multiple images as condition.
Args:
sampler_name (str): The name of the sampler to inject
num_images (int): The number of images to condition on
num_steps (int): The number of steps to run the sampler for
mode (str): Sampling strategy ('stochastic' or 'multidiffusion')
"""
sampler = getattr(self, sampler_name)
setattr(sampler, f'_old_inference_model', sampler._inference_model)
if mode == 'stochastic':
if num_images > num_steps:
print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
"This may lead to performance degradation.\033[0m")
# Create schedule for which image to use at each step
cond_indices = (np.arange(num_steps) % num_images).tolist()
def _new_inference_model(self, model, x_t, t, cond, **kwargs):
cond_idx = cond_indices.pop(0)
cond_i = cond[cond_idx:cond_idx+1]
return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
elif mode == 'multidiffusion':
from .samplers import FlowEulerSampler
def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
if cfg_interval[0] <= t <= cfg_interval[1]:
# Average predictions from all conditions when within CFG interval
preds = []
for i in range(len(cond)):
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
pred = sum(preds) / len(preds)
neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
else:
# Average predictions from all conditions when outside CFG interval
preds = []
for i in range(len(cond)):
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
pred = sum(preds) / len(preds)
return pred
else:
raise ValueError(f"Unsupported mode: {mode}")
sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
try:
yield
finally:
# Restore original inference model
sampler._inference_model = sampler._old_inference_model
delattr(sampler, f'_old_inference_model')