andrew3d commited on
Commit
76ef786
Β·
verified Β·
1 Parent(s): 787d8e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -187
app.py CHANGED
@@ -1,171 +1,136 @@
1
  # MIT License
2
-
3
  # Copyright (c) Microsoft
4
-
5
  # Permission is hereby granted, free of charge, to any person obtaining a copy
6
  # of this software and associated documentation files (the "Software"), to deal
7
  # in the Software without restriction, including without limitation the rights
8
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- # copies of the Software, and to permit persons to whom the Software is
10
- # furnished to do so, subject to the following conditions:
11
-
12
- # The above copyright notice and this permission notice shall be included in all
13
- # copies or substantial portions of the Software.
14
-
15
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- # SOFTWARE.
22
 
23
  # Copyright (c) [2025] [Microsoft]
24
- # Copyright (c) [2025] [Chongjie Ye]
25
  # SPDX-License-Identifier: MIT
26
  # This file has been modified by Chongjie Ye on 2025/04/10
27
- # Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
 
28
  # This modified file is released under the same license.
29
 
30
  import gradio as gr
31
  import os
32
 
33
  # ---- Force CPU-only environment globally ----
34
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # hide any GPUs from torch
35
- os.environ.setdefault("ATTN_BACKEND", "sdpa") # avoid xformers
36
  os.environ.setdefault("SPCONV_ALGO", "native") # safe sparseconv algo
37
  # ---------------------------------------------
 
38
  from typing import *
39
  import torch
40
  import numpy as np
41
  import tempfile
42
-
43
  import zipfile
44
  import types
 
45
 
46
  # ---------------------------------------------------------------------------
47
- # NOTE
48
- # The original Hi3DGen implementation expects the `hi3dgen` Python package to
49
- # reside alongside this app file. Hugging Face Spaces do not currently
50
- # support uploading an entire folder via the web interface, so the `hi3dgen`
51
- # source tree is bundled into a single `hi3dgen.zip` archive. On startup we
52
- # extract this archive into the working directory if the `hi3dgen` package is
53
- # not already present. This allows the rest of the code to `import hi3dgen` as
54
- # normal.
55
  # ---------------------------------------------------------------------------
56
-
57
  def _ensure_hi3dgen_available():
58
- """Unpack hi3dgen.zip into the current directory if the hi3dgen package
59
- is missing. This function is idempotent and safe to call multiple times.
60
- """
61
  pkg_name = 'hi3dgen'
62
- pkg_dir = os.path.join(os.path.dirname(__file__), pkg_name)
 
63
  if os.path.isdir(pkg_dir):
64
  return
65
- archive_path = os.path.join(os.path.dirname(__file__), f"{pkg_name}.zip")
66
- if os.path.isfile(archive_path):
67
- try:
68
- with zipfile.ZipFile(archive_path, 'r') as zf:
69
- zf.extractall(os.path.dirname(__file__))
70
- except Exception as e:
71
- raise RuntimeError(f"Failed to extract {archive_path}: {e}")
72
- else:
73
  raise FileNotFoundError(
74
- f"Required archive {archive_path} is missing. Make sure to upload the hi3dgen.zip file alongside app.py."
75
  )
 
 
 
 
 
76
 
77
- # Make sure the hi3dgen package is available before importing it
78
  _ensure_hi3dgen_available()
79
 
80
  # ---------------------------------------------------------------------------
81
- # xformers stub
82
- #
83
- # Some modules in the Hi3DGen pipeline import `xformers.ops.memory_efficient_attention`
84
- # to compute multi-head attention. The official `xformers` library is not
85
- # installed in this Space (and requires GPU-only build), so we provide a
86
- # minimal in-memory stub that exposes a compatible API backed by PyTorch's
87
- # built-in scaled dot-product attention. This stub is lightweight and
88
- # CPU-friendly. It registers both the `xformers` and `xformers.ops` modules
89
- # in sys.modules so that subsequent imports succeed.
90
  # ---------------------------------------------------------------------------
91
-
92
  def _ensure_xformers_stub():
