zhuhai111 commited on
Commit
53ef571
·
verified ·
1 Parent(s): fe582be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -68
app.py CHANGED
@@ -1,19 +1,20 @@
1
  import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
-
5
  import os
6
  import shutil
7
- os.environ['SPCONV_ALGO'] = 'native'
8
- from typing import *
9
- import torch
10
  import numpy as np
 
11
  import imageio
 
 
12
  from PIL import Image
 
 
13
  from trellis.pipelines import TrellisImageTo3DPipeline
14
  from trellis.utils import render_utils
15
- import trimesh
16
- import tempfile
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
@@ -35,7 +36,7 @@ def generate_3d(image, seed=-1,
35
  ss_guidance_strength=3, ss_sampling_steps=50,
36
  slat_guidance_strength=3, slat_sampling_steps=6,):
37
  if image is None:
38
- return None, None, None
39
 
40
  if seed == -1:
41
  seed = np.random.randint(0, MAX_SEED)
@@ -59,73 +60,43 @@ def generate_3d(image, seed=-1,
59
  )
60
  generated_mesh = outputs['mesh'][0]
61
 
62
- # Save outputs
63
- import datetime
64
  output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
65
  os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
66
  mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
67
 
68
  render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True)
 
69
  def combine_diagonal(color_np, normal_np):
70
- # Convert images to numpy arrays
71
  h, w, c = color_np.shape
72
- mask = np.fromfunction(lambda y, x: x > y, (h, w))
73
- mask = mask.astype(bool)
74
  mask = np.stack([mask] * c, axis=-1)
75
  combined_np = np.where(mask, color_np, normal_np)
76
  return Image.fromarray(combined_np)
77
 
78
  preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])]
79
 
80
- # Export mesh
81
  trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
82
  trimesh_mesh.export(mesh_path)
83
 
84
  return preview_images, normal_image, mesh_path, mesh_path
85
 
86
  def convert_mesh(mesh_path, export_format):
87
- """Download the mesh in the selected format."""
88
  if not mesh_path:
89
  return None
90
-
91
  temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
92
- temp_file_path = temp_file.name
93
-
94
- new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
95
  mesh = trimesh.load_mesh(mesh_path)
96
- mesh.export(temp_file_path)
97
-
98
- return temp_file_path
99
 
100
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
101
- gr.Markdown(
102
- """
103
  <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
104
  <p style='text-align: center;'>
105
  <strong>V0.1, Introduced By
106
  <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
107
  <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
108
  </p>
109
- """
110
- )
111
-
112
- with gr.Row():
113
- gr.Markdown("""
114
- <p align="center">
115
- <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank">
116
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
117
- </a>
118
- <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank">
119
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
120
- </a>
121
- <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank">
122
- <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C">
123
- </a>
124
- <a title="Social" href="https://x.com/ychngji6" target="_blank">
125
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg">
126
- </a>
127
- </p>
128
- """)
129
 
130
  with gr.Row():
131
  with gr.Column(scale=1):
@@ -134,10 +105,8 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
134
  with gr.Row():
135
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
136
  normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
137
-
138
  with gr.Tab("Multiple Images"):
139
  gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
140
-
141
  with gr.Accordion("Advanced Settings", open=False):
142
  seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
143
  gr.Markdown("#### Stage 1: Sparse Structure Generation")
@@ -148,18 +117,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
148
  with gr.Row():
149
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
150
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
151
-
152
  with gr.Group():
153
  with gr.Row():
154
  gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
155
-
156
  with gr.Column(scale=1):
157
  with gr.Tabs():
158
  with gr.Tab("Preview"):
159
  output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
160
  with gr.Tab("3D Model"):
161
  with gr.Column():
162
- model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
163
  with gr.Column():
164
  export_format = gr.Dropdown(
165
  choices=["obj", "glb", "ply", "stl"],
@@ -167,17 +135,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
167
  label="File Format"
168
  )
169
  download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
170
-
171
  image_prompt.upload(
172
  preprocess_image,
173
  inputs=[image_prompt],
174
  outputs=[image_prompt]
175
  )
176
-
177
  gen_shape_btn.click(
178
  generate_3d,
179
  inputs=[
180
- image_prompt, seed,
181
  ss_guidance_strength, ss_sampling_steps,
182
  slat_guidance_strength, slat_sampling_steps
183
  ],
@@ -190,10 +158,9 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
190
  def update_download_button(mesh_path, export_format):
191
  if not mesh_path:
192
  return gr.File.update(value=None, interactive=False)
193
-
194
  download_path = convert_mesh(mesh_path, export_format)
195
  return download_path
196
-
197
  export_format.change(
198
  update_download_button,
199
  inputs=[model_output, export_format],
@@ -202,7 +169,7 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
202
  lambda: gr.Button(interactive=True),
203
  outputs=[download_btn],
204
  )
205
-
206
  examples = gr.Examples(
207
  examples=[
208
  f'assets/example_image/{image}'
@@ -211,19 +178,23 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
211
  inputs=image_prompt,
212
  )
213
 
214
- gr.Markdown(
215
- """
216
- **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
217
- - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS)
218
- - **Normal Estimation:** We build on top of [StableNormal](https://github.com/hugoycj/StableNormal)
219
- """
220
- )
221
 
222
  if __name__ == "__main__":
223
- # 初始化 pipeline,使用 CPU
224
  pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
225
-
226
- # 加载 StableNormal 模型(会自动选择 CPU)
227
- normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
228
-
 
 
 
 
 
 
229
  demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import os
3
  import shutil
4
+ import tempfile
5
+ import datetime
 
6
  import numpy as np
7
+ import torch
8
  import imageio
9
+ import trimesh
10
+
11
  from PIL import Image
12
+ from typing import *
13
+ from gradio_litmodel3d import LitModel3D
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
  from trellis.utils import render_utils
16
+
17
+ os.environ['SPCONV_ALGO'] = 'native'
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
 
36
  ss_guidance_strength=3, ss_sampling_steps=50,
37
  slat_guidance_strength=3, slat_sampling_steps=6,):
38
  if image is None:
39
+ return None, None, None, None
40
 
41
  if seed == -1:
42
  seed = np.random.randint(0, MAX_SEED)
 
60
  )
