OmniPart / modules /inference_utils.py
omnipart's picture
init
491eded
raw
history blame
15.3 kB
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
import numpy as np
from typing import Optional
from PIL import Image, ImageDraw
import torchvision.transforms.functional as TF
import cv2
import torch
import trimesh
import glob
from tqdm import tqdm
def load_img_mask(img_path, mask_path, size=(518, 518)):
image = Image.open(img_path)
alpha = np.array(image.getchannel(3))
bbox = np.array(alpha).nonzero()
bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()]
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
aug_size_ratio = 1.2
aug_hsize = hsize * aug_size_ratio
aug_center_offset = [0, 0]
aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]]
aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)]
img_height, img_width = alpha.shape
mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
pad_left = max(0, -aug_bbox[0])
pad_top = max(0, -aug_bbox[1])
pad_right = max(0, aug_bbox[2] - img_width)
pad_bottom = max(0, aug_bbox[3] - img_height)
if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0:
img_array = np.array(image)
padded_img_array = np.pad(
img_array,
((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
mode='constant',
constant_values=0
)
padded_mask_array = np.pad(mask, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)
image = Image.fromarray(padded_img_array.astype('uint8'))
aug_bbox[0] += pad_left
aug_bbox[1] += pad_top
aug_bbox[2] += pad_left
aug_bbox[3] += pad_top
mask = padded_mask_array
image = image.crop(aug_bbox)
mask = mask[aug_bbox[1]:aug_bbox[3], aug_bbox[0]:aug_bbox[2]]
ordered_mask_input, mask_vis = load_bottom_up_mask(mask)
image_white_bg = np.array(image)
image_black_bg = np.array(image)
if image_white_bg.shape[-1] == 4:
mask_img = image_white_bg[..., 3] == 0
image_white_bg[mask_img] = [255, 255, 255, 255]
image_black_bg[mask_img] = [0, 0, 0, 255]
image_white_bg = image_white_bg[..., :3]
image_black_bg = image_black_bg[..., :3]
img_white_bg = Image.fromarray(image_white_bg.astype('uint8'))
img_black_bg = Image.fromarray(image_black_bg.astype('uint8'))
img_white_bg = img_white_bg.resize(size, resample=Image.Resampling.LANCZOS)
img_black_bg = img_black_bg.resize(size, resample=Image.Resampling.LANCZOS)
img_mask_vis = vis_mask_on_img(img_white_bg, mask_vis)
img_white_bg = TF.to_tensor(img_white_bg)
img_black_bg = TF.to_tensor(img_black_bg)
return img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis
def load_bottom_up_mask(mask, size=(518, 518)):
mask_input = smart_downsample_mask(mask, (37, 37))
mask_vis = cv2.resize(mask_input, (518, 518), interpolation=cv2.INTER_NEAREST)
mask_input = np.array(mask_input, dtype=np.int32)
unique_indices = np.unique(mask_input)
unique_indices = unique_indices[unique_indices > 0]
part_positions = {}
for idx in unique_indices:
y_coords, _ = np.where(mask_input == idx)
if len(y_coords) > 0:
part_positions[idx] = np.max(y_coords)
sorted_parts = sorted(part_positions.items(), key=lambda x: -x[1]) # Sort by y-coordinate in descending order
# Create mapping from old indices to new indices (ordered by position)
index_map = {}
for new_idx, (old_idx, _) in enumerate(sorted_parts, 1): # Start from 1 (0 is background)
index_map[old_idx] = new_idx
# Apply the mapping to create position-ordered mask
ordered_mask_input = np.zeros_like(mask_input)
for old_idx, new_idx in index_map.items():
ordered_mask_input[mask_input == old_idx] = new_idx
mask_vis = np.array(mask_vis, dtype=np.int32)
ordered_mask_input = torch.from_numpy(ordered_mask_input).long()
return ordered_mask_input, mask_vis
def smart_downsample_mask(mask, target_size):
h, w = mask.shape[:2]
target_h, target_w = target_size
h_ratio = h / target_h
w_ratio = w / target_w
downsampled = np.zeros((target_h, target_w), dtype=mask.dtype)
for i in range(target_h):
for j in range(target_w):
y_start = int(i * h_ratio)
y_end = min(int((i + 1) * h_ratio), h)
x_start = int(j * w_ratio)
x_end = min(int((j + 1) * w_ratio), w)
region = mask[y_start:y_end, x_start:x_end]
if region.size == 0:
continue
unique_values, counts = np.unique(region.flatten(), return_counts=True)
non_zero_mask = unique_values > 0
if np.any(non_zero_mask):
non_zero_values = unique_values[non_zero_mask]
non_zero_counts = counts[non_zero_mask]
max_idx = np.argmax(non_zero_counts)
downsampled[i, j] = non_zero_values[max_idx]
else:
max_idx = np.argmax(counts)
downsampled[i, j] = unique_values[max_idx]
return downsampled
def vis_mask_on_img(img, mask):
H, W = mask.shape
mask_vis = np.zeros((H, W, 3), dtype=np.uint8) + 255
for part_id in range(1, int(mask.max()) + 1):
part_mask = (mask == part_id)
if part_mask.sum() > 0:
color = get_random_color((part_id - 1), use_float=False)[:3]
mask_vis[part_mask, 0:3] = color
mask_img = Image.fromarray(mask_vis)
combined_width = W * 2
combined_height = H
combined_img = Image.new('RGB', (combined_width, combined_height), (255, 255, 255))
combined_img.paste(img, (0, 0))
combined_img.paste(mask_img, (W, 0))
draw = ImageDraw.Draw(combined_img)
draw.line([(W, 0), (W, H)], fill=(0, 0, 0), width=2)
return combined_img
def get_random_color(index: Optional[int] = None, use_float: bool = False):
# some pleasing colors
# matplotlib.colormaps['Set3'].colors + matplotlib.colormaps['Set2'].colors + matplotlib.colormaps['Set1'].colors
palette = np.array(
[
[141, 211, 199, 255],
[255, 255, 179, 255],
[190, 186, 218, 255],
[251, 128, 114, 255],
[128, 177, 211, 255],
[253, 180, 98, 255],
[179, 222, 105, 255],
[252, 205, 229, 255],
[217, 217, 217, 255],
[188, 128, 189, 255],
[204, 235, 197, 255],
[255, 237, 111, 255],
[102, 194, 165, 255],
[252, 141, 98, 255],
[141, 160, 203, 255],
[231, 138, 195, 255],
[166, 216, 84, 255],
[255, 217, 47, 255],
[229, 196, 148, 255],
[179, 179, 179, 255],
[228, 26, 28, 255],
[55, 126, 184, 255],
[77, 175, 74, 255],
[152, 78, 163, 255],
[255, 127, 0, 255],
[255, 255, 51, 255],
[166, 86, 40, 255],
[247, 129, 191, 255],
[153, 153, 153, 255],
],
dtype=np.uint8,
)
if index is None:
index = np.random.randint(0, len(palette))
if index >= len(palette):
index = index % len(palette)
if use_float:
return palette[index].astype(np.float32) / 255
else:
return palette[index]
def change_pcd_range(pcd, from_rg=(-1,1), to_rg=(-1,1)):
pcd = (pcd - (from_rg[0] + from_rg[1]) / 2) / (from_rg[1] - from_rg[0]) * (to_rg[1] - to_rg[0]) + (to_rg[0] + to_rg[1]) / 2
return pcd
def prepare_bbox_gen_input(voxel_coords_path, img_white_bg, ordered_mask_input, bins=64, device="cuda"):
whole_voxel = np.load(voxel_coords_path)
whole_voxel = whole_voxel[:, 1:]
whole_voxel = (whole_voxel + 0.5) / bins - 0.5
whole_voxel_index = change_pcd_range(whole_voxel, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins))
whole_voxel_index = (whole_voxel_index * bins).astype(np.int32)
points = torch.from_numpy(whole_voxel).to(torch.float16).unsqueeze(0).to(device)
whole_voxel_index = torch.from_numpy(whole_voxel_index).long().unsqueeze(0).to(device)
images = img_white_bg.unsqueeze(0).to(device)
masks = ordered_mask_input.unsqueeze(0).to(device)
return {
"points": points,
"whole_voxel_index": whole_voxel_index,
"images": images,
"masks": masks,
}
def vis_voxel_coords(voxel_coords, bins=64):
voxel_coords = voxel_coords[:, 1:]
voxel_coords = (voxel_coords + 0.5) / bins - 0.5
voxel_coords_ply = trimesh.PointCloud(voxel_coords)
rot_matrix = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
voxel_coords_ply.apply_transform(rot_matrix)
return voxel_coords_ply
def gen_mesh_from_bounds(bounds):
bboxes = []
rot_matrix = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
for j in range(bounds.shape[0]):
bbox = trimesh.primitives.Box(bounds=bounds[j])
color = get_random_color(j, use_float=True)
bbox.visual.vertex_colors = color
bboxes.append(bbox)
mesh = trimesh.Scene(bboxes)
mesh.apply_transform(rot_matrix)
return mesh
def prepare_part_synthesis_input(voxel_coords_path, bbox_depth_path, ordered_mask_input, padding_size=2, bins=64, device="cuda"):
overall_coords = np.load(voxel_coords_path)
overall_coords = overall_coords[:, 1:] # Remove first column
bbox_scene = np.load(bbox_depth_path)
all_coords_wnoise = []
part_layouts = []
start_idx = 0
part_layouts.append(slice(start_idx, start_idx + overall_coords.shape[0]))
start_idx += overall_coords.shape[0]
assigned_points = np.zeros(overall_coords.shape[0], dtype=bool)
bbox_coords_list = []
bbox_masks = []
for bbox in bbox_scene:
points = change_pcd_range(bbox, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins))
bbox_min = np.floor(points[0] * bins).astype(np.int32)
bbox_max = np.ceil(points[1] * bins).astype(np.int32)
bbox_min = np.clip(bbox_min - padding_size, 0, bins - 1)
bbox_max = np.clip(bbox_max + padding_size, 0, bins - 1)
bbox_mask = np.all((overall_coords >= bbox_min) & (overall_coords <= bbox_max), axis=1)
bbox_masks.append(bbox_mask)
if np.sum(bbox_mask) == 0:
continue
assigned_points = assigned_points | bbox_mask
bbox_coords = overall_coords[bbox_mask]
bbox_coords_list.append(bbox_coords)
part_layouts.append(slice(start_idx, start_idx + bbox_coords.shape[0]))
start_idx += bbox_coords.shape[0]
bbox_coords = torch.from_numpy(bbox_coords)
all_coords_wnoise.append(bbox_coords)
unassigned_mask = ~assigned_points
unassigned_coords = overall_coords[unassigned_mask]
if np.sum(unassigned_mask) > 0 and len(bbox_scene) > 0:
print(f"Assigning {np.sum(unassigned_mask)} unassigned points to nearest bboxes")
nearest_bbox_indices = []
for point_idx, point in enumerate(unassigned_coords):
min_dist = float('inf')
nearest_idx = -1
for bbox_idx, bbox in enumerate(bbox_scene):
points = change_pcd_range(bbox, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins))
bbox_min = np.floor(points[0] * bins).astype(np.int32)
bbox_max = np.ceil(points[1] * bins).astype(np.int32)
dx = min(abs(point[0] - bbox_min[0]), abs(point[0] - bbox_max[0]))
dy = min(abs(point[1] - bbox_min[1]), abs(point[1] - bbox_max[1]))
dz = min(abs(point[2] - bbox_min[2]), abs(point[2] - bbox_max[2]))
# dist = dx + dy + dz
dist = min(dx, dy, dz)
if dist < min_dist:
min_dist = dist;
nearest_idx = bbox_idx
nearest_bbox_indices.append(nearest_idx)
for bbox_idx in range(len(bbox_scene)):
points_for_this_bbox = np.array([i for i, idx in enumerate(nearest_bbox_indices) if idx == bbox_idx])
if len(points_for_this_bbox) > 0:
additional_coords = unassigned_coords[points_for_this_bbox]
if bbox_idx < len(bbox_coords_list):
combined_coords = np.vstack([bbox_coords_list[bbox_idx], additional_coords])
old_slice = part_layouts[bbox_idx + 1] # +1 because first slice is whole model
new_slice = slice(old_slice.start, old_slice.start + combined_coords.shape[0])
part_layouts[bbox_idx + 1] = new_slice
additional_points = additional_coords.shape[0]
for i in range(bbox_idx + 2, len(part_layouts)):
old_slice = part_layouts[i]
new_slice = slice(old_slice.start + additional_points, old_slice.stop + additional_points)
part_layouts[i] = new_slice
all_coords_wnoise[bbox_idx] = torch.from_numpy(combined_coords)
start_idx += additional_points
else:
part_layouts.append(slice(start_idx, start_idx + additional_coords.shape[0]))
start_idx += additional_coords.shape[0]
all_coords_wnoise.append(torch.from_numpy(additional_coords))
overall_coords = torch.from_numpy(overall_coords)
all_coords_wnoise.insert(0, overall_coords)
combined_coords = torch.cat(all_coords_wnoise, dim=0).int()
coords = torch.cat(
[torch.full((combined_coords.shape[0], 1), 0, dtype=torch.int32), combined_coords],
dim=-1
).to(device)
masks = ordered_mask_input.unsqueeze(0).to(device)
return {
'coords': coords,
'part_layouts': part_layouts,
'masks': masks,
}
def merge_parts(save_dir):
scene_list = []
scene_list_texture = []
part_list = glob.glob(os.path.join(save_dir, "*.glb"))
part_list = [p for p in part_list if "part" in p and "parts" not in p and "part0" not in p] # part 0 is the overall model
part_list.sort()
for i, part_path in enumerate(tqdm(part_list, desc="Merging parts")):
part_mesh = trimesh.load(part_path, force='mesh')
scene_list_texture.append(part_mesh)
random_color = get_random_color(i, use_float=True)
part_mesh_color = part_mesh.copy()
part_mesh_color.visual = trimesh.visual.ColorVisuals(
mesh=part_mesh_color,
vertex_colors=random_color
)
scene_list.append(part_mesh_color)
os.remove(part_path)
scene_texture = trimesh.Scene(scene_list_texture)
scene_texture.export(os.path.join(save_dir, "mesh_textured.glb"))
scene = trimesh.Scene(scene_list)
scene.export(os.path.join(save_dir, "mesh_segment.glb"))