93
-
94
- # ---------------------------------------------------------------------------
95
  import sys
96
- # If xformers is already available, do nothing.
97
  if 'xformers.ops' in sys.modules:
98
  return
99
  import torch.nn.functional as F
100
- # Create a new module object for xformers and its ops submodule
101
- xformers_mod = types.ModuleType('xformers')
102
  ops_mod = types.ModuleType('xformers.ops')
103
 
104
  def memory_efficient_attention(query, key, value, attn_bias=None):
105
- """
106
- Fallback implementation of memory_efficient_attention for CPU environments.
107
- This wraps torch.nn.functional.scaled_dot_product_attention.
108
- """
109
- # PyTorch expects the attention mask (bias) to be additive with shape
110
- # broadcastable to (batch, num_heads, seq_len_query, seq_len_key). If
111
- # attn_bias is provided and is non-zero, pass it through; otherwise
112
- # supply None to avoid unnecessary allocations.
113
  return F.scaled_dot_product_attention(query, key, value, attn_bias)
114
 
115
- # Populate the ops module with our fallback function
116
  ops_mod.memory_efficient_attention = memory_efficient_attention
117
- # Expose ops as an attribute of xformers
118
- xformers_mod.ops = ops_mod
119
- # Register modules
120
- sys.modules['xformers'] = xformers_mod
121
  sys.modules['xformers.ops'] = ops_mod
122
 
123
- # Ensure the xformers stub is registered before importing Hi3DGen
124
  _ensure_xformers_stub()
125
 
126
  # ---------------------------------------------------------------------------
127
- # Monkey-patch Hi3DGen to run on CPU
128
- #
129
- # Some utility functions and classes in the Hi3DGen codebase assume the
130
- # presence of a CUDA device by default. Specifically, the function
131
- # `construct_dense_grid` in `hi3dgen.representations.mesh.utils_cube` uses
132
- # `device='cuda'` as its default argument, and the class
133
- # `EnhancedMarchingCubes` in `hi3dgen.representations.mesh.cube2mesh` has
134
- # a constructor that defaults to `device="cuda"`. On CPU-only Spaces, these
135
- # defaults cause runtime errors when PyTorch attempts to allocate tensors on
136
- # a non-existent GPU. To avoid this, we override the default arguments for
137
- # these functions to use the CPU instead. If the patch cannot be applied
138
- # (for example, if the module structure changes in a future version), we
139
- # catch any exceptions and log a warning without stopping execution.
140
- try:
141
- from hi3dgen.representations.mesh import utils_cube
142
- if hasattr(utils_cube.construct_dense_grid, '__defaults__'):
143
- _defaults = list(utils_cube.construct_dense_grid.__defaults__ or ())
144
- if _defaults and _defaults[-1] == 'cuda':
145
- _defaults[-1] = 'cpu'
146
- utils_cube.construct_dense_grid.__defaults__ = tuple(_defaults)
147
-
148
- from hi3dgen.representations.mesh.cube2mesh import EnhancedMarchingCubes
149
- if hasattr(EnhancedMarchingCubes.__init__, '__defaults__'):
150
- _mc_defaults = list(EnhancedMarchingCubes.__init__.__defaults__ or ())
151
- if _mc_defaults and _mc_defaults[-1] == 'cuda':
152
- _mc_defaults[-1] = 'cpu'
153
- EnhancedMarchingCubes.__init__.__defaults__ = tuple(_mc_defaults)
154
- except Exception as _e:
155
- print(f"Warning: failed to apply CPU device overrides: {_e}")
156
-
157
  from hi3dgen.pipelines import Hi3DGenPipeline
158
  import trimesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  MAX_SEED = np.iinfo(np.int32).max
160
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
161
  WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights')
162
  os.makedirs(TMP_DIR, exist_ok=True)
163
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
164
 
 
 
 
165
  def cache_weights(weights_dir: str) -> dict:
166
- import os
167
  from huggingface_hub import snapshot_download
168
-
169
  os.makedirs(weights_dir, exist_ok=True)
170
  model_ids = [
171
  "Stable-X/trellis-normal-v0-1",
@@ -175,49 +140,62 @@ def cache_weights(weights_dir: str) -> dict:
175
  cached_paths = {}
176
  for model_id in model_ids:
177
  print(f"Caching weights for: {model_id}")
178
- # Check if the model is already cached
179
  local_path = os.path.join(weights_dir, model_id.split("/")[-1])
180
  if os.path.exists(local_path):
181
  print(f"Already cached at: {local_path}")
182
  cached_paths[model_id] = local_path
183
  continue
184
- # Download the model and cache it
185
  print(f"Downloading and caching model: {model_id}")
186
- # Use snapshot_download to download the model
187
- local_path = snapshot_download(repo_id=model_id, local_dir=os.path.join(weights_dir, model_id.split("/")[-1]), force_download=False)
 
 
 
188
  cached_paths[model_id] = local_path
189
  print(f"Cached at: {local_path}")
190
-
191
  return cached_paths
192
 
 
 
 
193
  def preprocess_mesh(mesh_prompt):
194
  print("Processing mesh")
195
  trimesh_mesh = trimesh.load_mesh(mesh_prompt)
196
- trimesh_mesh.export(mesh_prompt+'.glb')
197
- return mesh_prompt+'.glb'
 
198
 
199
  def preprocess_image(image):
200
  if image is None:
201
  return None
202
- image = hi3dgen_pipeline.preprocess_image(image, resolution=1024)
203
- return image
204
-
205
- def generate_3d(image, seed=-1,
206
- ss_guidance_strength=3, ss_sampling_steps=50,
207
- slat_guidance_strength=3, slat_sampling_steps=6,):
 
 
 
 
208
  if image is None:
209
  return None, None, None
210
 
211
  if seed == -1:
212
  seed = np.random.randint(0, MAX_SEED)
213
-
214
  image = hi3dgen_pipeline.preprocess_image(image, resolution=1024)
215
- normal_image = normal_predictor(image, resolution=768, match_input_resolution=True, data_type='object')
 
 
 
 
 
216
 
217
  outputs = hi3dgen_pipeline.run(
218
  normal_image,
219
  seed=seed,
220
- formats=["mesh",],
221
  preprocess_image=False,
222
  sparse_structure_sampler_params={
223
  "steps": ss_sampling_steps,
@@ -229,80 +207,71 @@ def generate_3d(image, seed=-1,
229
  },
230
  )
231
  generated_mesh = outputs['mesh'][0]
232
-
233
- # Save outputs
234
- import datetime
235
-
236
 
 
237
  output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
238
  os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
239
  mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
240
-
241
- # Export mesh
242
- trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
243
 
 
244
  trimesh_mesh.export(mesh_path)
245
 
246
  return normal_image, mesh_path, mesh_path
247
 
248
  def convert_mesh(mesh_path, export_format):
249
- """Download the mesh in the selected format."""
250
  if not mesh_path:
251
  return None
252
-
253
- # Create a temporary file to store the mesh data
254
  temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
255
  temp_file_path = temp_file.name
256
-
257
- new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
258
  mesh = trimesh.load_mesh(mesh_path)
259
- mesh.export(temp_file_path) # Export to the temporary file
260
-
261
- return temp_file_path # Return the path to the temporary file
262
 
263
- # Create the Gradio interface with improved layout
 
 
264
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
265
  gr.Markdown(
266
  """
267
  <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
268
  <p style='text-align: center;'>
269
  <strong>V0.1, Introduced By
270
- <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
271
- <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
272
  </p>
273
  """
274
  )
275
-
276
  with gr.Row():
277
  gr.Markdown("""
278
- <p align="center">
279
- <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
280
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
281
- </a>
282
- <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
283
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
284
- </a>
285
- <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
286
- <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
287
- </a>
288
- <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
289
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
290
- </a>
291
- </p>
292
- """)
293
 
294
  with gr.Row():
295
  with gr.Column(scale=1):
296
  with gr.Tabs():
297
-
298
  with gr.Tab("Single Image"):
299
  with gr.Row():
300
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
301
  normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
302
-
303
  with gr.Tab("Multiple Images"):
304
- gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
305
-
 
 
306
  with gr.Accordion("Advanced Settings", open=False):
307
  seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
308
  gr.Markdown("#### Stage 1: Sparse Structure Generation")
@@ -313,15 +282,14 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
313
  with gr.Row():
314
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
315
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
316
-
317
  with gr.Group():
318
  with gr.Row():
319
  gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
320
-
321
- # Right column - Output
322
  with gr.Column(scale=1):
323
  with gr.Column():
324
- model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
325
  with gr.Column():
326
  export_format = gr.Dropdown(
327
  choices=["obj", "glb", "ply", "stl"],
@@ -335,11 +303,11 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
335
  inputs=[image_prompt],
336
  outputs=[image_prompt]
337
  )
338
-
339
  gen_shape_btn.click(
340
  generate_3d,
341
  inputs=[
342
- image_prompt, seed,
343
  ss_guidance_strength, ss_sampling_steps,
344
  slat_guidance_strength, slat_sampling_steps
345
  ],
@@ -348,15 +316,13 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
348
  lambda: gr.Button(interactive=True),
349
  outputs=[download_btn],
350
  )
351
-
352
-
353
  def update_download_button(mesh_path, export_format):
354
  if not mesh_path:
355
  return gr.File.update(value=None, interactive=False)
356
-
357
  download_path = convert_mesh(mesh_path, export_format)
358
  return download_path
359
-
360
  export_format.change(
361
  update_download_button,
362
  inputs=[model_output, export_format],
@@ -365,31 +331,19 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
365
  lambda: gr.Button(interactive=True),
366
  outputs=[download_btn],
367
  )
368
-
369
  examples = None
370
 
371
  gr.Markdown(
372
  """
373
  **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
374
- - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS) and we draw inspiration from the teams behind [Rodin](https://hyperhuman.deemos.com/rodin), [Tripo](https://www.tripo3d.ai/app/home), and [Dora](https://github.com/Seed3D/Dora).
375
- - **Normal Estimation:** Our Normal Estimation Model builds on the leading normal estimation research such as [StableNormal](https://github.com/hugoycj/StableNormal) and [GenPercept](https://github.com/aim-uofa/GenPercept).
376
-
377
- **Your contributions and collaboration push the boundaries of 3D modeling!**
378
  """
379
  )
380
 
 
 
 
381
  if __name__ == "__main__":
382
- # Download and cache the weights
383
- cache_weights(WEIGHTS_DIR)
384
-
385
- hi3dgen_pipeline = Hi3DGenPipeline.from_pretrained("weights/trellis-normal-v0-1")
386
- hi3dgen_pipeline.cuda()
387
-
388
- # Initialize normal predictor
389
- try:
390
- normal_predictor = torch.hub.load(os.path.join(torch.hub.get_dir(), 'hugoycj_StableNormal_main'), "StableNormal_turbo", yoso_version='yoso-normal-v1-8-1', source='local', local_cache_dir='./weights', pretrained=True)
391
- except:
392
- normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1', local_cache_dir='./weights')
393
-
394
- # Launch the app
395
- demo.launch(share=False, server_name="0.0.0.0")
 
1
  # MIT License
 
2
  # Copyright (c) Microsoft
 
3
  # Permission is hereby granted, free of charge, to any person obtaining a copy
4
  # of this software and associated documentation files (the "Software"), to deal
5
  # in the Software without restriction, including without limitation the rights
6
  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to do so, subject to the following conditions:
8
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
9
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
10
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
11
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 
 
 
 
 
 
 
 
12
 
13
  # Copyright (c) [2025] [Microsoft]
14
+ # Copyright (c) [2025] [Chongjie Ye]
15
  # SPDX-License-Identifier: MIT
16
  # This file has been modified by Chongjie Ye on 2025/04/10
17
+ # Original file was released under MIT, with the full license text available at:
18
+ # https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE
19
  # This modified file is released under the same license.
20
 
21
  import gradio as gr
22
  import os
23
 
24
  # ---- Force CPU-only environment globally ----
25
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # hide GPUs from torch
26
+ os.environ.setdefault("ATTN_BACKEND", "sdpa") # avoid xformers path
27
  os.environ.setdefault("SPCONV_ALGO", "native") # safe sparseconv algo
28
  # ---------------------------------------------
29
+
30
  from typing import *
31
  import torch
32
  import numpy as np
33
  import tempfile
 
34
  import zipfile
35
  import types
36
+ import importlib
37
 
38
  # ---------------------------------------------------------------------------
39
+ # Ensure bundled hi3dgen sources are available (extracted from hi3dgen.zip)
 
 
 
 
 
 
 
40
  # ---------------------------------------------------------------------------
 
41
  def _ensure_hi3dgen_available():
 
 
 
42
  pkg_name = 'hi3dgen'
43
+ here = os.path.dirname(__file__)
44
+ pkg_dir = os.path.join(here, pkg_name)
45
  if os.path.isdir(pkg_dir):
46
  return
47
+ archive_path = os.path.join(here, f"{pkg_name}.zip")
48
+ if not os.path.isfile(archive_path):
 
 
 
 
 
 
49
  raise FileNotFoundError(
50
+ f"Required archive {archive_path} is missing. Upload hi3dgen.zip next to app.py."
51
  )
52
+ try:
53
+ with zipfile.ZipFile(archive_path, 'r') as zf:
54
+ zf.extractall(here)
55
+ except Exception as e:
56
+ raise RuntimeError(f"Failed to extract {archive_path}: {e}")
57
 
 
58
  _ensure_hi3dgen_available()
59
 
60
  # ---------------------------------------------------------------------------
61
+ # xformers stub (CPU-friendly fallback for xformers.ops.memory_efficient_attention)
 
 
 
 
 
 
 
 
62
  # ---------------------------------------------------------------------------
 
63
  def _ensure_xformers_stub():
 
 
64
  import sys
 
65
  if 'xformers.ops' in sys.modules:
66
  return
67
  import torch.nn.functional as F
68
+ xf_mod = types.ModuleType('xformers')
 
69
  ops_mod = types.ModuleType('xformers.ops')
70
 
71
  def memory_efficient_attention(query, key, value, attn_bias=None):
 
 
 
 
 
 
 
 
72
  return F.scaled_dot_product_attention(query, key, value, attn_bias)
73
 
 
74
  ops_mod.memory_efficient_attention = memory_efficient_attention
75
+ xf_mod.ops = ops_mod
76
+ sys.modules['xformers'] = xf_mod
 
 
77
  sys.modules['xformers.ops'] = ops_mod
78
 
 
79
  _ensure_xformers_stub()
80
 
81
  # ---------------------------------------------------------------------------
82
+ # Import pipeline AFTER stubbing xformers, then patch CUDA-hotspots to CPU
83
+ # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  from hi3dgen.pipelines import Hi3DGenPipeline
85
  import trimesh
86
+
87
+ # ---- Force CPU inside hi3dgen (avoid any CUDA paths) ----
88
+ print("[PATCH] Applying CPU monkey-patches to hi3dgen")
89
+ # 1) utils_cube.construct_dense_grid(..., device=...) -> force CPU
90
+ uc = importlib.import_module("hi3dgen.representations.mesh.utils_cube")
91
+ if not hasattr(uc, "_CPU_PATCHED"):
92
+ _orig_construct_dense_grid = uc.construct_dense_grid
93
+ def _construct_dense_grid_cpu(res, device=None):
94
+ return _orig_construct_dense_grid(res, device="cpu")
95
+ uc.construct_dense_grid = _construct_dense_grid_cpu
96
+ uc._CPU_PATCHED = True
97
+ print("[PATCH] utils_cube.construct_dense_grid -> CPU")
98
+
99
+ # 2) cube2mesh.EnhancedMarchingCubes default device -> force CPU
100
+ cm = importlib.import_module("hi3dgen.representations.mesh.cube2mesh")
101
+ M = cm.EnhancedMarchingCubes
102
+ if not hasattr(M, "_CPU_PATCHED"):
103
+ _orig_init = M.__init__
104
+ def _init_cpu(self, res, device=None):
105
+ return _orig_init(self, res, device=torch.device("cpu"))
106
+ M.__init__ = _init_cpu
107
+ M._CPU_PATCHED = True
108
+ print("[PATCH] cube2mesh.EnhancedMarchingCubes.__init__ -> CPU")
109
+
110
+ # 3) Belt & suspenders: coerce torch.arange(device='cuda') to CPU if any call slips through
111
+ if not hasattr(torch, "_ARANGE_CPU_PATCHED"):
112
+ _orig_arange = torch.arange
113
+ def _arange_cpu(*args, **kwargs):
114
+ dev = kwargs.get("device", None)
115
+ if dev is not None and str(dev).startswith("cuda"):
116
+ kwargs["device"] = "cpu"
117
+ return _orig_arange(*args, **kwargs)
118
+ torch.arange = _arange_cpu
119
+ torch._ARANGE_CPU_PATCHED = True
120
+ print("[PATCH] torch.arange(device='cuda') -> CPU")
121
+ # ----------------------------------------------------------
122
+
123
  MAX_SEED = np.iinfo(np.int32).max
124
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
125
  WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights')
126
  os.makedirs(TMP_DIR, exist_ok=True)
127
  os.makedirs(WEIGHTS_DIR, exist_ok=True)
128
 
129
+ # ---------------------------------------------------------------------------
130
+ # Weights caching
131
+ # ---------------------------------------------------------------------------
132
  def cache_weights(weights_dir: str) -> dict:
 
133
  from huggingface_hub import snapshot_download
 
134
  os.makedirs(weights_dir, exist_ok=True)
135
  model_ids = [
136
  "Stable-X/trellis-normal-v0-1",
 
140
  cached_paths = {}
141
  for model_id in model_ids:
142
  print(f"Caching weights for: {model_id}")
 
143
  local_path = os.path.join(weights_dir, model_id.split("/")[-1])
144
  if os.path.exists(local_path):
145
  print(f"Already cached at: {local_path}")
146
  cached_paths[model_id] = local_path
147
  continue
 
148
  print(f"Downloading and caching model: {model_id}")
149
+ local_path = snapshot_download(
150
+ repo_id=model_id,
151
+ local_dir=os.path.join(weights_dir, model_id.split("/")[-1]),
152
+ force_download=False
153
+ )
154
  cached_paths[model_id] = local_path
155
  print(f"Cached at: {local_path}")
 
156
  return cached_paths
157
 
158
+ # ---------------------------------------------------------------------------
159
+ # Pre/Post processing and generation
160
+ # ---------------------------------------------------------------------------
161
  def preprocess_mesh(mesh_prompt):
162
  print("Processing mesh")
163
  trimesh_mesh = trimesh.load_mesh(mesh_prompt)
164
+ out_path = mesh_prompt + '.glb'
165
+ trimesh_mesh.export(out_path)
166
+ return out_path
167
 
168
  def preprocess_image(image):
169
  if image is None:
170
  return None
171
+ return hi3dgen_pipeline.preprocess_image(image, resolution=1024)
172
+
173
+ def generate_3d(
174
+ image,
175
+ seed: int = -1,
176
+ ss_guidance_strength: float = 3,
177
+ ss_sampling_steps: int = 50,
178
+ slat_guidance_strength: float = 3,
179
+ slat_sampling_steps: int = 6,
180
+ ):
181
  if image is None:
182
  return None, None, None
183
 
184
  if seed == -1:
185
  seed = np.random.randint(0, MAX_SEED)
186
+
187
  image = hi3dgen_pipeline.preprocess_image(image, resolution=1024)
188
+ normal_image = normal_predictor(
189
+ image,
190
+ resolution=768,
191
+ match_input_resolution=True,
192
+ data_type='object'
193
+ )
194
 
195
  outputs = hi3dgen_pipeline.run(
196
  normal_image,
197
  seed=seed,
198
+ formats=["mesh"],
199
  preprocess_image=False,
200
  sparse_structure_sampler_params={
201
  "steps": ss_sampling_steps,
 
207
  },
208
  )
209
  generated_mesh = outputs['mesh'][0]
 
 
 
 
210
 
211
+ import datetime
212
  output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
213
  os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
214
  mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
 
 
 
215
 
216
+ trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
217
  trimesh_mesh.export(mesh_path)
218
 
219
  return normal_image, mesh_path, mesh_path
220
 
221
  def convert_mesh(mesh_path, export_format):
 
222
  if not mesh_path:
223
  return None
 
 
224
  temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
225
  temp_file_path = temp_file.name
 
 
226
  mesh = trimesh.load_mesh(mesh_path)
227
+ mesh.export(temp_file_path)
228
+ return temp_file_path
 
229
 
230
+ # ---------------------------------------------------------------------------
231
+ # UI
232
+ # ---------------------------------------------------------------------------
233
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
234
  gr.Markdown(
235
  """
236
  <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
237
  <p style='text-align: center;'>
238
  <strong>V0.1, Introduced By
239
+ <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> (CUHKSZ) and
240
+ <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> (ByteDance)</strong>
241
  </p>
242
  """
243
  )
244
+
245
  with gr.Row():
246
  gr.Markdown("""
247
+ <p align="center">
248
+ <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
249
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
250
+ </a>
251
+ <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
252
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
253
+ </a>
254
+ <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
255
+ <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
256
+ </a>
257
+ <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
258
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
259
+ </a>
260
+ </p>
261
+ """)
262
 
263
  with gr.Row():
264
  with gr.Column(scale=1):
265
  with gr.Tabs():
 
266
  with gr.Tab("Single Image"):
267
  with gr.Row():
268
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
269
  normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
 
270
  with gr.Tab("Multiple Images"):
271
+ gr.Markdown(
272
+ "<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>"
273
+ )
274
+
275
  with gr.Accordion("Advanced Settings", open=False):
276
  seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
277
  gr.Markdown("#### Stage 1: Sparse Structure Generation")
 
282
  with gr.Row():
283
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
284
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
285
+
286
  with gr.Group():
287
  with gr.Row():
288
  gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
289
+
 
290
  with gr.Column(scale=1):
291
  with gr.Column():
292
+ model_output = gr.Model3D(label="3D Model Preview (Each model is ~40MB; may take ~1 min to load)")
293
  with gr.Column():
294
  export_format = gr.Dropdown(
295
  choices=["obj", "glb", "ply", "stl"],
 
303
  inputs=[image_prompt],
304
  outputs=[image_prompt]
305
  )
306
+
307
  gen_shape_btn.click(
308
  generate_3d,
309
  inputs=[
310
+ image_prompt, seed,
311
  ss_guidance_strength, ss_sampling_steps,
312
  slat_guidance_strength, slat_sampling_steps
313
  ],
 
316
  lambda: gr.Button(interactive=True),
317
  outputs=[download_btn],
318
  )
319
+
 
320
  def update_download_button(mesh_path, export_format):
321
  if not mesh_path:
322
  return gr.File.update(value=None, interactive=False)
 
323
  download_path = convert_mesh(mesh_path, export_format)
324
  return download_path
325
+
326
  export_format.change(
327
  update_download_button,
328
  inputs=[model_output, export_format],
 
331
  lambda: gr.Button(interactive=True),
332
  outputs=[download_btn],
333
  )
334
+
335
  examples = None
336
 
337
  gr.Markdown(
338
  """
339
  **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
340
+ - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS); inspired by [Rodin], [Tripo], and [Dora].
341
+ - **Normal Estimation:** Builds on [StableNormal] and [GenPercept].
 
 
342
  """
343
  )
344
 
345
+ # ---------------------------------------------------------------------------
346
+ # Entry
347
+ # ---------------------------------------------------------------------------
348
  if __name__ == "__main__":
349
+ # Cache model w