|
import os |
|
import numpy as np |
|
import torch |
|
import argparse |
|
from PIL import Image |
|
from omegaconf import OmegaConf |
|
|
|
from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen |
|
from modules.part_synthesis.process_utils import save_parts_outputs |
|
from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts |
|
from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline |
|
|
|
if __name__ == "__main__": |
|
device = "cuda" |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--image-input", type=str, required=True) |
|
parser.add_argument("--mask-input", type=str, required=True) |
|
parser.add_argument("--output-root", type=str, default="./output") |
|
parser.add_argument("--seed", type=int, default=42) |
|
parser.add_argument("--num-inference-steps", type=int, default=25) |
|
parser.add_argument("--guidance-scale", type=float, default=3.5) |
|
parser.add_argument("--simplify_ratio", type=float, default=0.3) |
|
parser.add_argument("--partfield_encoder_path", type=str, default="ckpt/model_objaverse.ckpt") |
|
parser.add_argument("--bbox_gen_ckpt", type=str, default="ckpt/bbox_gen.ckpt") |
|
parser.add_argument("--part_synthesis_ckpt", type=str, default="ckpt/part_synthesis") |
|
args = parser.parse_args() |
|
|
|
os.makedirs(args.output_root, exist_ok=True) |
|
output_dir = os.path.join(args.output_root, args.image_input.split("/")[-1].split(".")[0]) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained(args.part_synthesis_ckpt) |
|
part_synthesis_pipeline.to(device) |
|
print("[INFO] PartSynthesis model loaded") |
|
|
|
|
|
bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args |
|
bbox_gen_config.partfield_encoder_path = args.partfield_encoder_path |
|
bbox_gen_model = BboxGen(bbox_gen_config) |
|
bbox_gen_model.load_state_dict(torch.load(args.bbox_gen_ckpt), strict=False) |
|
bbox_gen_model.to(device) |
|
bbox_gen_model.eval().half() |
|
print("[INFO] BboxGen model loaded") |
|
|
|
img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(args.image_input, args.mask_input) |
|
img_mask_vis.save(os.path.join(output_dir, "img_mask_vis.png")) |
|
|
|
voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=args.seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5}) |
|
voxel_coords = voxel_coords.cpu().numpy() |
|
np.save(os.path.join(output_dir, "voxel_coords.npy"), voxel_coords) |
|
voxel_coords_ply = vis_voxel_coords(voxel_coords) |
|
voxel_coords_ply.export(os.path.join(output_dir, "voxel_coords_vis.ply")) |
|
print("[INFO] Voxel coordinates saved") |
|
|
|
bbox_gen_input = prepare_bbox_gen_input(os.path.join(output_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input) |
|
bbox_gen_output = bbox_gen_model.generate(bbox_gen_input) |
|
np.save(os.path.join(output_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0]) |
|
bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0]) |
|
bboxes_vis.export(os.path.join(output_dir, "bboxes_vis.glb")) |
|
print("[INFO] BboxGen output saved") |
|
|
|
part_synthesis_input = prepare_part_synthesis_input(os.path.join(output_dir, "voxel_coords.npy"), os.path.join(output_dir, "bboxes.npy"), ordered_mask_input) |
|
part_synthesis_output = part_synthesis_pipeline.get_slat( |
|
img_black_bg, |
|
part_synthesis_input['coords'], |
|
[part_synthesis_input['part_layouts']], |
|
part_synthesis_input['masks'], |
|
seed=args.seed, |
|
slat_sampler_params={"steps": args.num_inference_steps, "cfg_strength": args.guidance_scale}, |
|
formats=['mesh', 'gaussian', 'radiance_field'], |
|
preprocess_image=False, |
|
) |
|
save_parts_outputs( |
|
part_synthesis_output, |
|
output_dir=output_dir, |
|
simplify_ratio=args.simplify_ratio, |
|
save_video=False, |
|
save_glb=True, |
|
textured=False, |
|
) |
|
merge_parts(output_dir) |
|
print("[INFO] PartSynthesis output saved") |