Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
-
import spaces
|
3 |
-
from gradio_litmodel3d import LitModel3D
|
4 |
-
|
5 |
import os
|
6 |
import shutil
|
7 |
-
|
8 |
-
|
9 |
-
import torch
|
10 |
import numpy as np
|
|
|
11 |
import imageio
|
|
|
|
|
12 |
from PIL import Image
|
|
|
|
|
13 |
from trellis.pipelines import TrellisImageTo3DPipeline
|
14 |
from trellis.utils import render_utils
|
15 |
-
|
16 |
-
|
17 |
|
18 |
MAX_SEED = np.iinfo(np.int32).max
|
19 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
@@ -35,7 +36,7 @@ def generate_3d(image, seed=-1,
|
|
35 |
ss_guidance_strength=3, ss_sampling_steps=50,
|
36 |
slat_guidance_strength=3, slat_sampling_steps=6,):
|
37 |
if image is None:
|
38 |
-
return None, None, None
|
39 |
|
40 |
if seed == -1:
|
41 |
seed = np.random.randint(0, MAX_SEED)
|
@@ -59,73 +60,43 @@ def generate_3d(image, seed=-1,
|
|
59 |
)
|
60 |
generated_mesh = outputs['mesh'][0]
|
61 |
|
62 |
-
# Save outputs
|
63 |
-
import datetime
|
64 |
output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
65 |
os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
|
66 |
mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
|
67 |
|
68 |
render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True)
|
|
|
69 |
def combine_diagonal(color_np, normal_np):
|
70 |
-
# Convert images to numpy arrays
|
71 |
h, w, c = color_np.shape
|
72 |
-
mask = np.fromfunction(lambda y, x: x > y, (h, w))
|
73 |
-
mask = mask.astype(bool)
|
74 |
mask = np.stack([mask] * c, axis=-1)
|
75 |
combined_np = np.where(mask, color_np, normal_np)
|
76 |
return Image.fromarray(combined_np)
|
77 |
|
78 |
preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])]
|
79 |
|
80 |
-
# Export mesh
|
81 |
trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
|
82 |
trimesh_mesh.export(mesh_path)
|
83 |
|
84 |
return preview_images, normal_image, mesh_path, mesh_path
|
85 |
|
86 |
def convert_mesh(mesh_path, export_format):
|
87 |
-
"""Download the mesh in the selected format."""
|
88 |
if not mesh_path:
|
89 |
return None
|
90 |
-
|
91 |
temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
|
92 |
-
temp_file_path = temp_file.name
|
93 |
-
|
94 |
-
new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
|
95 |
mesh = trimesh.load_mesh(mesh_path)
|
96 |
-
mesh.export(
|
97 |
-
|
98 |
-
return temp_file_path
|
99 |
|
100 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
101 |
-
gr.Markdown(
|
102 |
-
"""
|
103 |
<h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
|
104 |
<p style='text-align: center;'>
|
105 |
<strong>V0.1, Introduced By
|
106 |
<a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
|
107 |
<a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
|
108 |
</p>
|
109 |
-
|
110 |
-
)
|
111 |
-
|
112 |
-
with gr.Row():
|
113 |
-
gr.Markdown("""
|
114 |
-
<p align="center">
|
115 |
-
<a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank">
|
116 |
-
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
|
117 |
-
</a>
|
118 |
-
<a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank">
|
119 |
-
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
120 |
-
</a>
|
121 |
-
<a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank">
|
122 |
-
<img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C">
|
123 |
-
</a>
|
124 |
-
<a title="Social" href="https://x.com/ychngji6" target="_blank">
|
125 |
-
<img src="https://www.obukhov.ai/img/badges/badge-social.svg">
|
126 |
-
</a>
|
127 |
-
</p>
|
128 |
-
""")
|
129 |
|
130 |
with gr.Row():
|
131 |
with gr.Column(scale=1):
|
@@ -134,10 +105,8 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
134 |
with gr.Row():
|
135 |
image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
|
136 |
normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
|
137 |
-
|
138 |
with gr.Tab("Multiple Images"):
|
139 |
gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
|
140 |
-
|
141 |
with gr.Accordion("Advanced Settings", open=False):
|
142 |
seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
|
143 |
gr.Markdown("#### Stage 1: Sparse Structure Generation")
|
@@ -148,18 +117,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
148 |
with gr.Row():
|
149 |
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
150 |
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
|
151 |
-
|
152 |
with gr.Group():
|
153 |
with gr.Row():
|
154 |
gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
|
155 |
-
|
156 |
with gr.Column(scale=1):
|
157 |
with gr.Tabs():
|
158 |
with gr.Tab("Preview"):
|
159 |
output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
|
160 |
with gr.Tab("3D Model"):
|
161 |
with gr.Column():
|
162 |
-
model_output = gr.Model3D(label="3D Model Preview (Each model is
|
163 |
with gr.Column():
|
164 |
export_format = gr.Dropdown(
|
165 |
choices=["obj", "glb", "ply", "stl"],
|
@@ -167,17 +135,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
167 |
label="File Format"
|
168 |
)
|
169 |
download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
|
170 |
-
|
171 |
image_prompt.upload(
|
172 |
preprocess_image,
|
173 |
inputs=[image_prompt],
|
174 |
outputs=[image_prompt]
|
175 |
)
|
176 |
-
|
177 |
gen_shape_btn.click(
|
178 |
generate_3d,
|
179 |
inputs=[
|
180 |
-
image_prompt, seed,
|
181 |
ss_guidance_strength, ss_sampling_steps,
|
182 |
slat_guidance_strength, slat_sampling_steps
|
183 |
],
|
@@ -190,10 +158,9 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
190 |
def update_download_button(mesh_path, export_format):
|
191 |
if not mesh_path:
|
192 |
return gr.File.update(value=None, interactive=False)
|
193 |
-
|
194 |
download_path = convert_mesh(mesh_path, export_format)
|
195 |
return download_path
|
196 |
-
|
197 |
export_format.change(
|
198 |
update_download_button,
|
199 |
inputs=[model_output, export_format],
|
@@ -202,7 +169,7 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
202 |
lambda: gr.Button(interactive=True),
|
203 |
outputs=[download_btn],
|
204 |
)
|
205 |
-
|
206 |
examples = gr.Examples(
|
207 |
examples=[
|
208 |
f'assets/example_image/{image}'
|
@@ -211,19 +178,23 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
211 |
inputs=image_prompt,
|
212 |
)
|
213 |
|
214 |
-
gr.Markdown(
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
"""
|
220 |
-
)
|
221 |
|
222 |
if __name__ == "__main__":
|
223 |
-
#
|
224 |
pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
|
225 |
-
|
226 |
-
|
227 |
-
normal_predictor = torch.hub.load(
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
import os
|
3 |
import shutil
|
4 |
+
import tempfile
|
5 |
+
import datetime
|
|
|
6 |
import numpy as np
|
7 |
+
import torch
|
8 |
import imageio
|
9 |
+
import trimesh
|
10 |
+
|
11 |
from PIL import Image
|
12 |
+
from typing import *
|
13 |
+
from gradio_litmodel3d import LitModel3D
|
14 |
from trellis.pipelines import TrellisImageTo3DPipeline
|
15 |
from trellis.utils import render_utils
|
16 |
+
|
17 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
18 |
|
19 |
MAX_SEED = np.iinfo(np.int32).max
|
20 |
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
|
|
|
36 |
ss_guidance_strength=3, ss_sampling_steps=50,
|
37 |
slat_guidance_strength=3, slat_sampling_steps=6,):
|
38 |
if image is None:
|
39 |
+
return None, None, None, None
|
40 |
|
41 |
if seed == -1:
|
42 |
seed = np.random.randint(0, MAX_SEED)
|
|
|
60 |
)
|
61 |
generated_mesh = outputs['mesh'][0]
|
62 |
|
|
|
|
|
63 |
output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
64 |
os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
|
65 |
mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
|
66 |
|
67 |
render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True)
|
68 |
+
|
69 |
def combine_diagonal(color_np, normal_np):
|
|
|
70 |
h, w, c = color_np.shape
|
71 |
+
mask = np.fromfunction(lambda y, x: x > y, (h, w)).astype(bool)
|
|
|
72 |
mask = np.stack([mask] * c, axis=-1)
|
73 |
combined_np = np.where(mask, color_np, normal_np)
|
74 |
return Image.fromarray(combined_np)
|
75 |
|
76 |
preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])]
|
77 |
|
|
|
78 |
trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
|
79 |
trimesh_mesh.export(mesh_path)
|
80 |
|
81 |
return preview_images, normal_image, mesh_path, mesh_path
|
82 |
|
83 |
def convert_mesh(mesh_path, export_format):
|
|
|
84 |
if not mesh_path:
|
85 |
return None
|
|
|
86 |
temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
|
|
|
|
|
|
|
87 |
mesh = trimesh.load_mesh(mesh_path)
|
88 |
+
mesh.export(temp_file.name)
|
89 |
+
return temp_file.name
|
|
|
90 |
|
91 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
92 |
+
gr.Markdown("""
|
|
|
93 |
<h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
|
94 |
<p style='text-align: center;'>
|
95 |
<strong>V0.1, Introduced By
|
96 |
<a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
|
97 |
<a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
|
98 |
</p>
|
99 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
with gr.Row():
|
102 |
with gr.Column(scale=1):
|
|
|
105 |
with gr.Row():
|
106 |
image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
|
107 |
normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
|
|
|
108 |
with gr.Tab("Multiple Images"):
|
109 |
gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
|
|
|
110 |
with gr.Accordion("Advanced Settings", open=False):
|
111 |
seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
|
112 |
gr.Markdown("#### Stage 1: Sparse Structure Generation")
|
|
|
117 |
with gr.Row():
|
118 |
slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
|
119 |
slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
|
|
|
120 |
with gr.Group():
|
121 |
with gr.Row():
|
122 |
gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
|
123 |
+
|
124 |
with gr.Column(scale=1):
|
125 |
with gr.Tabs():
|
126 |
with gr.Tab("Preview"):
|
127 |
output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
|
128 |
with gr.Tab("3D Model"):
|
129 |
with gr.Column():
|
130 |
+
model_output = gr.Model3D(label="3D Model Preview (Each model is approx. 40MB)")
|
131 |
with gr.Column():
|
132 |
export_format = gr.Dropdown(
|
133 |
choices=["obj", "glb", "ply", "stl"],
|
|
|
135 |
label="File Format"
|
136 |
)
|
137 |
download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
|
138 |
+
|
139 |
image_prompt.upload(
|
140 |
preprocess_image,
|
141 |
inputs=[image_prompt],
|
142 |
outputs=[image_prompt]
|
143 |
)
|
144 |
+
|
145 |
gen_shape_btn.click(
|
146 |
generate_3d,
|
147 |
inputs=[
|
148 |
+
image_prompt, seed,
|
149 |
ss_guidance_strength, ss_sampling_steps,
|
150 |
slat_guidance_strength, slat_sampling_steps
|
151 |
],
|
|
|
158 |
def update_download_button(mesh_path, export_format):
|
159 |
if not mesh_path:
|
160 |
return gr.File.update(value=None, interactive=False)
|
|
|
161 |
download_path = convert_mesh(mesh_path, export_format)
|
162 |
return download_path
|
163 |
+
|
164 |
export_format.change(
|
165 |
update_download_button,
|
166 |
inputs=[model_output, export_format],
|
|
|
169 |
lambda: gr.Button(interactive=True),
|
170 |
outputs=[download_btn],
|
171 |
)
|
172 |
+
|
173 |
examples = gr.Examples(
|
174 |
examples=[
|
175 |
f'assets/example_image/{image}'
|
|
|
178 |
inputs=image_prompt,
|
179 |
)
|
180 |
|
181 |
+
gr.Markdown("""
|
182 |
+
**Acknowledgments**: Hi3DGen is built on the shoulders of giants. We acknowledge contributions from:
|
183 |
+
- [Trellis 3D](https://github.com/microsoft/TRELLIS)
|
184 |
+
- [StableNormal](https://github.com/hugoycj/StableNormal)
|
185 |
+
""")
|
|
|
|
|
186 |
|
187 |
if __name__ == "__main__":
|
188 |
+
# ✅ 强制使用 CPU
|
189 |
pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
|
190 |
+
pipeline.to("cpu") # <-- 强制使用 CPU
|
191 |
+
|
192 |
+
normal_predictor = torch.hub.load(
|
193 |
+
"hugoycj/StableNormal",
|
194 |
+
"StableNormal_turbo",
|
195 |
+
trust_repo=True,
|
196 |
+
yoso_version="yoso-normal-v1-8-1"
|
197 |
+
)
|
198 |
+
normal_predictor.to("cpu") # <-- 也强制使用 CPU
|
199 |
+
|
200 |
demo.launch()
|