import os os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' import numpy as np from typing import Optional from PIL import Image, ImageDraw import torchvision.transforms.functional as TF import cv2 import torch import trimesh import glob from tqdm import tqdm def load_img_mask(img_path, mask_path, size=(518, 518)): image = Image.open(img_path) alpha = np.array(image.getchannel(3)) bbox = np.array(alpha).nonzero() bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 aug_size_ratio = 1.2 aug_hsize = hsize * aug_size_ratio aug_center_offset = [0, 0] aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] img_height, img_width = alpha.shape mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) pad_left = max(0, -aug_bbox[0]) pad_top = max(0, -aug_bbox[1]) pad_right = max(0, aug_bbox[2] - img_width) pad_bottom = max(0, aug_bbox[3] - img_height) if pad_left > 0 or pad_top > 0 or pad_right > 0 or pad_bottom > 0: img_array = np.array(image) padded_img_array = np.pad( img_array, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0 ) padded_mask_array = np.pad(mask, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0) image = Image.fromarray(padded_img_array.astype('uint8')) aug_bbox[0] += pad_left aug_bbox[1] += pad_top aug_bbox[2] += pad_left aug_bbox[3] += pad_top mask = padded_mask_array image = image.crop(aug_bbox) mask = mask[aug_bbox[1]:aug_bbox[3], aug_bbox[0]:aug_bbox[2]] ordered_mask_input, mask_vis = load_bottom_up_mask(mask) image_white_bg = np.array(image) image_black_bg = np.array(image) if image_white_bg.shape[-1] == 4: mask_img = image_white_bg[..., 3] == 0 image_white_bg[mask_img] = [255, 255, 255, 255] image_black_bg[mask_img] = [0, 0, 0, 255] image_white_bg = image_white_bg[..., :3] image_black_bg = image_black_bg[..., :3] img_white_bg = Image.fromarray(image_white_bg.astype('uint8')) img_black_bg = Image.fromarray(image_black_bg.astype('uint8')) img_white_bg = img_white_bg.resize(size, resample=Image.Resampling.LANCZOS) img_black_bg = img_black_bg.resize(size, resample=Image.Resampling.LANCZOS) img_mask_vis = vis_mask_on_img(img_white_bg, mask_vis) img_white_bg = TF.to_tensor(img_white_bg) img_black_bg = TF.to_tensor(img_black_bg) return img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis def load_bottom_up_mask(mask, size=(518, 518)): mask_input = smart_downsample_mask(mask, (37, 37)) mask_vis = cv2.resize(mask_input, (518, 518), interpolation=cv2.INTER_NEAREST) mask_input = np.array(mask_input, dtype=np.int32) unique_indices = np.unique(mask_input) unique_indices = unique_indices[unique_indices > 0] part_positions = {} for idx in unique_indices: y_coords, _ = np.where(mask_input == idx) if len(y_coords) > 0: part_positions[idx] = np.max(y_coords) sorted_parts = sorted(part_positions.items(), key=lambda x: -x[1]) # Sort by y-coordinate in descending order # Create mapping from old indices to new indices (ordered by position) index_map = {} for new_idx, (old_idx, _) in enumerate(sorted_parts, 1): # Start from 1 (0 is background) index_map[old_idx] = new_idx # Apply the mapping to create position-ordered mask ordered_mask_input = np.zeros_like(mask_input) for old_idx, new_idx in index_map.items(): ordered_mask_input[mask_input == old_idx] = new_idx mask_vis = np.array(mask_vis, dtype=np.int32) ordered_mask_input = torch.from_numpy(ordered_mask_input).long() return ordered_mask_input, mask_vis def smart_downsample_mask(mask, target_size): h, w = mask.shape[:2] target_h, target_w = target_size h_ratio = h / target_h w_ratio = w / target_w downsampled = np.zeros((target_h, target_w), dtype=mask.dtype) for i in range(target_h): for j in range(target_w): y_start = int(i * h_ratio) y_end = min(int((i + 1) * h_ratio), h) x_start = int(j * w_ratio) x_end = min(int((j + 1) * w_ratio), w) region = mask[y_start:y_end, x_start:x_end] if region.size == 0: continue unique_values, counts = np.unique(region.flatten(), return_counts=True) non_zero_mask = unique_values > 0 if np.any(non_zero_mask): non_zero_values = unique_values[non_zero_mask] non_zero_counts = counts[non_zero_mask] max_idx = np.argmax(non_zero_counts) downsampled[i, j] = non_zero_values[max_idx] else: max_idx = np.argmax(counts) downsampled[i, j] = unique_values[max_idx] return downsampled def vis_mask_on_img(img, mask): H, W = mask.shape mask_vis = np.zeros((H, W, 3), dtype=np.uint8) + 255 for part_id in range(1, int(mask.max()) + 1): part_mask = (mask == part_id) if part_mask.sum() > 0: color = get_random_color((part_id - 1), use_float=False)[:3] mask_vis[part_mask, 0:3] = color mask_img = Image.fromarray(mask_vis) combined_width = W * 2 combined_height = H combined_img = Image.new('RGB', (combined_width, combined_height), (255, 255, 255)) combined_img.paste(img, (0, 0)) combined_img.paste(mask_img, (W, 0)) draw = ImageDraw.Draw(combined_img) draw.line([(W, 0), (W, H)], fill=(0, 0, 0), width=2) return combined_img def get_random_color(index: Optional[int] = None, use_float: bool = False): # some pleasing colors # matplotlib.colormaps['Set3'].colors + matplotlib.colormaps['Set2'].colors + matplotlib.colormaps['Set1'].colors palette = np.array( [ [141, 211, 199, 255], [255, 255, 179, 255], [190, 186, 218, 255], [251, 128, 114, 255], [128, 177, 211, 255], [253, 180, 98, 255], [179, 222, 105, 255], [252, 205, 229, 255], [217, 217, 217, 255], [188, 128, 189, 255], [204, 235, 197, 255], [255, 237, 111, 255], [102, 194, 165, 255], [252, 141, 98, 255], [141, 160, 203, 255], [231, 138, 195, 255], [166, 216, 84, 255], [255, 217, 47, 255], [229, 196, 148, 255], [179, 179, 179, 255], [228, 26, 28, 255], [55, 126, 184, 255], [77, 175, 74, 255], [152, 78, 163, 255], [255, 127, 0, 255], [255, 255, 51, 255], [166, 86, 40, 255], [247, 129, 191, 255], [153, 153, 153, 255], ], dtype=np.uint8, ) if index is None: index = np.random.randint(0, len(palette)) if index >= len(palette): index = index % len(palette) if use_float: return palette[index].astype(np.float32) / 255 else: return palette[index] def change_pcd_range(pcd, from_rg=(-1,1), to_rg=(-1,1)): pcd = (pcd - (from_rg[0] + from_rg[1]) / 2) / (from_rg[1] - from_rg[0]) * (to_rg[1] - to_rg[0]) + (to_rg[0] + to_rg[1]) / 2 return pcd def prepare_bbox_gen_input(voxel_coords_path, img_white_bg, ordered_mask_input, bins=64, device="cuda"): whole_voxel = np.load(voxel_coords_path) whole_voxel = whole_voxel[:, 1:] whole_voxel = (whole_voxel + 0.5) / bins - 0.5 whole_voxel_index = change_pcd_range(whole_voxel, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins)) whole_voxel_index = (whole_voxel_index * bins).astype(np.int32) points = torch.from_numpy(whole_voxel).to(torch.float16).unsqueeze(0).to(device) whole_voxel_index = torch.from_numpy(whole_voxel_index).long().unsqueeze(0).to(device) images = img_white_bg.unsqueeze(0).to(device) masks = ordered_mask_input.unsqueeze(0).to(device) return { "points": points, "whole_voxel_index": whole_voxel_index, "images": images, "masks": masks, } def vis_voxel_coords(voxel_coords, bins=64): voxel_coords = voxel_coords[:, 1:] voxel_coords = (voxel_coords + 0.5) / bins - 0.5 voxel_coords_ply = trimesh.PointCloud(voxel_coords) rot_matrix = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) voxel_coords_ply.apply_transform(rot_matrix) return voxel_coords_ply def gen_mesh_from_bounds(bounds): bboxes = [] rot_matrix = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) for j in range(bounds.shape[0]): bbox = trimesh.primitives.Box(bounds=bounds[j]) color = get_random_color(j, use_float=True) bbox.visual.vertex_colors = color bboxes.append(bbox) mesh = trimesh.Scene(bboxes) mesh.apply_transform(rot_matrix) return mesh def prepare_part_synthesis_input(voxel_coords_path, bbox_depth_path, ordered_mask_input, padding_size=2, bins=64, device="cuda"): overall_coords = np.load(voxel_coords_path) overall_coords = overall_coords[:, 1:] # Remove first column bbox_scene = np.load(bbox_depth_path) all_coords_wnoise = [] part_layouts = [] start_idx = 0 part_layouts.append(slice(start_idx, start_idx + overall_coords.shape[0])) start_idx += overall_coords.shape[0] assigned_points = np.zeros(overall_coords.shape[0], dtype=bool) bbox_coords_list = [] bbox_masks = [] for bbox in bbox_scene: points = change_pcd_range(bbox, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins)) bbox_min = np.floor(points[0] * bins).astype(np.int32) bbox_max = np.ceil(points[1] * bins).astype(np.int32) bbox_min = np.clip(bbox_min - padding_size, 0, bins - 1) bbox_max = np.clip(bbox_max + padding_size, 0, bins - 1) bbox_mask = np.all((overall_coords >= bbox_min) & (overall_coords <= bbox_max), axis=1) bbox_masks.append(bbox_mask) if np.sum(bbox_mask) == 0: continue assigned_points = assigned_points | bbox_mask bbox_coords = overall_coords[bbox_mask] bbox_coords_list.append(bbox_coords) part_layouts.append(slice(start_idx, start_idx + bbox_coords.shape[0])) start_idx += bbox_coords.shape[0] bbox_coords = torch.from_numpy(bbox_coords) all_coords_wnoise.append(bbox_coords) unassigned_mask = ~assigned_points unassigned_coords = overall_coords[unassigned_mask] if np.sum(unassigned_mask) > 0 and len(bbox_scene) > 0: print(f"Assigning {np.sum(unassigned_mask)} unassigned points to nearest bboxes") nearest_bbox_indices = [] for point_idx, point in enumerate(unassigned_coords): min_dist = float('inf') nearest_idx = -1 for bbox_idx, bbox in enumerate(bbox_scene): points = change_pcd_range(bbox, from_rg=(-0.5, 0.5), to_rg=(0.5/bins, 1-0.5/bins)) bbox_min = np.floor(points[0] * bins).astype(np.int32) bbox_max = np.ceil(points[1] * bins).astype(np.int32) dx = min(abs(point[0] - bbox_min[0]), abs(point[0] - bbox_max[0])) dy = min(abs(point[1] - bbox_min[1]), abs(point[1] - bbox_max[1])) dz = min(abs(point[2] - bbox_min[2]), abs(point[2] - bbox_max[2])) # dist = dx + dy + dz dist = min(dx, dy, dz) if dist < min_dist: min_dist = dist; nearest_idx = bbox_idx nearest_bbox_indices.append(nearest_idx) for bbox_idx in range(len(bbox_scene)): points_for_this_bbox = np.array([i for i, idx in enumerate(nearest_bbox_indices) if idx == bbox_idx]) if len(points_for_this_bbox) > 0: additional_coords = unassigned_coords[points_for_this_bbox] if bbox_idx < len(bbox_coords_list): combined_coords = np.vstack([bbox_coords_list[bbox_idx], additional_coords]) old_slice = part_layouts[bbox_idx + 1] # +1 because first slice is whole model new_slice = slice(old_slice.start, old_slice.start + combined_coords.shape[0]) part_layouts[bbox_idx + 1] = new_slice additional_points = additional_coords.shape[0] for i in range(bbox_idx + 2, len(part_layouts)): old_slice = part_layouts[i] new_slice = slice(old_slice.start + additional_points, old_slice.stop + additional_points) part_layouts[i] = new_slice all_coords_wnoise[bbox_idx] = torch.from_numpy(combined_coords) start_idx += additional_points else: part_layouts.append(slice(start_idx, start_idx + additional_coords.shape[0])) start_idx += additional_coords.shape[0] all_coords_wnoise.append(torch.from_numpy(additional_coords)) overall_coords = torch.from_numpy(overall_coords) all_coords_wnoise.insert(0, overall_coords) combined_coords = torch.cat(all_coords_wnoise, dim=0).int() coords = torch.cat( [torch.full((combined_coords.shape[0], 1), 0, dtype=torch.int32), combined_coords], dim=-1 ).to(device) masks = ordered_mask_input.unsqueeze(0).to(device) return { 'coords': coords, 'part_layouts': part_layouts, 'masks': masks, } def merge_parts(save_dir): scene_list = [] scene_list_texture = [] part_list = glob.glob(os.path.join(save_dir, "*.glb")) part_list = [p for p in part_list if "part" in p and "parts" not in p and "part0" not in p] # part 0 is the overall model part_list.sort() for i, part_path in enumerate(tqdm(part_list, desc="Merging parts")): part_mesh = trimesh.load(part_path, force='mesh') scene_list_texture.append(part_mesh) random_color = get_random_color(i, use_float=True) part_mesh_color = part_mesh.copy() part_mesh_color.visual = trimesh.visual.ColorVisuals( mesh=part_mesh_color, vertex_colors=random_color ) scene_list.append(part_mesh_color) os.remove(part_path) scene_texture = trimesh.Scene(scene_list_texture) scene_texture.export(os.path.join(save_dir, "mesh_textured.glb")) scene = trimesh.Scene(scene_list) scene.export(os.path.join(save_dir, "mesh_segment.glb"))