Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,326 Bytes
f108aa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import spaces
import gradio as gr
import os
import time
import torch
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
from accelerate.utils import set_seed
import trimesh
from src.utils.data_utils import get_colored_mesh_composition, export_renderings
from src.utils.image_utils import prepare_image
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.models.briarmbg import BriaRMBG
# Constants
MAX_NUM_PARTS = 16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
# Download and initialize models
partcrafter_weights_dir = "pretrained_weights/PartCrafter"
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(DEVICE)
rmbg_net.eval()
pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(DEVICE, DTYPE)
@spaces.GPU()
@torch.no_grad()
def run_triposg(image: Image.Image,
num_parts: int,
seed: int,
num_tokens: int,
num_inference_steps: int,
guidance_scale: float,
max_num_expanded_coords: float,
use_flash_decoder: bool,
rmbg: bool):
"""
Generate 3D part meshes from an input image.
"""
if rmbg:
img_pil = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
else:
img_pil = image
set_seed(seed)
start_time = time.time()
outputs = pipe(
image=[img_pil] * num_parts,
attention_kwargs={"num_parts": num_parts},
num_tokens=num_tokens,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
max_num_expanded_coords=max_num_expanded_coords,
use_flash_decoder=use_flash_decoder,
).meshes
duration = time.time() - start_time
print(f"Generation time: {duration:.2f}s")
# Ensure no None outputs
for i, mesh in enumerate(outputs):
if mesh is None:
outputs[i] = trimesh.Trimesh(vertices=[[0,0,0]], faces=[[0,0,0]])
# Merge and color
merged = get_colored_mesh_composition(outputs)
# Export meshes and return results
timestamp = time.strftime("%Y%m%d_%H%M%S")
export_dir = os.path.join("results", timestamp)
os.makedirs(export_dir, exist_ok=True)
for idx, mesh in enumerate(outputs):
mesh.export(os.path.join(export_dir, f"part_{idx:02}.glb"))
merged.export(os.path.join(export_dir, "object.glb"))
return merged, export_dir
# Gradio Interface
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# PartCrafter 3D Generation Demo")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Input Image")
num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
seed = gr.Number(value=0, label="Random Seed", precision=0)
num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
num_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.1, label="Guidance Scale")
max_coords = gr.Text(value="1e9", label="Max Expanded Coords")
flash_decoder = gr.Checkbox(value=False, label="Use Flash Decoder")
remove_bg = gr.Checkbox(value=False, label="Remove Background (RMBG)")
run_button = gr.Button("Generate 3D Parts")
with gr.Column(scale=1):
output_model = gr.Model3D(label="Merged 3D Object")
output_dir = gr.Textbox(label="Export Directory")
run_button.click(fn=run_triposg,
inputs=[input_image, num_parts, seed, num_tokens, num_steps,
guidance, max_coords, flash_decoder, remove_bg],
outputs=[output_model, output_dir])
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()
|