61
  generated_mesh = outputs['mesh'][0]
62
 
 
 
63
  output_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
64
  os.makedirs(os.path.join(TMP_DIR, output_id), exist_ok=True)
65
  mesh_path = f"{TMP_DIR}/{output_id}/mesh.glb"
66
 
67
  render_results = render_utils.render_video(generated_mesh, resolution=1024, ssaa=1, num_frames=8, pitch=0.25, inverse_direction=True)
68
+
69
  def combine_diagonal(color_np, normal_np):
 
70
  h, w, c = color_np.shape
71
+ mask = np.fromfunction(lambda y, x: x > y, (h, w)).astype(bool)
 
72
  mask = np.stack([mask] * c, axis=-1)
73
  combined_np = np.where(mask, color_np, normal_np)
74
  return Image.fromarray(combined_np)
75
 
76
  preview_images = [combine_diagonal(c, n) for c, n in zip(render_results['color'], render_results['normal'])]
77
 
 
78
  trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
79
  trimesh_mesh.export(mesh_path)
80
 
81
  return preview_images, normal_image, mesh_path, mesh_path
82
 
83
  def convert_mesh(mesh_path, export_format):
 
84
  if not mesh_path:
85
  return None
 
86
  temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
 
 
 
87
  mesh = trimesh.load_mesh(mesh_path)
88
+ mesh.export(temp_file.name)
89
+ return temp_file.name
 
90
 
91
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
92
+ gr.Markdown("""
 
93
  <h1 style='text-align: center;'>Hi3DGen: High-fidelity 3D Geometry Generation from Images via Normal Bridging</h1>
94
  <p style='text-align: center;'>
95
  <strong>V0.1, Introduced By
96
  <a href="https://gaplab.cuhk.edu.cn/" target="_blank">GAP Lab</a> from CUHKSZ and
97
  <a href="https://www.nvsgames.cn/" target="_blank">Game-AIGC Team</a> from ByteDance</strong>
98
  </p>
99
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
  with gr.Column(scale=1):
 
105
  with gr.Row():
106
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
107
  normal_output = gr.Image(label="Normal Bridge", image_mode="RGBA", type="pil")
 
108
  with gr.Tab("Multiple Images"):
109
  gr.Markdown("<div style='text-align: center; padding: 40px; font-size: 24px;'>Multiple Images functionality is coming soon!</div>")
 
110
  with gr.Accordion("Advanced Settings", open=False):
111
  seed = gr.Slider(-1, MAX_SEED, label="Seed", value=0, step=1)
112
  gr.Markdown("#### Stage 1: Sparse Structure Generation")
 
117
  with gr.Row():
118
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
119
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=6, step=1)
 
120
  with gr.Group():
121
  with gr.Row():
122
  gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
123
+
124
  with gr.Column(scale=1):
125
  with gr.Tabs():
126
  with gr.Tab("Preview"):
127
  output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
128
  with gr.Tab("3D Model"):
129
  with gr.Column():
130
+ model_output = gr.Model3D(label="3D Model Preview (Each model is approx. 40MB)")
131
  with gr.Column():
132
  export_format = gr.Dropdown(
133
  choices=["obj", "glb", "ply", "stl"],
 
135
  label="File Format"
136
  )
137
  download_btn = gr.DownloadButton(label="Export Mesh", interactive=False)
138
+
139
  image_prompt.upload(
140
  preprocess_image,
141
  inputs=[image_prompt],
142
  outputs=[image_prompt]
143
  )
144
+
145
  gen_shape_btn.click(
146
  generate_3d,
147
  inputs=[
148
+ image_prompt, seed,
149
  ss_guidance_strength, ss_sampling_steps,
150
  slat_guidance_strength, slat_sampling_steps
151
  ],
 
158
  def update_download_button(mesh_path, export_format):
159
  if not mesh_path:
160
  return gr.File.update(value=None, interactive=False)
 
161
  download_path = convert_mesh(mesh_path, export_format)
162
  return download_path
163
+
164
  export_format.change(
165
  update_download_button,
166
  inputs=[model_output, export_format],
 
169
  lambda: gr.Button(interactive=True),
170
  outputs=[download_btn],
171
  )
172
+
173
  examples = gr.Examples(
174
  examples=[
175
  f'assets/example_image/{image}'
 
178
  inputs=image_prompt,
179
  )
180
 
181
+ gr.Markdown("""
182
+ **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We acknowledge contributions from:
183
+ - [Trellis 3D](https://github.com/microsoft/TRELLIS)
184
+ - [StableNormal](https://github.com/hugoycj/StableNormal)
185
+ """)
 
 
186
 
187
  if __name__ == "__main__":
188
+ # 强制使用 CPU
189
  pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
190
+ pipeline.to("cpu") # <-- 强制使用 CPU
191
+
192
+ normal_predictor = torch.hub.load(
193
+ "hugoycj/StableNormal",
194
+ "StableNormal_turbo",
195
+ trust_repo=True,
196
+ yoso_version="yoso-normal-v1-8-1"
197
+ )
198
+ normal_predictor.to("cpu") # <-- 也强制使用 CPU
199
+
200
  demo.launch()