zhuhai111 commited on
Commit
fe582be
·
verified ·
1 Parent(s): a5dab0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -28
app.py CHANGED
@@ -31,7 +31,6 @@ def preprocess_image(image):
31
  image = pipeline.preprocess_image(image, resolution=1024)
32
  return image
33
 
34
- @spaces.GPU
35
  def generate_3d(image, seed=-1,
36
  ss_guidance_strength=3, ss_sampling_steps=50,
37
  slat_guidance_strength=3, slat_sampling_steps=6,):
@@ -70,11 +69,9 @@ def generate_3d(image, seed=-1,
70
  def combine_diagonal(color_np, normal_np):
71
  # Convert images to numpy arrays
72
  h, w, c = color_np.shape
73
- # Create a boolean mask that is True for pixels where x > y (diagonally)
74
  mask = np.fromfunction(lambda y, x: x > y, (h, w))
75
  mask = mask.astype(bool)
76
  mask = np.stack([mask] * c, axis=-1)
77
- # Where mask is True take color, else normal
78
  combined_np = np.where(mask, color_np, normal_np)
79
  return Image.fromarray(combined_np)
80
 
@@ -82,7 +79,6 @@ def generate_3d(image, seed=-1,
82
 
83
  # Export mesh
84
  trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
85
-
86
  trimesh_mesh.export(mesh_path)
87
 
88
  return preview_images, normal_image, mesh_path, mesh_path
@@ -92,17 +88,15 @@ def convert_mesh(mesh_path, export_format):
92
  if not mesh_path:
93
  return None
94
 
95
- # Create a temporary file to store the mesh data
96
  temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
97
  temp_file_path = temp_file.name
98
 
99
  new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
100
  mesh = trimesh.load_mesh(mesh_path)
101
- mesh.export(temp_file_path) # Export to the temporary file
102
 
103
- return temp_file_path # Return the path to the temporary file
104
 
105
- # Create the Gradio interface with improved layout
106
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
107
  gr.Markdown(
108
  """
@@ -118,17 +112,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
118
  with gr.Row():
119
  gr.Markdown("""
120
  <p align="center">
121
- <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
122
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
123
  </a>
124
- <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
125
  <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
126
  </a>
127
- <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
128
- <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
129
  </a>
130
- <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
131
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
132
  </a>
133
  </p>
134
  """)
@@ -136,7 +130,6 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
136
  with gr.Row():
137
  with gr.Column(scale=1):
138
  with gr.Tabs():
139
-
140
  with gr.Tab("Single Image"):
141
  with gr.Row():
142
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
@@ -160,11 +153,10 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
160
  with gr.Row():
161
  gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
162
 
163
- # Right column - Output
164
  with gr.Column(scale=1):
165
  with gr.Tabs():
166
  with gr.Tab("Preview"):
167
- output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto",show_label=False)
168
  with gr.Tab("3D Model"):
169
  with gr.Column():
170
  model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
@@ -194,8 +186,7 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
194
  lambda: gr.Button(interactive=True),
195
  outputs=[download_btn],
196
  )
197
-
198
-
199
  def update_download_button(mesh_path, export_format):
200
  if not mesh_path:
201
  return gr.File.update(value=None, interactive=False)
@@ -223,20 +214,16 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
223
  gr.Markdown(
224
  """
225
  **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
226
- - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS) and we draw inspiration from the teams behind [Rodin](https://hyperhuman.deemos.com/rodin), [Tripo](https://www.tripo3d.ai/app/home), and [Dora](https://github.com/Seed3D/Dora).
227
- - **Normal Estimation:** Our Normal Estimation Model builds on the leading normal estimation research such as [StableNormal](https://github.com/hugoycj/StableNormal) and [GenPercept](https://github.com/aim-uofa/GenPercept).
228
-
229
- **Your contributions and collaboration push the boundaries of 3D modeling!**
230
  """
231
  )
232
 
233
  if __name__ == "__main__":
234
- # Initialize pipeline
235
  pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
236
- pipeline.cuda()
237
-
238
- # Initialize normal predictor
239
  normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
240
 
241
- # Launch the app
242
  demo.launch()
 
31
  image = pipeline.preprocess_image(image, resolution=1024)
32
  return image
33
 
 
34
  def generate_3d(image, seed=-1,
35
  ss_guidance_strength=3, ss_sampling_steps=50,
36
  slat_guidance_strength=3, slat_sampling_steps=6,):
 
69
  def combine_diagonal(color_np, normal_np):
70
  # Convert images to numpy arrays
71
  h, w, c = color_np.shape
 
72
  mask = np.fromfunction(lambda y, x: x > y, (h, w))
73
  mask = mask.astype(bool)
74
  mask = np.stack([mask] * c, axis=-1)
 
75
  combined_np = np.where(mask, color_np, normal_np)
76
  return Image.fromarray(combined_np)
77
 
 
79
 
80
  # Export mesh
81
  trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
 
82
  trimesh_mesh.export(mesh_path)
83
 
84
  return preview_images, normal_image, mesh_path, mesh_path
 
88
  if not mesh_path:
89
  return None
90
 
 
91
  temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
92
  temp_file_path = temp_file.name
93
 
94
  new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
95
  mesh = trimesh.load_mesh(mesh_path)
96
+ mesh.export(temp_file_path)
97
 
98
+ return temp_file_path
99
 
 
100
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
101
  gr.Markdown(
102
  """
 
112
  with gr.Row():
113
  gr.Markdown("""
114
  <p align="center">
115
+ <a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank">
116
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
117
  </a>
118
+ <a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank">
119
  <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
120
  </a>
121
+ <a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank">
122
+ <img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C">
123
  </a>
124
+ <a title="Social" href="https://x.com/ychngji6" target="_blank">
125
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg">
126
  </a>
127
  </p>
128
  """)
 
130
  with gr.Row():
131
  with gr.Column(scale=1):
132
  with gr.Tabs():
 
133
  with gr.Tab("Single Image"):
134
  with gr.Row():
135
  image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
 
153
  with gr.Row():
154
  gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
155
 
 
156
  with gr.Column(scale=1):
157
  with gr.Tabs():
158
  with gr.Tab("Preview"):
159
+ output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
160
  with gr.Tab("3D Model"):
161
  with gr.Column():
162
  model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
 
186
  lambda: gr.Button(interactive=True),
187
  outputs=[download_btn],
188
  )
189
+
 
190
  def update_download_button(mesh_path, export_format):
191
  if not mesh_path:
192
  return gr.File.update(value=None, interactive=False)
 
214
  gr.Markdown(
215
  """
216
  **Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
217
+ - **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS)
218
+ - **Normal Estimation:** We build on top of [StableNormal](https://github.com/hugoycj/StableNormal)
 
 
219
  """
220
  )
221
 
222
  if __name__ == "__main__":
223
+ # 初始化 pipeline,使用 CPU
224
  pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
225
+
226
+ # 加载 StableNormal 模型(会自动选择 CPU)
 
227
  normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
228
 
 
229
  demo.launch()