|  | import gradio as gr | 
					
						
						|  | import spaces | 
					
						
						|  | from gradio_litmodel3d import LitModel3D | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | import shutil | 
					
						
						|  | os.environ['SPCONV_ALGO'] = 'native' | 
					
						
						|  | from typing import * | 
					
						
						|  | import torch | 
					
						
						|  | import numpy as np | 
					
						
						|  | import imageio | 
					
						
						|  | from PIL import Image | 
					
						
						|  | from trellis.pipelines import TrellisImageTo3DPipeline | 
					
						
						|  | from trellis.utils import render_utils | 
					
						
						|  | import trimesh | 
					
						
						|  | import tempfile | 
					
						
						|  |  | 
					
						
						|  | MAX_SEED = np.iinfo(np.int32).max | 
					
						
						|  | TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | 
					
						
						|  | os.makedirs(TMP_DIR, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  | def preprocess_mesh(mesh_prompt): | 
					
						
						|  | print("Processing mesh") | 
					
						
						|  | trimesh_mesh = trimesh.load_mesh(mesh_prompt) | 
					
						
						|  | trimesh_mesh.export(mesh_prompt+'.glb') | 
					
						
						|  | return mesh_prompt+'.glb' | 
					
						
						|  |  | 
					
						
						|  | def preprocess_image(image): | 
					
						
						|  | if image is None: | 
					
						
						|  | return None | 
					
						
						|  | image = pipeline.preprocess_image(image, resolution=1024) | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU | 
					
						
						|  | def generate_3d(image, seed=-1, | 
					
						
						|  | ss_guidance_strength=3, ss_sampling_steps=50, | 
					
						
						|  | slat_guidance_strength=3, slat_sampling_steps=6,): | 
					
						
						|  | if image is None: | 
					
						
						|  | return None, None, None | 
					
						
						|  |  | 
					
						
						|  | if seed == -1: | 
					
						
						|  | seed = np.random.randint(0, MAX_SEED) | 
					
						
						|  |  | 
					
						
						|  | image = pipeline.preprocess_image(image, resolution=1024) | 
					
						
						|  | normal_image = normal_predictor(image, resolution=768, match_input_resolution=True, data_type='object') | 
					
						
						|  |  | 
					
						
						|  | outputs = pipeline.run( | 
					
						
						|  | normal_image, | 
					
						
						|  | seed=seed, | 
					
						
						|  | formats=["mesh",], | 
					
						
						|  | preprocess_image=False, | 
					
						
						|  | sparse_structure_sampler_params={ | 
					
						
						|  | "steps": ss_sampling_steps, | 
					
						
						|  | "cfg_strength": ss_guidance_strength, | 
					
						
						|  | }, | 
					
						
						|  | slat_sampler_params={ | 
					
						
						|  | "steps": slat_sampling_steps, | 
					
						
						|  | "cfg_strength": slat_guidance_strength, | 
					
						
						|  | }, | 
					
						
						|  | ) | 
					
						
						|  | generated_mesh = outputs['mesh'][0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import datetime | 
					
						
						|  | output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S") | 
					
						
						|  | os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True) | 
					
						
						|  | mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb" | 
					
						
						|  |  | 
					
						
						|  | render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True) | 
					
						
						|  | def combine_diagonal(color_np, normal_np): | 
					
						
						|  |  | 
					
						
						|  | h, w, c = color_np.shape | 
					
						
						|  |  | 
					
						
						|  | mask = np.fromfunction(lambda y, x: x > y, (h, w)) | 
					
						
						|  | mask = mask.astype(bool) | 
					
						
						|  | mask = np.stack([mask] * c, axis=-1) | 
					
						
						|  |  | 
					
						
						|  | combined_np = np.where(mask, color_np, normal_np) | 
					
						
						|  | return Image.fromarray(combined_np) | 
					
						
						|  |  | 
					
						
						|  | preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True) | 
					
						
						|  |  | 
					
						
						|  | trimesh_mesh.export(mesh_path) | 
					
						
						|  |  | 
					
						
						|  | return preview_images, normal_image, mesh_path, mesh_path | 
					
						
						|  |  | 
					
						
						|  | def convert_mesh(mesh_path, export_format): | 
					
						
						|  | """Download the mesh in the selected format.""" | 
					
						
						|  | if not mesh_path: | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False) | 
					
						
						|  | temp_file_path = temp_file.name | 
					
						
						|  |  | 
					
						
						|  | new_mesh_path = mesh_path.replace(".glb", f".{export_format}") | 
					
						
						|  | mesh = trimesh.load_mesh(mesh_path) | 
					
						
						|  | mesh.export(temp_file_path) | 
					
						
						|  |  | 
					
						
						|  | return temp_file_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks(css="footer {visibility: hidden}") as demo: | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | """ | 
					
						
						|  | <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1> | 
					
						
						|  | <p style='text-align: center;'> | 
					
						
						|  | <strong>V0.1, Introduced By | 
					
						
						|  | <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and | 
					
						
						|  | <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong> | 
					
						
						|  | </p> | 
					
						
						|  | """ | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Markdown(""" | 
					
						
						|  | <p align="center"> | 
					
						
						|  | <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | 
					
						
						|  | <img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | 
					
						
						|  | </a> | 
					
						
						|  | <a title="arXiv" href="" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | 
					
						
						|  | <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | 
					
						
						|  | </a> | 
					
						
						|  | <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | 
					
						
						|  | <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | 
					
						
						|  | </a> | 
					
						
						|  | <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | 
					
						
						|  | <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | 
					
						
						|  | </a> | 
					
						
						|  | </p> | 
					
						
						|  | """) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=1): | 
					
						
						|  | with gr.Tabs(): | 
					
						
						|  |  | 
					
						
						|  | with gr.Tab("Single Image"): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil") | 
					
						
						|  | normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil") | 
					
						
						|  |  | 
					
						
						|  | with gr.Tab("Multiple Images"): | 
					
						
						|  | gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>") | 
					
						
						|  |  | 
					
						
						|  | with gr.Accordion("Advanced Settings", open=False): | 
					
						
						|  | seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1) | 
					
						
						|  | gr.Markdown("#### Stage 1: Sparse Structure Generation") | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3, step=0.1) | 
					
						
						|  | ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=50, step=1) | 
					
						
						|  | gr.Markdown("#### Stage 2: Structured Latent Generation") | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) | 
					
						
						|  | slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1) | 
					
						
						|  |  | 
					
						
						|  | with gr.Group(): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Column(scale=1): | 
					
						
						|  | with gr.Tabs(): | 
					
						
						|  | with gr.Tab("Preview"): | 
					
						
						|  | output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto",show_label=False) | 
					
						
						|  | with gr.Tab("3D Model"): | 
					
						
						|  | with gr.Column(): | 
					
						
						|  | model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)", height=300) | 
					
						
						|  | with gr.Column(): | 
					
						
						|  | export_format = gr.Dropdown( | 
					
						
						|  | choices=["obj", "glb", "ply", "stl"], | 
					
						
						|  | value="glb", | 
					
						
						|  | label="File Format" | 
					
						
						|  | ) | 
					
						
						|  | download_btn = gr.DownloadButton(label="Export Mesh", interactive=False) | 
					
						
						|  |  | 
					
						
						|  | image_prompt.upload( | 
					
						
						|  | preprocess_image, | 
					
						
						|  | inputs=[image_prompt], | 
					
						
						|  | outputs=[image_prompt] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gen_shape_btn.click( | 
					
						
						|  | generate_3d, | 
					
						
						|  | inputs=[ | 
					
						
						|  | image_prompt, seed, | 
					
						
						|  | ss_guidance_strength, ss_sampling_steps, | 
					
						
						|  | slat_guidance_strength, slat_sampling_steps | 
					
						
						|  | ], | 
					
						
						|  | outputs=[output_gallery, normal_output, model_output, download_btn] | 
					
						
						|  | ).then( | 
					
						
						|  | lambda: gr.Button(interactive=True), | 
					
						
						|  | outputs=[download_btn], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def update_download_button(mesh_path, export_format): | 
					
						
						|  | if not mesh_path: | 
					
						
						|  | return gr.File.update(value=None, interactive=False) | 
					
						
						|  |  | 
					
						
						|  | download_path = convert_mesh(mesh_path, export_format) | 
					
						
						|  | return download_path | 
					
						
						|  |  | 
					
						
						|  | export_format.change( | 
					
						
						|  | update_download_button, | 
					
						
						|  | inputs=[model_output, export_format], | 
					
						
						|  | outputs=[download_btn] | 
					
						
						|  | ).then( | 
					
						
						|  | lambda: gr.Button(interactive=True), | 
					
						
						|  | outputs=[download_btn], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | examples = gr.Examples( | 
					
						
						|  | examples=[ | 
					
						
						|  | f'assets/example_image/{image}' | 
					
						
						|  | for image in os.listdir("assets/example_image") | 
					
						
						|  | ], | 
					
						
						|  | inputs=image_prompt, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  |  | 
					
						
						|  | pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1") | 
					
						
						|  | pipeline.cuda() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |