File size: 15,038 Bytes
d21f0b7
e57cb55
d21f0b7
 
3a704bc
 
 
 
 
 
 
 
 
787d8e6
 
76ef786
 
787d8e6
 
76ef786
d21f0b7
76ef786
d21f0b7
 
 
76ef786
 
d21f0b7
 
76ef786
 
d21f0b7
76ef786
d21f0b7
76ef786
 
 
 
 
d21f0b7
 
 
d6b0d62
76ef786
d6b0d62
 
 
 
 
 
76ef786
d6b0d62
 
 
3a704bc
d6b0d62
 
 
76ef786
 
d6b0d62
 
 
 
be08e83
3a704bc
76ef786
 
e57cb55
76ef786
 
 
3a704bc
 
76ef786
3a704bc
 
 
76ef786
 
 
 
e57cb55
76ef786
 
 
 
3a704bc
e57cb55
3a704bc
e57cb55
 
 
 
 
3a704bc
76ef786
 
e57cb55
76ef786
3a704bc
 
 
 
 
 
 
76ef786
 
3a704bc
76ef786
 
 
 
 
3a704bc
76ef786
 
 
3a704bc
 
 
 
 
 
76ef786
d21f0b7
 
 
 
 
 
76ef786
 
 
d21f0b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ef786
 
 
 
 
d21f0b7
 
 
 
76ef786
 
 
d21f0b7
 
 
76ef786
 
 
d21f0b7
 
 
 
76ef786
 
 
 
 
 
 
 
 
 
d21f0b7
 
 
 
 
76ef786
d21f0b7
76ef786
 
 
 
 
 
d21f0b7
 
 
 
76ef786
d21f0b7
 
 
 
 
 
 
 
 
 
 
 
76ef786
d21f0b7
 
 
 
76ef786
d21f0b7
 
 
 
 
 
 
 
 
 
76ef786
 
d21f0b7
76ef786
 
 
d21f0b7
 
 
 
 
 
76ef786
 
d21f0b7
 
 
76ef786
d21f0b7
 
76ef786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d21f0b7
 
 
 
 
 
 
 
 
76ef786
 
 
 
d21f0b7
 
 
 
 
 
 
 
 
 
76ef786
d21f0b7
 
 
76ef786
d21f0b7
 
76ef786
d21f0b7
 
 
 
 
 
 
 
 
 
 
 
 
76ef786
d21f0b7
 
 
76ef786
d21f0b7
 
 
 
 
 
 
 
76ef786
d21f0b7
 
 
 
 
76ef786
d21f0b7
 
 
 
 
 
 
 
76ef786
d21f0b7
 
 
 
 
3a704bc
e57cb55
d21f0b7
 
 
76ef786
 
 
d21f0b7
3c768e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
# MIT License
# (see original notice and terms)

import os
import types
import zipfile
import importlib
from typing import *

import gradio as gr
import numpy as np
import torch
import tempfile

# ---- Force CPU-only environment globally ----
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"          # hide GPUs from torch
os.environ.setdefault("ATTN_BACKEND", "sdpa")      # avoid xformers path
os.environ.setdefault("SPCONV_ALGO", "native")     # safe sparseconv algo
# ---------------------------------------------

# ---------------------------------------------------------------------------
# Ensure bundled hi3dgen sources are available (extracted from hi3dgen.zip)
# ---------------------------------------------------------------------------
def _ensure_hi3dgen_available():
    pkg_name = 'hi3dgen'
    here = os.path.dirname(__file__)
    pkg_dir = os.path.join(here, pkg_name)
    if os.path.isdir(pkg_dir):
        return
    archive_path = os.path.join(here, f"{pkg_name}.zip")
    if not os.path.isfile(archive_path):
        raise FileNotFoundError(
            f"Required archive {archive_path} is missing. Upload hi3dgen.zip next to app.py."
        )
    try:
        with zipfile.ZipFile(archive_path, 'r') as zf:
            zf.extractall(here)
    except Exception as e:
        raise RuntimeError(f"Failed to extract {archive_path}: {e}")

_ensure_hi3dgen_available()

# ---------------------------------------------------------------------------
# xformers stub (CPU-friendly fallback for xformers.ops.memory_efficient_attention)
# ---------------------------------------------------------------------------
def _ensure_xformers_stub():
    import sys
    if 'xformers.ops' in sys.modules:
        return
    import torch.nn.functional as F
    xf_mod = types.ModuleType('xformers')
    ops_mod = types.ModuleType('xformers.ops')

    def memory_efficient_attention(query, key, value, attn_bias=None):
        # SDPA fallback
        return F.scaled_dot_product_attention(query, key, value, attn_bias)

    ops_mod.memory_efficient_attention = memory_efficient_attention
    xf_mod.ops = ops_mod
    sys.modules['xformers'] = xf_mod
    sys.modules['xformers.ops'] = ops_mod

_ensure_xformers_stub()

# ---------------------------------------------------------------------------
# Patch CUDA hotspots to CPU **BEFORE** importing the pipeline
# ---------------------------------------------------------------------------
print("[PATCH] Applying CPU monkey-patches to hi3dgen")

# 1) utils_cube.construct_dense_grid(..., device=...) -> force CPU
uc = importlib.import_module("hi3dgen.representations.mesh.utils_cube")
if not hasattr(uc, "_CPU_PATCHED"):
    _uc_orig_construct_dense_grid = uc.construct_dense_grid

    def _construct_dense_grid_cpu(res, device=None):
        # ignore any requested device, always CPU
        return _uc_orig_construct_dense_grid(res, device="cpu")

    uc.construct_dense_grid = _construct_dense_grid_cpu
    uc._CPU_PATCHED = True
    print("[PATCH] utils_cube.construct_dense_grid -> CPU")

# 2) cube2mesh.EnhancedMarchingCubes default device -> force CPU (flexible)
cm = importlib.import_module("hi3dgen.representations.mesh.cube2mesh")
M = cm.EnhancedMarchingCubes
if not hasattr(M, "_CPU_PATCHED"):
    _orig_init = M.__init__

    def _init_cpu(self, *args, **kwargs):
        # ensure device ends up on CPU regardless of how it's passed
        if "device" in kwargs:
            kwargs["device"] = torch.device("cpu")
        else:
            kwargs.setdefault("device", torch.device("cpu"))
        return _orig_init(self, *args, **kwargs)

    M.__init__ = _init_cpu
    M._CPU_PATCHED = True
    print("[PATCH] cube2mesh.EnhancedMarchingCubes.__init__ -> CPU (flex)")

# 3) IMPORTANT: cube2mesh does "from .utils_cube import construct_dense_grid"
#    so we must override the BOUND symbol inside cube2mesh as well.
if getattr(cm, "construct_dense_grid", None) is not _construct_dense_grid_cpu:
    cm.construct_dense_grid = _construct_dense_grid_cpu
    print("[PATCH] cube2mesh.construct_dense_grid (bound name) -> CPU")

# 4) Belt & suspenders: coerce torch.arange(device='cuda') to CPU if anything slips through
if not hasattr(torch, "_ARANGE_CPU_PATCHED"):
    _orig_arange = torch.arange

    def _arange_cpu(*args, **kwargs):
        dev = kwargs.get("device", None)
        if dev is not None and str(dev).startswith("cuda"):
            kwargs["device"] = "cpu"
        return _orig_arange(*args, **kwargs)

    torch.arange = _arange_cpu
    torch._ARANGE_CPU_PATCHED = True
    print("[PATCH] torch.arange(device='cuda') -> CPU")

# ---------------------------------------------------------------------------
# Now import pipeline (AFTER patches so bound names are already overridden)
# ---------------------------------------------------------------------------
from hi3dgen.pipelines import Hi3DGenPipeline
import trimesh

MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights')
os.makedirs(TMP_DIR, exist_ok=True)
os.makedirs(WEIGHTS_DIR, exist_ok=True)

# ---------------------------------------------------------------------------
# Weights caching
# ---------------------------------------------------------------------------
def cache_weights(weights_dir: str) -> dict:
    from huggingface_hub import snapshot_download
    os.makedirs(weights_dir, exist_ok=True)
    model_ids = [
        "Stable-X/trellis-normal-v0-1",
        "Stable-X/yoso-normal-v1-8-1",
        "ZhengPeng7/BiRefNet",
    ]
    cached_paths = {}
    for model_id in model_ids:
        print(f"Caching weights for: {model_id}")
        local_path = os.path.join(weights_dir, model_id.split("/")[-1])
        if os.path.exists(local_path):
            print(f"Already cached at: {local_path}")
            cached_paths[model_id] = local_path
            continue
        print(f"Downloading and caching model: {model_id}")
        local_path = snapshot_download(
            repo_id=model_id,
            local_dir=os.path.join(weights_dir, model_id.split("/")[-1]),
            force_download=False
        )
        cached_paths[model_id] = local_path
        print(f"Cached at: {local_path}")
    return cached_paths

# ---------------------------------------------------------------------------
# Pre/Post processing and generation
# ---------------------------------------------------------------------------
def preprocess_mesh(mesh_prompt):
    print("Processing mesh")
    trimesh_mesh = trimesh.load_mesh(mesh_prompt)
    out_path = mesh_prompt + '.glb'
    trimesh_mesh.export(out_path)
    return out_path

def preprocess_image(image):
    if image is None:
        return None
    return hi3dgen_pipeline.preprocess_image(image, resolution=1024)

def generate_3d(
    image,
    seed: int = -1,
    ss_guidance_strength: float = 3,
    ss_sampling_steps: int = 50,
    slat_guidance_strength: float = 3,
    slat_sampling_steps: int = 6,
):
    if image is None:
        return None, None, None

    if seed == -1:
        seed = np.random.randint(0, MAX_SEED)

    image = hi3dgen_pipeline.preprocess_image(image, resolution=1024)
    normal_image = normal_predictor(
        image,
        resolution=768,
        match_input_resolution=True,
        data_type='object'
    )

    outputs = hi3dgen_pipeline.run(
        normal_image,
        seed=seed,
        formats=["mesh"],
        preprocess_image=False,
        sparse_structure_sampler_params={
            "steps": ss_sampling_steps,
            "cfg_strength": ss_guidance_strength,
        },
        slat_sampler_params={
            "steps": slat_sampling_steps,
            "cfg_strength": slat_guidance_strength,
        },
    )
    generated_mesh = outputs['mesh'][0]

    import datetime
    output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
    mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"

    trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
    trimesh_mesh.export(mesh_path)

    return normal_image, mesh_path, mesh_path

def convert_mesh(mesh_path, export_format):
    if not mesh_path:
        return None
    temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
    temp_file_path = temp_file.name
    mesh = trimesh.load_mesh(mesh_path)
    mesh.export(temp_file_path)
    return temp_file_path

# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(css="footer {visibility: hidden}") as demo:
    gr.Markdown(
        """
        <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
        <p style='text-align: center;'>
            <strong>V0.1, Introduced By 
            <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> (CUHKSZ) and 
            <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> (ByteDance)</strong>
        </p>
        """
    )

    with gr.Row():
        gr.Markdown("""
            <p align="center">
            <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
            </a>
            <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
            </a>
            <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
            </a>
            <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
                <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
            </a>
            </p>
        """)

    with gr.Row():
        with gr.Column(scale=1):
            with gr.Tabs():
                with gr.Tab("Single Image"):
                    with gr.Row():
                        image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
                        normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
                with gr.Tab("Multiple Images"):
                    gr.Markdown(
                        "<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>"
                    )

            with gr.Accordion("Advanced Settings", open=False):
                seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
                gr.Markdown("#### Stage 1: Sparse Structure Generation")
                with gr.Row():
                    ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3, step=0.1)
                    ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=50, step=1)
                gr.Markdown("#### Stage 2: Structured Latent Generation")
                with gr.Row():
                    slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
                    slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)

            with gr.Group():
                with gr.Row():
                    gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")

        with gr.Column(scale=1):
            with gr.Column():
                model_output = gr.Model3D(label="3D Model Preview (Each model is ~40MB; may take ~1 min to load)")
            with gr.Column():
                export_format = gr.Dropdown(
                    choices=["obj", "glb", "ply", "stl"],
                    value="glb",
                    label="File Format"
                )
                download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)

    image_prompt.upload(
        preprocess_image,
        inputs=[image_prompt],
        outputs=[image_prompt]
    )

    gen_shape_btn.click(
        generate_3d,
        inputs=[
            image_prompt, seed,
            ss_guidance_strength, ss_sampling_steps,
            slat_guidance_strength, slat_sampling_steps
        ],
        outputs=[normal_output, model_output, download_btn]
    ).then(
        lambda: gr.Button(interactive=True),
        outputs=[download_btn],
    )

    def update_download_button(mesh_path, export_format):
        if not mesh_path:
            return gr.File.update(value=None, interactive=False)
        download_path = convert_mesh(mesh_path, export_format)
        return download_path

    export_format.change(
        update_download_button,
        inputs=[model_output, export_format],
        outputs=[download_btn]
    ).then(
        lambda: gr.Button(interactive=True),
        outputs=[download_btn],
    )

    examples = None

    gr.Markdown(
        """
        **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
        - **3D Modeling:** Finetuned from the SOTA open-source 3D foundation model [Trellis].
        - **Normal Estimation:** Builds on StableNormal and GenPercept.
        """
    )

# ---------------------------------------------------------------------------
# Entry
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    # Cache model weights locally
    cache_weights(WEIGHTS_DIR)

    # Load pipeline on CPU
    hi3dgen_pipeline = Hi3DGenPipeline.from_pretrained("weights/trellis-normal-v0-1")
    try:
        hi3dgen_pipeline.to("cpu")
    except Exception:
        pass  # some pipelines may not implement .to

    # Initialize normal predictor (CPU)
    try:
        normal_predictor = torch.hub.load(
            os.path.join(torch.hub.get_dir(), 'hugoycj_StableNormal_main'),
            "StableNormal_turbo",
            yoso_version='yoso-normal-v1-8-1',
            source='local',
            local_cache_dir='./weights',
            pretrained=True
        )
    except Exception:
        normal_predictor = torch.hub.load(
            "hugoycj/StableNormal",
            "StableNormal_turbo",
            trust_repo=True,
            yoso_version='yoso-normal-v1-8-1',
            local_cache_dir='./weights'
        )
    try:
        normal_predictor.to("cpu")
    except Exception:
        pass

    # Launch the Gradio app
    demo.launch(share=False, server_name="0.0.0.0")