import spaces import gradio as gr import os import time import torch import numpy as np from PIL import Image from huggingface_hub import snapshot_download from accelerate.utils import set_seed import trimesh from src.utils.data_utils import get_colored_mesh_composition, export_renderings from src.utils.image_utils import prepare_image from src.pipelines.pipeline_partcrafter import PartCrafterPipeline from src.models.briarmbg import BriaRMBG # Constants MAX_NUM_PARTS = 16 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 # Download and initialize models partcrafter_weights_dir = "pretrained_weights/PartCrafter" rmbg_weights_dir = "pretrained_weights/RMBG-1.4" snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir) snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir) rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(DEVICE) rmbg_net.eval() pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(DEVICE, DTYPE) @spaces.GPU() @torch.no_grad() def run_triposg(image: Image.Image, num_parts: int, seed: int, num_tokens: int, num_inference_steps: int, guidance_scale: float, max_num_expanded_coords: float, use_flash_decoder: bool, rmbg: bool): """ Generate 3D part meshes from an input image. """ if rmbg: img_pil = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net) else: img_pil = image set_seed(seed) start_time = time.time() outputs = pipe( image=[img_pil] * num_parts, attention_kwargs={"num_parts": num_parts}, num_tokens=num_tokens, generator=torch.Generator(device=pipe.device).manual_seed(seed), num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, max_num_expanded_coords=max_num_expanded_coords, use_flash_decoder=use_flash_decoder, ).meshes duration = time.time() - start_time print(f"Generation time: {duration:.2f}s") # Ensure no None outputs for i, mesh in enumerate(outputs): if mesh is None: outputs[i] = trimesh.Trimesh(vertices=[[0,0,0]], faces=[[0,0,0]]) # Merge and color merged = get_colored_mesh_composition(outputs) # Export meshes and return results timestamp = time.strftime("%Y%m%d_%H%M%S") export_dir = os.path.join("results", timestamp) os.makedirs(export_dir, exist_ok=True) for idx, mesh in enumerate(outputs): mesh.export(os.path.join(export_dir, f"part_{idx:02}.glb")) merged.export(os.path.join(export_dir, "object.glb")) return merged, export_dir # Gradio Interface def build_demo(): with gr.Blocks() as demo: gr.Markdown("# PartCrafter 3D Generation Demo") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Input Image") num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts") seed = gr.Number(value=0, label="Random Seed", precision=0) num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens") num_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps") guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.1, label="Guidance Scale") max_coords = gr.Text(value="1e9", label="Max Expanded Coords") flash_decoder = gr.Checkbox(value=False, label="Use Flash Decoder") remove_bg = gr.Checkbox(value=False, label="Remove Background (RMBG)") run_button = gr.Button("Generate 3D Parts") with gr.Column(scale=1): output_model = gr.Model3D(label="Merged 3D Object") output_dir = gr.Textbox(label="Export Directory") run_button.click(fn=run_triposg, inputs=[input_image, num_parts, seed, num_tokens, num_steps, guidance, max_coords, flash_decoder, remove_bg], outputs=[output_model, output_dir]) return demo if __name__ == "__main__": demo = build_demo() demo.launch()