andrew3d commited on
Commit
d21f0b7
Β·
verified Β·
1 Parent(s): 9126975

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -0
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) Microsoft
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2025] [Microsoft]
24
+ # Copyright (c) [2025] [Chongjie Ye]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Chongjie Ye on 2025/04/10
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ import gradio as gr
31
+ import os
32
+ os.environ['SPCONV_ALGO'] = 'native'
33
+ from typing import *
34
+ import torch
35
+ import numpy as np
36
+ import tempfile
37
+
38
+ import zipfile
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # NOTE
42
+ # The original Hi3DGen implementation expects the `hi3dgen` Python package to
43
+ # reside alongside this app file. Hugging Face Spaces do not currently
44
+ # support uploading an entire folder via the web interface, so the `hi3dgen`
45
+ # source tree is bundled into a single `hi3dgen.zip` archive. On startup we
46
+ # extract this archive into the working directory if the `hi3dgen` package is
47
+ # not already present. This allows the rest of the code to `import hi3dgen` as
48
+ # normal.
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def _ensure_hi3dgen_available():
52
+ """Unpack hi3dgen.zip into the current directory if the hi3dgen package
53
+ is missing. This function is idempotent and safe to call multiple times.
54
+ """
55
+ pkg_name = 'hi3dgen'
56
+ pkg_dir = os.path.join(os.path.dirname(__file__), pkg_name)
57
+ if os.path.isdir(pkg_dir):
58
+ return
59
+ archive_path = os.path.join(os.path.dirname(__file__), f"{pkg_name}.zip")
60
+ if os.path.isfile(archive_path):
61
+ try:
62
+ with zipfile.ZipFile(archive_path, 'r') as zf:
63
+ zf.extractall(os.path.dirname(__file__))
64
+ except Exception as e:
65
+ raise RuntimeError(f"Failed to extract {archive_path}: {e}")
66
+ else:
67
+ raise FileNotFoundError(
68
+ f"Required archive {archive_path} is missing. Make sure to upload the hi3dgen.zip file alongside app.py."
69
+ )
70
+
71
+ # Make sure the hi3dgen package is available before importing it
72
+ _ensure_hi3dgen_available()
73
+
74
+ from hi3dgen.pipelines import Hi3DGenPipeline
75
+ import trimesh
76
+ MAX_SEED = np.iinfo(np.int32).max
77
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
78
+ WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights')
79
+ os.makedirs(TMP_DIR, exist_ok=True)
80
+ os.makedirs(WEIGHTS_DIR, exist_ok=True)
81
+
82
+ def cache_weights(weights_dir: str) -> dict:
83
+ import os
84
+ from huggingface_hub import snapshot_download
85
+
86
+ os.makedirs(weights_dir, exist_ok=True)
87
+ model_ids = [
88
+ "Stable-X/trellis-normal-v0-1",
89
+ "Stable-X/yoso-normal-v1-8-1",
90
+ "ZhengPeng7/BiRefNet",
91
+ ]
92
+ cached_paths = {}
93
+ for model_id in model_ids:
94
+ print(f"Caching weights for: {model_id}")
95
+ # Check if the model is already cached
96
+ local_path = os.path.join(weights_dir, model_id.split("/")[-1])
97
+ if os.path.exists(local_path):
98
+ print(f"Already cached at: {local_path}")
99
+ cached_paths[model_id] = local_path
100
+ continue
101
+ # Download the model and cache it
102
+ print(f"Downloading and caching model: {model_id}")
103
+ # Use snapshot_download to download the model
104
+ local_path = snapshot_download(repo_id=model_id, local_dir=os.path.join(weights_dir, model_id.split("/")[-1]), force_download=False)
105
+ cached_paths[model_id] = local_path
106
+ print(f"Cached at: {local_path}")
107
+
108
+ return cached_paths
109
+
110
+ def preprocess_mesh(mesh_prompt):
111
+ print("Processing mesh")
112
+ trimesh_mesh = trimesh.load_mesh(mesh_prompt)
113
+ trimesh_mesh.export(mesh_prompt+'.glb')
114
+ return mesh_prompt+'.glb'
115
+
116
+ def preprocess_image(image):
117
+ if image is None:
118
+ return None
119
+ image = hi3dgen_pipeline.preprocess_image(image, resolution=1024)
120
+ return image
121
+
122
+ def generate_3d(image, seed=-1,
123
+ ss_guidance_strength=3, ss_sampling_steps=50,
124
+ slat_guidance_strength=3, slat_sampling_steps=6,):
125
+ if image is None:
126
+ return None, None, None
127
+
128
+ if seed == -1:
129
+ seed = np.random.randint(0, MAX_SEED)
130
+
131
+ image = hi3dgen_pipeline.preprocess_image(image, resolution=1024)
132
+ normal_image = normal_predictor(image, resolution=768, match_input_resolution=True, data_type='object')
133
+
134
+ outputs = hi3dgen_pipeline.run(
135
+ normal_image,
136
+ seed=seed,
137
+ formats=["mesh",],
138
+ preprocess_image=False,
139
+ sparse_structure_sampler_params={
140
+ "steps": ss_sampling_steps,
141
+ "cfg_strength": ss_guidance_strength,
142
+ },
143
+ slat_sampler_params={
144
+ "steps": slat_sampling_steps,
145
+ "cfg_strength": slat_guidance_strength,
146
+ },
147
+ )
148
+ generated_mesh = outputs['mesh'][0]
149
+
150
+ # Save outputs
151
+ import datetime
152
+
153
+
154
+ output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
155
+ os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
156
+ mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
157
+
158
+ # Export mesh
159
+ trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
160
+
161
+ trimesh_mesh.export(mesh_path)
162
+
163
+ return normal_image, mesh_path, mesh_path
164
+
165
+ def convert_mesh(mesh_path, export_format):
166
+ """Download the mesh in the selected format."""
167
+ if not mesh_path:
168
+ return None
169
+
170
+ # Create a temporary file to store the mesh data
171
+ temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
172
+ temp_file_path = temp_file.name
173
+
174
+ new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
175
+ mesh = trimesh.load_mesh(mesh_path)
176
+ mesh.export(temp_file_path) # Export to the temporary file
177
+
178
+ return temp_file_path # Return the path to the temporary file
179
+
180
+ # Create the Gradio interface with improved layout
181
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
182
+ gr.Markdown(
183
+ """
184
+ <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
185
+ <p style='text-align: center;'>
186
+ <strong>V0.1, Introduced By
187
+ <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
188
+ <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
189
+ </p>
190
+ """
191
+ )
192
+
193
+ with gr.Row():
194
+ gr.Markdown("""
195
+ <p align="center">
196
+ <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
197
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
198
+ </a>
199
+ <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
200
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
201
+ </a>
202
+ <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
203
+ <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
204
+ </a>
205
+ <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
206
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
207
+ </a>
208
+ </p>
209
+ """)
210
+
211
+ with gr.Row():
212
+ with gr.Column(scale=1):
213
+ with gr.Tabs():
214
+
215
+ with gr.Tab("Single Image"):
216
+ with gr.Row():
217
+ image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
218
+ normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
219
+
220
+ with gr.Tab("Multiple Images"):
221
+ gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
222
+
223
+ with gr.Accordion("Advanced Settings", open=False):
224
+ seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
225
+ gr.Markdown("#### Stage 1: Sparse Structure Generation")
226
+ with gr.Row():
227
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3, step=0.1)
228
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=50, step=1)
229
+ gr.Markdown("#### Stage 2: Structured Latent Generation")
230
+ with gr.Row():
231
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
232
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
233
+
234
+ with gr.Group():
235
+ with gr.Row():
236
+ gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
237
+
238
+ # Right column - Output
239
+ with gr.Column(scale=1):
240
+ with gr.Column():
241
+ model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
242
+ with gr.Column():
243
+ export_format = gr.Dropdown(
244
+ choices=["obj", "glb", "ply", "stl"],
245
+ value="glb",
246
+ label="File Format"
247
+ )
248
+ download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
249
+
250
+ image_prompt.upload(
251
+ preprocess_image,
252
+ inputs=[image_prompt],
253
+ outputs=[image_prompt]
254
+ )
255
+
256
+ gen_shape_btn.click(
257
+ generate_3d,
258
+ inputs=[
259
+ image_prompt, seed,
260
+ ss_guidance_strength, ss_sampling_steps,
261
+ slat_guidance_strength, slat_sampling_steps
262
+ ],
263
+ outputs=[normal_output, model_output, download_btn]
264
+ ).then(
265
+ lambda: gr.Button(interactive=True),
266
+ outputs=[download_btn],
267
+ )
268
+
269
+
270
+ def update_download_button(mesh_path, export_format):
271
+ if not mesh_path:
272
+ return gr.File.update(value=None, interactive=False)
273
+
274
+ download_path = convert_mesh(mesh_path, export_format)
275
+ return download_path
276
+
277
+ export_format.change(
278
+ update_download_button,
279
+ inputs=[model_output, export_format],
280
+ outputs=[download_btn]
281
+ ).then(
282
+ lambda: gr.Button(interactive=True),
283
+ outputs=[download_btn],
284
+ )
285
+
286
+ examples = None
287
+
288
+ gr.Markdown(
289
+ """
290
+ **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
291
+ - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS) and we draw inspiration from the teams behind [Rodin](https://hyperhuman.deemos.com/rodin), [Tripo](https://www.tripo3d.ai/app/home), and [Dora](https://github.com/Seed3D/Dora).
292
+ - **Normal Estimation:** Our Normal Estimation Model builds on the leading normal estimation research such as [StableNormal](https://github.com/hugoycj/StableNormal) and [GenPercept](https://github.com/aim-uofa/GenPercept).
293
+
294
+ **Your contributions and collaboration push the boundaries of 3D modeling!**
295
+ """
296
+ )
297
+
298
+ if __name__ == "__main__":
299
+ # Download and cache the weights
300
+ cache_weights(WEIGHTS_DIR)
301
+
302
+ hi3dgen_pipeline = Hi3DGenPipeline.from_pretrained("weights/trellis-normal-v0-1")
303
+ hi3dgen_pipeline.cuda()
304
+
305
+ # Initialize normal predictor
306
+ try:
307
+ normal_predictor = torch.hub.load(os.path.join(torch.hub.get_dir(), 'hugoycj_StableNormal_main'), "StableNormal_turbo", yoso_version='yoso-normal-v1-8-1', source='local', local_cache_dir='./weights', pretrained=True)
308
+ except:
309
+ normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1', local_cache_dir='./weights')
310
+
311
+ # Launch the app
312
+ demo.launch(share=False, server_name="0.0.0.0")