Update app.py
Browse files
app.py
CHANGED
@@ -31,7 +31,6 @@ def preprocess_image(image):
|
|
31 |
image = pipeline.preprocess_image(image, resolution=1024)
|
32 |
return image
|
33 |
|
34 |
-
@spaces.GPU
|
35 |
def generate_3d(image, seed=-1,
|
36 |
ss_guidance_strength=3, ss_sampling_steps=50,
|
37 |
slat_guidance_strength=3, slat_sampling_steps=6,):
|
@@ -70,11 +69,9 @@ def generate_3d(image, seed=-1,
|
|
70 |
def combine_diagonal(color_np, normal_np):
|
71 |
# Convert images to numpy arrays
|
72 |
h, w, c = color_np.shape
|
73 |
-
# Create a boolean mask that is True for pixels where x > y (diagonally)
|
74 |
mask = np.fromfunction(lambda y, x: x > y, (h, w))
|
75 |
mask = mask.astype(bool)
|
76 |
mask = np.stack([mask] * c, axis=-1)
|
77 |
-
# Where mask is True take color, else normal
|
78 |
combined_np = np.where(mask, color_np, normal_np)
|
79 |
return Image.fromarray(combined_np)
|
80 |
|
@@ -82,7 +79,6 @@ def generate_3d(image, seed=-1,
|
|
82 |
|
83 |
# Export mesh
|
84 |
trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
|
85 |
-
|
86 |
trimesh_mesh.export(mesh_path)
|
87 |
|
88 |
return preview_images, normal_image, mesh_path, mesh_path
|
@@ -92,17 +88,15 @@ def convert_mesh(mesh_path, export_format):
|
|
92 |
if not mesh_path:
|
93 |
return None
|
94 |
|
95 |
-
# Create a temporary file to store the mesh data
|
96 |
temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
|
97 |
temp_file_path = temp_file.name
|
98 |
|
99 |
new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
|
100 |
mesh = trimesh.load_mesh(mesh_path)
|
101 |
-
mesh.export(temp_file_path)
|
102 |
|
103 |
-
return temp_file_path
|
104 |
|
105 |
-
# Create the Gradio interface with improved layout
|
106 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
107 |
gr.Markdown(
|
108 |
"""
|
@@ -118,17 +112,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
118 |
with gr.Row():
|
119 |
gr.Markdown("""
|
120 |
<p align="center">
|
121 |
-
<a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank"
|
122 |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
|
123 |
</a>
|
124 |
-
<a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank"
|
125 |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
126 |
</a>
|
127 |
-
<a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank"
|
128 |
-
<img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C"
|
129 |
</a>
|
130 |
-
<a title="Social" href="https://x.com/ychngji6" target="_blank"
|
131 |
-
<img src="https://www.obukhov.ai/img/badges/badge-social.svg"
|
132 |
</a>
|
133 |
</p>
|
134 |
""")
|
@@ -136,7 +130,6 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
136 |
with gr.Row():
|
137 |
with gr.Column(scale=1):
|
138 |
with gr.Tabs():
|
139 |
-
|
140 |
with gr.Tab("Single Image"):
|
141 |
with gr.Row():
|
142 |
image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
|
@@ -160,11 +153,10 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
160 |
with gr.Row():
|
161 |
gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
|
162 |
|
163 |
-
# Right column - Output
|
164 |
with gr.Column(scale=1):
|
165 |
with gr.Tabs():
|
166 |
with gr.Tab("Preview"):
|
167 |
-
output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto",show_label=False)
|
168 |
with gr.Tab("3D Model"):
|
169 |
with gr.Column():
|
170 |
model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
|
@@ -194,8 +186,7 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
194 |
lambda: gr.Button(interactive=True),
|
195 |
outputs=[download_btn],
|
196 |
)
|
197 |
-
|
198 |
-
|
199 |
def update_download_button(mesh_path, export_format):
|
200 |
if not mesh_path:
|
201 |
return gr.File.update(value=None, interactive=False)
|
@@ -223,20 +214,16 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
223 |
gr.Markdown(
|
224 |
"""
|
225 |
**Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
|
226 |
-
- **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS)
|
227 |
-
- **Normal Estimation:**
|
228 |
-
|
229 |
-
**Your contributions and collaboration push the boundaries of 3D modeling!**
|
230 |
"""
|
231 |
)
|
232 |
|
233 |
if __name__ == "__main__":
|
234 |
-
#
|
235 |
pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
|
236 |
-
|
237 |
-
|
238 |
-
# Initialize normal predictor
|
239 |
normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
|
240 |
|
241 |
-
# Launch the app
|
242 |
demo.launch()
|
|
|
31 |
image = pipeline.preprocess_image(image, resolution=1024)
|
32 |
return image
|
33 |
|
|
|
34 |
def generate_3d(image, seed=-1,
|
35 |
ss_guidance_strength=3, ss_sampling_steps=50,
|
36 |
slat_guidance_strength=3, slat_sampling_steps=6,):
|
|
|
69 |
def combine_diagonal(color_np, normal_np):
|
70 |
# Convert images to numpy arrays
|
71 |
h, w, c = color_np.shape
|
|
|
72 |
mask = np.fromfunction(lambda y, x: x > y, (h, w))
|
73 |
mask = mask.astype(bool)
|
74 |
mask = np.stack([mask] * c, axis=-1)
|
|
|
75 |
combined_np = np.where(mask, color_np, normal_np)
|
76 |
return Image.fromarray(combined_np)
|
77 |
|
|
|
79 |
|
80 |
# Export mesh
|
81 |
trimesh_mesh = generated_mesh.to_trimesh(transform_pose=True)
|
|
|
82 |
trimesh_mesh.export(mesh_path)
|
83 |
|
84 |
return preview_images, normal_image, mesh_path, mesh_path
|
|
|
88 |
if not mesh_path:
|
89 |
return None
|
90 |
|
|
|
91 |
temp_file = tempfile.NamedTemporaryFile(suffix=f".{export_format}", delete=False)
|
92 |
temp_file_path = temp_file.name
|
93 |
|
94 |
new_mesh_path = mesh_path.replace(".glb", f".{export_format}")
|
95 |
mesh = trimesh.load_mesh(mesh_path)
|
96 |
+
mesh.export(temp_file_path)
|
97 |
|
98 |
+
return temp_file_path
|
99 |
|
|
|
100 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
101 |
gr.Markdown(
|
102 |
"""
|
|
|
112 |
with gr.Row():
|
113 |
gr.Markdown("""
|
114 |
<p align="center">
|
115 |
+
<a title="Website" href="https://stable-x.github.io/Hi3DGen/" target="_blank">
|
116 |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
|
117 |
</a>
|
118 |
+
<a title="arXiv" href="https://stable-x.github.io/Hi3DGen/hi3dgen_paper.pdf" target="_blank">
|
119 |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
120 |
</a>
|
121 |
+
<a title="Github" href="https://github.com/Stable-X/Hi3DGen" target="_blank">
|
122 |
+
<img src="https://img.shields.io/github/stars/Stable-X/Hi3DGen?label=GitHub%20%E2%98%85&logo=github&color=C8C">
|
123 |
</a>
|
124 |
+
<a title="Social" href="https://x.com/ychngji6" target="_blank">
|
125 |
+
<img src="https://www.obukhov.ai/img/badges/badge-social.svg">
|
126 |
</a>
|
127 |
</p>
|
128 |
""")
|
|
|
130 |
with gr.Row():
|
131 |
with gr.Column(scale=1):
|
132 |
with gr.Tabs():
|
|
|
133 |
with gr.Tab("Single Image"):
|
134 |
with gr.Row():
|
135 |
image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil")
|
|
|
153 |
with gr.Row():
|
154 |
gen_shape_btn = gr.Button("Generate Shape", size="lg", variant="primary")
|
155 |
|
|
|
156 |
with gr.Column(scale=1):
|
157 |
with gr.Tabs():
|
158 |
with gr.Tab("Preview"):
|
159 |
+
output_gallery = gr.Gallery(label="Examples", columns=4, rows=2, object_fit="contain", height="auto", show_label=False)
|
160 |
with gr.Tab("3D Model"):
|
161 |
with gr.Column():
|
162 |
model_output = gr.Model3D(label="3D Model Preview (Each model is approximately 40MB, may take around 1 minute to load)")
|
|
|
186 |
lambda: gr.Button(interactive=True),
|
187 |
outputs=[download_btn],
|
188 |
)
|
189 |
+
|
|
|
190 |
def update_download_button(mesh_path, export_format):
|
191 |
if not mesh_path:
|
192 |
return gr.File.update(value=None, interactive=False)
|
|
|
214 |
gr.Markdown(
|
215 |
"""
|
216 |
**Acknowledgments**: Hi3DGen is built on the shoulders of giants. We would like to express our gratitude to the open-source research community and the developers of these pioneering projects:
|
217 |
+
- **3D Modeling:** Our 3D Model is finetuned from the SOTA open-source 3D foundation model [Trellis](https://github.com/microsoft/TRELLIS)
|
218 |
+
- **Normal Estimation:** We build on top of [StableNormal](https://github.com/hugoycj/StableNormal)
|
|
|
|
|
219 |
"""
|
220 |
)
|
221 |
|
222 |
if __name__ == "__main__":
|
223 |
+
# 初始化 pipeline,使用 CPU
|
224 |
pipeline = TrellisImageTo3DPipeline.from_pretrained("Stable-X/trellis-normal-v0-1")
|
225 |
+
|
226 |
+
# 加载 StableNormal 模型(会自动选择 CPU)
|
|
|
227 |
normal_predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal_turbo", trust_repo=True, yoso_version='yoso-normal-v1-8-1')
|
228 |
|
|
|
229 |
demo.launch()
|