File size: 4,154 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import numpy as np
import torch
import argparse
from PIL import Image
from omegaconf import OmegaConf
from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen
from modules.part_synthesis.process_utils import save_parts_outputs
from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts
from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser()
parser.add_argument("--image-input", type=str, required=True)
parser.add_argument("--mask-input", type=str, required=True)
parser.add_argument("--output-root", type=str, default="./output")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-inference-steps", type=int, default=25)
parser.add_argument("--guidance-scale", type=float, default=3.5)
parser.add_argument("--simplify_ratio", type=float, default=0.3)
parser.add_argument("--partfield_encoder_path", type=str, default="ckpt/model_objaverse.ckpt")
parser.add_argument("--bbox_gen_ckpt", type=str, default="ckpt/bbox_gen.ckpt")
parser.add_argument("--part_synthesis_ckpt", type=str, default="ckpt/part_synthesis")
args = parser.parse_args()
os.makedirs(args.output_root, exist_ok=True)
output_dir = os.path.join(args.output_root, args.image_input.split("/")[-1].split(".")[0])
os.makedirs(output_dir, exist_ok=True)
torch.manual_seed(args.seed)
# load part_synthesis model
part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained(args.part_synthesis_ckpt)
part_synthesis_pipeline.to(device)
print("[INFO] PartSynthesis model loaded")
# load bbox_gen model
bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args
bbox_gen_config.partfield_encoder_path = args.partfield_encoder_path
bbox_gen_model = BboxGen(bbox_gen_config)
bbox_gen_model.load_state_dict(torch.load(args.bbox_gen_ckpt), strict=False)
bbox_gen_model.to(device)
bbox_gen_model.eval().half()
print("[INFO] BboxGen model loaded")
img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(args.image_input, args.mask_input)
img_mask_vis.save(os.path.join(output_dir, "img_mask_vis.png"))
voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=args.seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5})
voxel_coords = voxel_coords.cpu().numpy()
np.save(os.path.join(output_dir, "voxel_coords.npy"), voxel_coords)
voxel_coords_ply = vis_voxel_coords(voxel_coords)
voxel_coords_ply.export(os.path.join(output_dir, "voxel_coords_vis.ply"))
print("[INFO] Voxel coordinates saved")
bbox_gen_input = prepare_bbox_gen_input(os.path.join(output_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input)
bbox_gen_output = bbox_gen_model.generate(bbox_gen_input)
np.save(os.path.join(output_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0])
bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0])
bboxes_vis.export(os.path.join(output_dir, "bboxes_vis.glb"))
print("[INFO] BboxGen output saved")
part_synthesis_input = prepare_part_synthesis_input(os.path.join(output_dir, "voxel_coords.npy"), os.path.join(output_dir, "bboxes.npy"), ordered_mask_input)
part_synthesis_output = part_synthesis_pipeline.get_slat(
img_black_bg,
part_synthesis_input['coords'],
[part_synthesis_input['part_layouts']],
part_synthesis_input['masks'],
seed=args.seed,
slat_sampler_params={"steps": args.num_inference_steps, "cfg_strength": args.guidance_scale},
formats=['mesh', 'gaussian', 'radiance_field'],
preprocess_image=False,
)
save_parts_outputs(
part_synthesis_output,
output_dir=output_dir,
simplify_ratio=args.simplify_ratio,
save_video=False,
save_glb=True,
textured=False,
)
merge_parts(output_dir)
print("[INFO] PartSynthesis output saved") |