Sm0kyWu commited on
Commit
3ebf31b
·
verified ·
1 Parent(s): bac292e

Upload 63 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +62 -0
  2. app.py +259 -421
  3. assets/example_image/T.png +3 -0
  4. assets/example_image/typical_building_building.png +3 -0
  5. assets/example_image/typical_building_castle.png +3 -0
  6. assets/example_image/typical_building_colorful_cottage.png +3 -0
  7. assets/example_image/typical_building_maya_pyramid.png +3 -0
  8. assets/example_image/typical_building_mushroom.png +3 -0
  9. assets/example_image/typical_building_space_station.png +3 -0
  10. assets/example_image/typical_creature_dragon.png +3 -0
  11. assets/example_image/typical_creature_elephant.png +3 -0
  12. assets/example_image/typical_creature_furry.png +3 -0
  13. assets/example_image/typical_creature_quadruped.png +3 -0
  14. assets/example_image/typical_creature_robot_crab.png +3 -0
  15. assets/example_image/typical_creature_robot_dinosour.png +3 -0
  16. assets/example_image/typical_creature_rock_monster.png +3 -0
  17. assets/example_image/typical_humanoid_block_robot.png +3 -0
  18. assets/example_image/typical_humanoid_dragonborn.png +3 -0
  19. assets/example_image/typical_humanoid_dwarf.png +3 -0
  20. assets/example_image/typical_humanoid_goblin.png +3 -0
  21. assets/example_image/typical_humanoid_mech.png +3 -0
  22. assets/example_image/typical_misc_crate.png +3 -0
  23. assets/example_image/typical_misc_fireplace.png +3 -0
  24. assets/example_image/typical_misc_gate.png +3 -0
  25. assets/example_image/typical_misc_lantern.png +3 -0
  26. assets/example_image/typical_misc_magicbook.png +3 -0
  27. assets/example_image/typical_misc_mailbox.png +3 -0
  28. assets/example_image/typical_misc_monster_chest.png +3 -0
  29. assets/example_image/typical_misc_paper_machine.png +3 -0
  30. assets/example_image/typical_misc_phonograph.png +3 -0
  31. assets/example_image/typical_misc_portal2.png +3 -0
  32. assets/example_image/typical_misc_storage_chest.png +3 -0
  33. assets/example_image/typical_misc_telephone.png +3 -0
  34. assets/example_image/typical_misc_television.png +3 -0
  35. assets/example_image/typical_misc_workbench.png +3 -0
  36. assets/example_image/typical_vehicle_biplane.png +3 -0
  37. assets/example_image/typical_vehicle_bulldozer.png +3 -0
  38. assets/example_image/typical_vehicle_cart.png +3 -0
  39. assets/example_image/typical_vehicle_excavator.png +3 -0
  40. assets/example_image/typical_vehicle_helicopter.png +3 -0
  41. assets/example_image/typical_vehicle_locomotive.png +3 -0
  42. assets/example_image/typical_vehicle_pirate_ship.png +3 -0
  43. assets/example_image/weatherworn_misc_paper_machine3.png +3 -0
  44. assets/example_multi_image/character_1.png +3 -0
  45. assets/example_multi_image/character_2.png +3 -0
  46. assets/example_multi_image/character_3.png +3 -0
  47. assets/example_multi_image/mushroom_1.png +3 -0
  48. assets/example_multi_image/mushroom_2.png +3 -0
  49. assets/example_multi_image/mushroom_3.png +3 -0
  50. assets/example_multi_image/orangeguy_1.png +3 -0
.gitattributes CHANGED
@@ -35,3 +35,65 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
  wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
37
  wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
38
+ assets/example_image/T.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/example_image/typical_building_building.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/example_image/typical_building_castle.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/example_image/typical_building_colorful_cottage.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/example_image/typical_building_maya_pyramid.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/example_image/typical_building_mushroom.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/example_image/typical_building_space_station.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/example_image/typical_creature_dragon.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/example_image/typical_creature_elephant.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/example_image/typical_creature_furry.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/example_image/typical_creature_quadruped.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/example_image/typical_creature_robot_crab.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/example_image/typical_creature_robot_dinosour.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/example_image/typical_creature_rock_monster.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/example_image/typical_humanoid_block_robot.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/example_image/typical_humanoid_dragonborn.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/example_image/typical_humanoid_dwarf.png filter=lfs diff=lfs merge=lfs -text
55
+ assets/example_image/typical_humanoid_goblin.png filter=lfs diff=lfs merge=lfs -text
56
+ assets/example_image/typical_humanoid_mech.png filter=lfs diff=lfs merge=lfs -text
57
+ assets/example_image/typical_misc_crate.png filter=lfs diff=lfs merge=lfs -text
58
+ assets/example_image/typical_misc_fireplace.png filter=lfs diff=lfs merge=lfs -text
59
+ assets/example_image/typical_misc_gate.png filter=lfs diff=lfs merge=lfs -text
60
+ assets/example_image/typical_misc_lantern.png filter=lfs diff=lfs merge=lfs -text
61
+ assets/example_image/typical_misc_magicbook.png filter=lfs diff=lfs merge=lfs -text
62
+ assets/example_image/typical_misc_mailbox.png filter=lfs diff=lfs merge=lfs -text
63
+ assets/example_image/typical_misc_monster_chest.png filter=lfs diff=lfs merge=lfs -text
64
+ assets/example_image/typical_misc_paper_machine.png filter=lfs diff=lfs merge=lfs -text
65
+ assets/example_image/typical_misc_phonograph.png filter=lfs diff=lfs merge=lfs -text
66
+ assets/example_image/typical_misc_portal2.png filter=lfs diff=lfs merge=lfs -text
67
+ assets/example_image/typical_misc_storage_chest.png filter=lfs diff=lfs merge=lfs -text
68
+ assets/example_image/typical_misc_telephone.png filter=lfs diff=lfs merge=lfs -text
69
+ assets/example_image/typical_misc_television.png filter=lfs diff=lfs merge=lfs -text
70
+ assets/example_image/typical_misc_workbench.png filter=lfs diff=lfs merge=lfs -text
71
+ assets/example_image/typical_vehicle_biplane.png filter=lfs diff=lfs merge=lfs -text
72
+ assets/example_image/typical_vehicle_bulldozer.png filter=lfs diff=lfs merge=lfs -text
73
+ assets/example_image/typical_vehicle_cart.png filter=lfs diff=lfs merge=lfs -text
74
+ assets/example_image/typical_vehicle_excavator.png filter=lfs diff=lfs merge=lfs -text
75
+ assets/example_image/typical_vehicle_helicopter.png filter=lfs diff=lfs merge=lfs -text
76
+ assets/example_image/typical_vehicle_locomotive.png filter=lfs diff=lfs merge=lfs -text
77
+ assets/example_image/typical_vehicle_pirate_ship.png filter=lfs diff=lfs merge=lfs -text
78
+ assets/example_image/weatherworn_misc_paper_machine3.png filter=lfs diff=lfs merge=lfs -text
79
+ assets/example_multi_image/character_1.png filter=lfs diff=lfs merge=lfs -text
80
+ assets/example_multi_image/character_2.png filter=lfs diff=lfs merge=lfs -text
81
+ assets/example_multi_image/character_3.png filter=lfs diff=lfs merge=lfs -text
82
+ assets/example_multi_image/mushroom_1.png filter=lfs diff=lfs merge=lfs -text
83
+ assets/example_multi_image/mushroom_2.png filter=lfs diff=lfs merge=lfs -text
84
+ assets/example_multi_image/mushroom_3.png filter=lfs diff=lfs merge=lfs -text
85
+ assets/example_multi_image/orangeguy_1.png filter=lfs diff=lfs merge=lfs -text
86
+ assets/example_multi_image/orangeguy_2.png filter=lfs diff=lfs merge=lfs -text
87
+ assets/example_multi_image/orangeguy_3.png filter=lfs diff=lfs merge=lfs -text
88
+ assets/example_multi_image/popmart_1.png filter=lfs diff=lfs merge=lfs -text
89
+ assets/example_multi_image/popmart_2.png filter=lfs diff=lfs merge=lfs -text
90
+ assets/example_multi_image/popmart_3.png filter=lfs diff=lfs merge=lfs -text
91
+ assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text
92
+ assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text
93
+ assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text
94
+ assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text
95
+ assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text
96
+ assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text
97
+ assets/example_multi_image/yoimiya_1.png filter=lfs diff=lfs merge=lfs -text
98
+ assets/example_multi_image/yoimiya_2.png filter=lfs diff=lfs merge=lfs -text
99
+ assets/example_multi_image/yoimiya_3.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
- from gradio_litmodel3d import LitModel3D
3
  import spaces
 
4
 
5
  import os
6
-
7
  import shutil
8
  os.environ['SPCONV_ALGO'] = 'native'
9
  from typing import *
@@ -12,97 +11,111 @@ import numpy as np
12
  import imageio
13
  from easydict import EasyDict as edict
14
  from PIL import Image
15
- from Amodal3R.pipelines import Amodal3RImageTo3DPipeline
16
  from trellis.pipelines import TrellisImageTo3DPipeline
17
- from Amodal3R.representations import Gaussian, MeshExtractResult
18
- from Amodal3R.utils import render_utils, postprocessing_utils
19
- from segment_anything import sam_model_registry, SamPredictor
20
- from huggingface_hub import hf_hub_download
21
- import cv2
22
 
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
26
  os.makedirs(TMP_DIR, exist_ok=True)
27
 
 
28
  def start_session(req: gr.Request):
29
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
30
  os.makedirs(user_dir, exist_ok=True)
31
-
 
32
  def end_session(req: gr.Request):
33
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
  shutil.rmtree(user_dir)
35
 
36
- def reset_image(predictor, img):
 
37
  """
38
- 上传图像后调用:
39
- - 重置 predictor,
40
- - 设置 predictor 的输入图像,
41
- - 返回原图
 
42
  """
43
- predictor.set_image(img)
44
- original_img = img.copy()
45
- # 返回predictor,visible occlusion mask初始化, 原始图像
46
- return predictor, original_img, "The models are ready."
47
-
48
- def button_clickable(selected_points):
49
- if len(selected_points) > 0:
50
- return gr.Button.update(interactive=True)
51
- else:
52
- return gr.Button.update(interactive=False)
53
 
54
- def run_sam(predictor, selected_points):
55
  """
56
- 调用 SAM 模型进行分割。
 
 
 
 
 
 
57
  """
58
- # predictor.set_image(image)
59
- if len(selected_points) == 0:
60
- return [], None
61
- input_points = [p for p in selected_points]
62
- input_labels = [1 for _ in range(len(selected_points))]
63
- masks, _, _ = predictor.predict(
64
- point_coords=np.array(input_points),
65
- point_labels=np.array(input_labels),
66
- multimask_output=False, # 单对象输出
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
- best_mask = masks[0].astype(np.uint8)
69
- # dilate
70
- if len(selected_points) > 1:
71
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
72
- best_mask = cv2.dilate(best_mask, kernel, iterations=1)
73
- best_mask = cv2.erode(best_mask, kernel, iterations=1)
74
- return best_mask
75
-
76
- def apply_mask_overlay(image, mask, color=(255, 0, 0)):
77
- """
78
- 在原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
79
- """
80
- img_arr = image
81
- overlay = img_arr.copy()
82
- gray_color = np.array([200, 200, 200], dtype=np.uint8)
83
- non_mask = mask == 0
84
- overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
85
- contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
86
- cv2.drawContours(overlay, contours, -1, color, 2)
87
- return overlay
88
-
89
- def segment_and_overlay(image, points, sam_predictor):
90
  """
91
- 调用 run_sam 获得 mask,然后叠加显示分割结果。
92
  """
93
- visible_mask = run_sam(sam_predictor, points)
94
- overlaid = apply_mask_overlay(image, visible_mask * 255)
95
- return overlaid, visible_mask
96
 
97
 
98
  @spaces.GPU
99
  def image_to_3d(
100
- images: np.ndarray,
 
 
101
  seed: int,
102
  ss_guidance_strength: float,
103
  ss_sampling_steps: int,
104
  slat_guidance_strength: float,
105
  slat_sampling_steps: int,
 
106
  req: gr.Request,
107
  ) -> Tuple[dict, str]:
108
  """
@@ -122,21 +135,37 @@ def image_to_3d(
122
  str: The path to the video of the 3D model.
123
  """
124
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
125
- outputs = pipeline.run_multi_image(
126
- [Image.fromarray(images)],
127
- seed=seed,
128
- formats=["gaussian", "mesh"],
129
- preprocess_image=False,
130
- sparse_structure_sampler_params={
131
- "steps": ss_sampling_steps,
132
- "cfg_strength": ss_guidance_strength,
133
- },
134
- slat_sampler_params={
135
- "steps": slat_sampling_steps,
136
- "cfg_strength": slat_guidance_strength,
137
- },
138
- mode="stochastic",
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
141
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
142
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
@@ -153,9 +182,15 @@ def extract_glb(
153
  mesh_simplify: float,
154
  texture_size: int,
155
  req: gr.Request,
156
- ) -> tuple:
157
  """
158
- 从生成的 3D 模型中提取 GLB 文件。
 
 
 
 
 
 
159
  """
160
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
161
  gs, mesh = unpack_state(state)
@@ -167,9 +202,13 @@ def extract_glb(
167
 
168
 
169
  @spaces.GPU
170
- def extract_gaussian(state: dict, req: gr.Request) -> tuple:
171
  """
172
- 从生成的 3D 模型中提取 Gaussian 文件。
 
 
 
 
173
  """
174
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
175
  gs, _ = unpack_state(state)
@@ -179,273 +218,59 @@ def extract_gaussian(state: dict, req: gr.Request) -> tuple:
179
  return gaussian_path, gaussian_path
180
 
181
 
182
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
183
- return {
184
- 'gaussian': {
185
- **gs.init_params,
186
- '_xyz': gs._xyz.cpu().numpy(),
187
- '_features_dc': gs._features_dc.cpu().numpy(),
188
- '_scaling': gs._scaling.cpu().numpy(),
189
- '_rotation': gs._rotation.cpu().numpy(),
190
- '_opacity': gs._opacity.cpu().numpy(),
191
- },
192
- 'mesh': {
193
- 'vertices': mesh.vertices.cpu().numpy(),
194
- 'faces': mesh.faces.cpu().numpy(),
195
- },
196
- }
197
-
198
-
199
- def unpack_state(state: dict) -> tuple:
200
- gs = Gaussian(
201
- aabb=state['gaussian']['aabb'],
202
- sh_degree=state['gaussian']['sh_degree'],
203
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
204
- scaling_bias=state['gaussian']['scaling_bias'],
205
- opacity_bias=state['gaussian']['opacity_bias'],
206
- scaling_activation=state['gaussian']['scaling_activation'],
207
- )
208
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
209
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
210
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
211
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
212
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
213
-
214
- mesh = edict(
215
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
216
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
217
- )
218
-
219
- return gs, mesh
220
-
221
- def get_sam_predictor():
222
- # sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
223
- # model_type = "vit_h"
224
- # sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
225
- # sam_predictor = SamPredictor(sam)
226
- # return sam_predictor
227
- return predictor
228
-
229
-
230
- def draw_points_on_image(image, point):
231
- """在图像上绘制所有点,points 为 [(x, y, point_type), ...]"""
232
- image_with_points = image.copy()
233
- x, y = point
234
- color = (255, 0, 0)
235
- cv2.circle(image_with_points, (int(x), int(y)), radius=10, color=color, thickness=-1)
236
- return image_with_points
237
-
238
-
239
- def see_point(image, x, y):
240
- """
241
- see操作:不修改 points 列表,仅在图像上临时显示这个点,
242
- 并返回更新后的图像和当前列表(不更新)。
243
- """
244
- # 复制当前列表,并在副本中加上新点(仅用于显示)
245
- updated_image = draw_points_on_image(image, [x,y])
246
- return updated_image
247
-
248
- def add_point(x, y, visible_points):
249
- """
250
- add操作:将新点添加到 points 列表中,
251
- 并返回更新后的图像和新的点列表。
252
- """
253
- if [x, y] not in visible_points:
254
- visible_points.append([x, y])
255
- return visible_points
256
-
257
- def delete_point(visible_points):
258
- """
259
- delete操作:删除 points 列表中的最后一个点,
260
- 并返回更新后的图像和新的点列表。
261
- """
262
- visible_points.pop()
263
- return visible_points
264
-
265
 
266
- def clear_all_points(image):
267
- """
268
- 清除所有点:返回原图、空的 visible 和 occlusion 列表,
269
- 以及更新后的点文本信息和空下拉菜单列表。
270
- """
271
- updated_image = image.copy()
272
- return updated_image
273
 
274
- def see_visible_points(image, visible_points):
275
  """
276
- 在图像上绘制所有 visible 点(红色)。
277
  """
278
- updated_image = image.copy()
279
- for p in visible_points:
280
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
281
- return updated_image
282
-
283
- def update_all_points(visible_points):
284
- text = f"Points: {visible_points}"
285
- visible_dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points]
286
- # 返回更新字典来明确设置 choices value
287
- return text, gr.Dropdown(label="Select Point to Delete", choices=visible_dropdown_choices, value=None, interactive=True)
288
-
289
- def delete_selected_visible(image, visible_points, selected_value):
290
- # selected_value 是类似 "(x, y)" 的字符串
291
- try:
292
- selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
293
- except ValueError:
294
- selected_index = None
295
- if selected_index is not None and 0 <= selected_index < len(visible_points):
296
- visible_points.pop(selected_index)
297
- updated_image = image.copy()
298
- # 重新绘制所有 visible 点(红色)
299
- for p in visible_points:
300
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
301
- updated_text, vis_dropdown = update_all_points(visible_points)
302
- return updated_image, visible_points, updated_text, vis_dropdown
303
-
304
- def add_mask(mask, mask_list):
305
- # check if the mask if same as the last mask in the list
306
- if len(mask_list) > 0:
307
- if np.array_equal(mask, mask_list[-1]):
308
- return mask_list
309
- mask_list.append(mask)
310
- return mask_list
311
-
312
- def vis_mask(image, mask_list):
313
- updated_image = image.copy()
314
- # combine all the mask:
315
- combined_mask = np.zeros_like(updated_image[:, :, 0])
316
- for mask in mask_list:
317
- combined_mask = cv2.bitwise_or(combined_mask, mask)
318
- # overlay the mask on the image
319
- updated_image = apply_mask_overlay(updated_image, combined_mask)
320
- return updated_image
321
-
322
- def delete_mask(mask_list):
323
- if len(mask_list) > 0:
324
- mask_list.pop()
325
- return mask_list
326
-
327
- def check_combined_mask(image, visibility_mask, mask_list, scale=0.6):
328
- updated_image = image.copy()
329
- # combine all the mask:
330
- combined_mask = np.zeros_like(updated_image[:, :, 0])
331
- occluded_mask = np.zeros_like(updated_image[:, :, 0])
332
- if len(mask_list) == 0:
333
- combined_mask = visibility_mask
334
- else:
335
- for mask in mask_list:
336
- combined_mask = cv2.bitwise_or(combined_mask, mask)
337
-
338
- if len(mask_list) > 1:
339
- kernel = np.ones((5, 5), np.uint8)
340
- dilate_iterations = 1
341
- combined_mask = cv2.dilate(combined_mask, kernel, iterations=dilate_iterations)
342
- combined_mask = cv2.erode(combined_mask, kernel, iterations=dilate_iterations)
343
-
344
- masked_img = updated_image * combined_mask[:, :, None]
345
- occluded_mask[combined_mask == 1] = 127
346
-
347
- # move the visible part to the center of the image
348
- x, y, w, h = cv2.boundingRect(combined_mask.astype(np.uint8))
349
- cropped_occluded_mask = (occluded_mask[y:y+h, x:x+w]).astype(np.uint8)
350
- cropped_img = masked_img[y:y+h, x:x+w]
351
-
352
- target_size = 512
353
- scale_factor = target_size / max(w, h)
354
- new_w = int(round(w * scale_factor * scale))
355
- new_h = int(round(h * scale_factor * scale))
356
-
357
- resized_occluded_mask = cv2.resize(cropped_occluded_mask.astype(np.uint8), (new_w, new_h), cv2.INTER_NEAREST)
358
- resized_img = cv2.resize(cropped_img, (new_w, new_h), cv2.INTER_NEAREST)
359
-
360
- final_img = np.zeros((target_size, target_size, 3), dtype=updated_image.dtype)
361
- final_occluded_mask = np.zeros((target_size, target_size), dtype=np.uint8)
362
-
363
- x_offset = (target_size - new_w) // 2
364
- y_offset = (target_size - new_h) // 2
365
-
366
- final_img[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_img
367
- final_occluded_mask[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized_occluded_mask
368
-
369
- return final_img, occluded_mask
370
-
371
-
372
-
373
- def get_seed(randomize_seed: bool, seed: int) -> int:
374
- """
375
- Get the random seed.
376
- """
377
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
378
 
379
 
380
  with gr.Blocks(delete_cache=(600, 600)) as demo:
381
  gr.Markdown("""
382
- ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
 
 
 
 
383
  """)
384
-
385
- # 定义各状态变量
386
- predictor = gr.State(value=get_sam_predictor())
387
- visible_points_state = gr.State(value=[])
388
- occlusion_points_state = gr.State(value=[])
389
- original_image = gr.State(value=None)
390
- visibility_mask = gr.State(value=None)
391
- visibility_mask_list = gr.State(value=[])
392
-
393
- occluded_mask = gr.State(value=None)
394
- output_buf = gr.State()
395
-
396
-
397
- with gr.Row():
398
- gr.Markdown("""* Step 1 - Generate Visibility Mask and Occlusion Mask.
399
- * Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
400
- * Add the point prompts to indicate the target object and occluders separately.
401
- * "Render Point", see the position of the point to be added.
402
- * "Add Point", the point will be added to the list.
403
- * "Generate mask", see the segmented area corresponding to current point list.
404
- * "Add mask", current mask will be added for 3D amodal completion.
405
- """)
406
- with gr.Row():
407
- with gr.Column():
408
- input_image = gr.Image(type="numpy", label='Input Occlusion Image', sources="upload", height=300)
409
- with gr.Row():
410
- message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message") # 用于显示提示信息
411
- with gr.Row():
412
- x_input = gr.Number(label="X Coordinate", value=0)
413
- y_input = gr.Number(label="Y Coordinate", value=0)
414
- with gr.Row():
415
- see_button = gr.Button("Render Point")
416
- add_button = gr.Button("Add Point")
417
- with gr.Row():
418
- clear_button = gr.Button("Clear Points")
419
- see_visible_button = gr.Button("Render Added Points")
420
- with gr.Row():
421
- # 新增文本框实时显示点列表
422
- points_text = gr.Textbox(label="Points List", interactive=False)
423
- with gr.Row():
424
- # 新增下拉菜单,用户可选择需要删除的点
425
- visible_points_dropdown = gr.Dropdown(label="Select Point to Delete", choices=[], value=None, interactive=True)
426
- delete_visible_button = gr.Button("Delete Selected Visible")
427
- with gr.Column():
428
- # 用于显示 SAM 分割结果
429
- visible_mask = gr.Image(label='Visible Mask', interactive=False, height=300)
430
- with gr.Row():
431
- gen_vis_mask = gr.Button("Generate Mask")
432
- add_vis_mask = gr.Button("Add Mask")
433
- with gr.Row():
434
- render_vis_mask = gr.Button("Render Mask")
435
- undo_vis_mask = gr.Button("Undo Last Mask")
436
- vis_input = gr.Image(label='Visible Input', interactive=False, height=300)
437
- with gr.Row():
438
- zoom_scale = gr.Slider(0.3, 1.0, label="Target Object Scale", value=0.6, step=0.1)
439
- check_visible_input = gr.Button("Generate Occluded Input")
440
- with gr.Row():
441
- gr.Markdown("""* Step 2 - 3D Amodal Completion.
442
- * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
443
- * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
444
- """)
445
  with gr.Row():
446
  with gr.Column():
447
- with gr.Accordion(label="Generation Settings", open=True):
448
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
 
 
 
 
 
 
 
 
 
 
 
449
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
450
  gr.Markdown("Stage 1: Sparse Structure Generation")
451
  with gr.Row():
@@ -455,114 +280,127 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
455
  with gr.Row():
456
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
457
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
 
458
  generate_btn = gr.Button("Generate")
 
 
 
 
 
 
 
 
 
 
 
 
459
  with gr.Column():
460
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
 
 
 
 
461
 
462
- # # Handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  demo.load(start_session)
464
  demo.unload(end_session)
465
-
466
- # ---------------------------
467
- # 原有交互逻辑(略)
468
- # ---------------------------
469
- input_image.upload(
470
- reset_image,
471
- [predictor, input_image],
472
- [predictor, original_image, message],
473
- )
474
- see_button.click(
475
- see_point,
476
- inputs=[original_image, x_input, y_input],
477
- outputs=[input_image]
478
- )
479
- add_button.click(
480
- add_point,
481
- inputs=[x_input, y_input, visible_points_state],
482
- outputs=[visible_points_state]
483
- )
484
 
485
- # ---------------------------
486
- # 新增的交互逻辑
487
- # ---------------------------
488
- clear_button.click(
489
- clear_all_points,
490
- inputs=[original_image],
491
- outputs=[input_image]
492
  )
493
- see_visible_button.click(
494
- see_visible_points,
495
- inputs=[input_image, visible_points_state],
496
- outputs=input_image
497
  )
498
- # 当 visible_points_state 或 occlusion_points_state 变化时,更新文本框和下拉菜单
499
- visible_points_state.change(
500
- update_all_points,
501
- inputs=[visible_points_state],
502
- outputs=[points_text, visible_points_dropdown]
503
  )
504
- delete_visible_button.click(
505
- delete_selected_visible,
506
- inputs=[input_image, visible_points_state, visible_points_dropdown],
507
- outputs=[input_image, visible_points_state, points_text, visible_points_dropdown]
508
  )
509
 
510
- # 生成mask的逻辑
511
- gen_vis_mask.click(
512
- segment_and_overlay,
513
- inputs=[original_image, visible_points_state, predictor],
514
- outputs=[visible_mask, visibility_mask]
515
- )
516
- add_vis_mask.click(
517
- add_mask,
518
- inputs=[visibility_mask, visibility_mask_list],
519
- outputs=[visibility_mask_list]
520
- )
521
- render_vis_mask.click(
522
- vis_mask,
523
- inputs=[original_image, visibility_mask_list],
524
- outputs=[visible_mask]
525
  )
526
- undo_vis_mask.click(
527
- delete_mask,
528
- inputs=[visibility_mask_list],
529
- outputs=[visibility_mask_list]
530
  )
531
 
532
- check_visible_input.click(
533
- check_combined_mask,
534
- inputs=[original_image, visibility_mask, visibility_mask_list, zoom_scale],
535
- outputs=[vis_input, occluded_mask]
 
 
 
536
  )
537
 
 
 
 
 
 
 
 
 
538
 
539
- # 3D Amodal Reconstruction
540
- # generate_btn.click(
541
- # get_seed,
542
- # inputs=[randomize_seed, seed],
543
- # outputs=[seed],
544
- # ).then(
545
- # image_to_3d,
546
- # inputs=[vis_input, occluded_mask, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
547
- # outputs=[output_buf, video_output],
548
- # )
549
-
550
- generate_btn.click(
551
- image_to_3d,
552
- inputs=[vis_input, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
553
- outputs=[output_buf, video_output],
554
  )
555
 
556
 
557
- # 启动 Gradio App
558
  if __name__ == "__main__":
559
- # pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
560
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
561
  pipeline.cuda()
562
- predictor = get_sam_predictor()
563
- predictor = predictor.cuda()
564
  try:
565
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
566
  except:
567
  pass
568
- demo.launch()
 
1
  import gradio as gr
 
2
  import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
 
5
  import os
 
6
  import shutil
7
  os.environ['SPCONV_ALGO'] = 'native'
8
  from typing import *
 
11
  import imageio
12
  from easydict import EasyDict as edict
13
  from PIL import Image
 
14
  from trellis.pipelines import TrellisImageTo3DPipeline
15
+ from trellis.representations import Gaussian, MeshExtractResult
16
+ from trellis.utils import render_utils, postprocessing_utils
 
 
 
17
 
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
21
  os.makedirs(TMP_DIR, exist_ok=True)
22
 
23
+
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
26
  os.makedirs(user_dir, exist_ok=True)
27
+
28
+
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
31
  shutil.rmtree(user_dir)
32
 
33
+
34
+ def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
+ Preprocess the input image.
37
+ Args:
38
+ image (Image.Image): The input image.
39
+ Returns:
40
+ Image.Image: The preprocessed image.
41
  """
42
+ processed_image = pipeline.preprocess_image(image)
43
+ return processed_image
44
+
 
 
 
 
 
 
 
45
 
46
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
47
  """
48
+ Preprocess a list of input images.
49
+
50
+ Args:
51
+ images (List[Tuple[Image.Image, str]]): The input images.
52
+
53
+ Returns:
54
+ List[Image.Image]: The preprocessed images.
55
  """
56
+ images = [image[0] for image in images]
57
+ processed_images = [pipeline.preprocess_image(image) for image in images]
58
+ return processed_images
59
+
60
+
61
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
62
+ return {
63
+ 'gaussian': {
64
+ **gs.init_params,
65
+ '_xyz': gs._xyz.cpu().numpy(),
66
+ '_features_dc': gs._features_dc.cpu().numpy(),
67
+ '_scaling': gs._scaling.cpu().numpy(),
68
+ '_rotation': gs._rotation.cpu().numpy(),
69
+ '_opacity': gs._opacity.cpu().numpy(),
70
+ },
71
+ 'mesh': {
72
+ 'vertices': mesh.vertices.cpu().numpy(),
73
+ 'faces': mesh.faces.cpu().numpy(),
74
+ },
75
+ }
76
+
77
+
78
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
79
+ gs = Gaussian(
80
+ aabb=state['gaussian']['aabb'],
81
+ sh_degree=state['gaussian']['sh_degree'],
82
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
83
+ scaling_bias=state['gaussian']['scaling_bias'],
84
+ opacity_bias=state['gaussian']['opacity_bias'],
85
+ scaling_activation=state['gaussian']['scaling_activation'],
86
  )
87
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
88
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
89
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
90
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
91
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
92
+
93
+ mesh = edict(
94
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
95
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
96
+ )
97
+
98
+ return gs, mesh
99
+
100
+
101
+ def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
 
 
 
 
102
  """
103
+ Get the random seed.
104
  """
105
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
 
 
106
 
107
 
108
  @spaces.GPU
109
  def image_to_3d(
110
+ image: Image.Image,
111
+ multiimages: List[Tuple[Image.Image, str]],
112
+ is_multiimage: bool,
113
  seed: int,
114
  ss_guidance_strength: float,
115
  ss_sampling_steps: int,
116
  slat_guidance_strength: float,
117
  slat_sampling_steps: int,
118
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
119
  req: gr.Request,
120
  ) -> Tuple[dict, str]:
121
  """
 
135
  str: The path to the video of the 3D model.
136
  """
137
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
138
+ if not is_multiimage:
139
+ outputs = pipeline.run(
140
+ image,
141
+ seed=seed,
142
+ formats=["gaussian", "mesh"],
143
+ preprocess_image=False,
144
+ sparse_structure_sampler_params={
145
+ "steps": ss_sampling_steps,
146
+ "cfg_strength": ss_guidance_strength,
147
+ },
148
+ slat_sampler_params={
149
+ "steps": slat_sampling_steps,
150
+ "cfg_strength": slat_guidance_strength,
151
+ },
152
+ )
153
+ else:
154
+ outputs = pipeline.run_multi_image(
155
+ [image[0] for image in multiimages],
156
+ seed=seed,
157
+ formats=["gaussian", "mesh"],
158
+ preprocess_image=False,
159
+ sparse_structure_sampler_params={
160
+ "steps": ss_sampling_steps,
161
+ "cfg_strength": ss_guidance_strength,
162
+ },
163
+ slat_sampler_params={
164
+ "steps": slat_sampling_steps,
165
+ "cfg_strength": slat_guidance_strength,
166
+ },
167
+ mode=multiimage_algo,
168
+ )
169
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
170
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
171
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
182
  mesh_simplify: float,
183
  texture_size: int,
184
  req: gr.Request,
185
+ ) -> Tuple[str, str]:
186
  """
187
+ Extract a GLB file from the 3D model.
188
+ Args:
189
+ state (dict): The state of the generated 3D model.
190
+ mesh_simplify (float): The mesh simplification factor.
191
+ texture_size (int): The texture resolution.
192
+ Returns:
193
+ str: The path to the extracted GLB file.
194
  """
195
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
196
  gs, mesh = unpack_state(state)
 
202
 
203
 
204
  @spaces.GPU
205
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
206
  """
207
+ Extract a Gaussian file from the 3D model.
208
+ Args:
209
+ state (dict): The state of the generated 3D model.
210
+ Returns:
211
+ str: The path to the extracted Gaussian file.
212
  """
213
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
214
  gs, _ = unpack_state(state)
 
218
  return gaussian_path, gaussian_path
219
 
220
 
221
+ def prepare_multi_example() -> List[Image.Image]:
222
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
223
+ images = []
224
+ for case in multi_case:
225
+ _images = []
226
+ for i in range(1, 4):
227
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
228
+ W, H = img.size
229
+ img = img.resize((int(W / H * 512), 512))
230
+ _images.append(np.array(img))
231
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
232
+ return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
 
 
 
 
 
 
 
234
 
235
+ def split_image(image: Image.Image) -> List[Image.Image]:
236
  """
237
+ Split an image into multiple views.
238
  """
239
+ image = np.array(image)
240
+ alpha = image[..., 3]
241
+ alpha = np.any(alpha>0, axis=0)
242
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
243
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
244
+ images = []
245
+ for s, e in zip(start_pos, end_pos):
246
+ images.append(Image.fromarray(image[:, s:e+1]))
247
+ return [preprocess_image(image) for image in images]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  with gr.Blocks(delete_cache=(600, 600)) as demo:
251
  gr.Markdown("""
252
+ ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
253
+ * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
254
+ * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
255
+
256
+ ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
257
  """)
258
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  with gr.Row():
260
  with gr.Column():
261
+ with gr.Tabs() as input_tabs:
262
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
263
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
264
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
265
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
266
+ gr.Markdown("""
267
+ Input different views of the object in separate images.
268
+
269
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
270
+ """)
271
+
272
+ with gr.Accordion(label="Generation Settings", open=False):
273
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
274
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
275
  gr.Markdown("Stage 1: Sparse Structure Generation")
276
  with gr.Row():
 
280
  with gr.Row():
281
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
282
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
283
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
284
+
285
  generate_btn = gr.Button("Generate")
286
+
287
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
288
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
289
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
290
+
291
+ with gr.Row():
292
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
293
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
294
+ gr.Markdown("""
295
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
296
+ """)
297
+
298
  with gr.Column():
299
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
300
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
301
+
302
+ with gr.Row():
303
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
304
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
305
 
306
+ is_multiimage = gr.State(False)
307
+ output_buf = gr.State()
308
+
309
+ # Example images at the bottom of the page
310
+ with gr.Row() as single_image_example:
311
+ examples = gr.Examples(
312
+ examples=[
313
+ f'assets/example_image/{image}'
314
+ for image in os.listdir("assets/example_image")
315
+ ],
316
+ inputs=[image_prompt],
317
+ fn=preprocess_image,
318
+ outputs=[image_prompt],
319
+ run_on_click=True,
320
+ examples_per_page=64,
321
+ )
322
+ with gr.Row(visible=False) as multiimage_example:
323
+ examples_multi = gr.Examples(
324
+ examples=prepare_multi_example(),
325
+ inputs=[image_prompt],
326
+ fn=split_image,
327
+ outputs=[multiimage_prompt],
328
+ run_on_click=True,
329
+ examples_per_page=8,
330
+ )
331
+
332
+ # Handlers
333
  demo.load(start_session)
334
  demo.unload(end_session)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ single_image_input_tab.select(
337
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
338
+ outputs=[is_multiimage, single_image_example, multiimage_example]
 
 
 
 
339
  )
340
+ multiimage_input_tab.select(
341
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
342
+ outputs=[is_multiimage, single_image_example, multiimage_example]
 
343
  )
344
+
345
+ image_prompt.upload(
346
+ preprocess_image,
347
+ inputs=[image_prompt],
348
+ outputs=[image_prompt],
349
  )
350
+ multiimage_prompt.upload(
351
+ preprocess_images,
352
+ inputs=[multiimage_prompt],
353
+ outputs=[multiimage_prompt],
354
  )
355
 
356
+ generate_btn.click(
357
+ get_seed,
358
+ inputs=[randomize_seed, seed],
359
+ outputs=[seed],
360
+ ).then(
361
+ image_to_3d,
362
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
363
+ outputs=[output_buf, video_output],
364
+ ).then(
365
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
366
+ outputs=[extract_glb_btn, extract_gs_btn],
 
 
 
 
367
  )
368
+
369
+ video_output.clear(
370
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
371
+ outputs=[extract_glb_btn, extract_gs_btn],
372
  )
373
 
374
+ extract_glb_btn.click(
375
+ extract_glb,
376
+ inputs=[output_buf, mesh_simplify, texture_size],
377
+ outputs=[model_output, download_glb],
378
+ ).then(
379
+ lambda: gr.Button(interactive=True),
380
+ outputs=[download_glb],
381
  )
382
 
383
+ extract_gs_btn.click(
384
+ extract_gaussian,
385
+ inputs=[output_buf],
386
+ outputs=[model_output, download_gs],
387
+ ).then(
388
+ lambda: gr.Button(interactive=True),
389
+ outputs=[download_gs],
390
+ )
391
 
392
+ model_output.clear(
393
+ lambda: gr.Button(interactive=False),
394
+ outputs=[download_glb],
 
 
 
 
 
 
 
 
 
 
 
 
395
  )
396
 
397
 
398
+ # Launch the Gradio app
399
  if __name__ == "__main__":
 
400
  pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
401
  pipeline.cuda()
 
 
402
  try:
403
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
404
  except:
405
  pass
406
+ demo.launch()
assets/example_image/T.png ADDED

Git LFS Details

  • SHA256: e29ddc83a5bd3a05fe9b34732169bc4ea7131f7c36527fdc5f626a90a73076d2
  • Pointer size: 131 Bytes
  • Size of remote file: 955 kB
assets/example_image/typical_building_building.png ADDED

Git LFS Details

  • SHA256: 8faa11d557be95c000c475247e61a773d511114c7d1e517c04f8d3d88a6049ec
  • Pointer size: 131 Bytes
  • Size of remote file: 547 kB
assets/example_image/typical_building_castle.png ADDED

Git LFS Details

  • SHA256: 076f0554b087b921863643d2b1ab3e0572a13a347fd66bc29cd9d194034affae
  • Pointer size: 131 Bytes
  • Size of remote file: 426 kB
assets/example_image/typical_building_colorful_cottage.png ADDED

Git LFS Details

  • SHA256: 687305b4e35da759692be0de614d728583a2a9cd2fd3a55593fa753e567d0d47
  • Pointer size: 131 Bytes
  • Size of remote file: 609 kB
assets/example_image/typical_building_maya_pyramid.png ADDED

Git LFS Details

  • SHA256: 4d514f7f4db244ee184af4ddfbc5948d417b4e5bf1c6ee5f5a592679561690df
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
assets/example_image/typical_building_mushroom.png ADDED

Git LFS Details

  • SHA256: de9b72d3e13e967e70844ddc54643832a84a1b35ca043a11e7c774371d0ccdab
  • Pointer size: 131 Bytes
  • Size of remote file: 488 kB
assets/example_image/typical_building_space_station.png ADDED

Git LFS Details

  • SHA256: 212c7b4c27ba1e01a7908dbc7f245e7115850eadbc9974aa726327cf35062846
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
assets/example_image/typical_creature_dragon.png ADDED

Git LFS Details

  • SHA256: 0e8d6720dfa1e7b332b76e897e617b7f0863187f30879451b4724f482c84185a
  • Pointer size: 131 Bytes
  • Size of remote file: 564 kB
assets/example_image/typical_creature_elephant.png ADDED

Git LFS Details

  • SHA256: 86a171e37a3d781e7215977f565cd63e813341c1f89e2c586fa61937e4ed6916
  • Pointer size: 131 Bytes
  • Size of remote file: 482 kB
assets/example_image/typical_creature_furry.png ADDED

Git LFS Details

  • SHA256: 5b5445b8f1996cf6d72497b2d7564c656f4048e6c1fa626fd7bb3ee582fee671
  • Pointer size: 131 Bytes
  • Size of remote file: 648 kB
assets/example_image/typical_creature_quadruped.png ADDED

Git LFS Details

  • SHA256: 7469f43f58389adec101e9685f60188bd4e7fbede77eef975102f6a8865bc786
  • Pointer size: 131 Bytes
  • Size of remote file: 685 kB
assets/example_image/typical_creature_robot_crab.png ADDED

Git LFS Details

  • SHA256: d7e716abe8f8895080f562d1dc26b14fa0e20a05aa5beb2770c6fb3b87b3476a
  • Pointer size: 131 Bytes
  • Size of remote file: 594 kB
assets/example_image/typical_creature_robot_dinosour.png ADDED

Git LFS Details

  • SHA256: d0986f29557a6fddf9b52b5251a6b6103728c61e201b1cfad1e709b090b72f56
  • Pointer size: 131 Bytes
  • Size of remote file: 632 kB
assets/example_image/typical_creature_rock_monster.png ADDED

Git LFS Details

  • SHA256: e29458a6110bee8374c0d4d12471e7167a6c1c98c18f6e2d7ff4f5f0ca3fa01b
  • Pointer size: 131 Bytes
  • Size of remote file: 648 kB
assets/example_image/typical_humanoid_block_robot.png ADDED

Git LFS Details

  • SHA256: 3a0acbb532668e1bf35f3eef5bcbfdd094c22219ef2d837fa01ccf51cce75ca3
  • Pointer size: 131 Bytes
  • Size of remote file: 441 kB
assets/example_image/typical_humanoid_dragonborn.png ADDED

Git LFS Details

  • SHA256: 5d7c547909a6c12da55dbab1c1c98181ff09e58c9ba943682ca105e71be9548e
  • Pointer size: 131 Bytes
  • Size of remote file: 481 kB
assets/example_image/typical_humanoid_dwarf.png ADDED

Git LFS Details

  • SHA256: a4a7c157d5d8071128c27594e45a7a03e5113b3333b7f1c5ff1379481e3e0264
  • Pointer size: 131 Bytes
  • Size of remote file: 498 kB
assets/example_image/typical_humanoid_goblin.png ADDED

Git LFS Details

  • SHA256: 2b0e9a04ae3e7bef44b7180a70306f95374b60727ffa0f6f01fd6c746595cd77
  • Pointer size: 131 Bytes
  • Size of remote file: 496 kB
assets/example_image/typical_humanoid_mech.png ADDED

Git LFS Details

  • SHA256: a244ec54b7984e646e54d433de6897657081dd5b9cd5ccd3d865328d813beb49
  • Pointer size: 131 Bytes
  • Size of remote file: 850 kB
assets/example_image/typical_misc_crate.png ADDED

Git LFS Details

  • SHA256: 59fd9884301faca93265166d90078e8c31e76c7f93524b1db31975df4b450748
  • Pointer size: 131 Bytes
  • Size of remote file: 642 kB
assets/example_image/typical_misc_fireplace.png ADDED

Git LFS Details

  • SHA256: 2288c034603e289192d63cbc73565107caefd99e81c4b7afa2983c8b13e34440
  • Pointer size: 131 Bytes
  • Size of remote file: 558 kB
assets/example_image/typical_misc_gate.png ADDED

Git LFS Details

  • SHA256: ec8db5389b74fe56b826e3c6d860234541033387350e09268591c46d411cc8e9
  • Pointer size: 131 Bytes
  • Size of remote file: 572 kB
assets/example_image/typical_misc_lantern.png ADDED

Git LFS Details

  • SHA256: e17bd83adf433ebfca17abd220097b2b7f08affc649518bd7822e03797e83d41
  • Pointer size: 131 Bytes
  • Size of remote file: 300 kB
assets/example_image/typical_misc_magicbook.png ADDED

Git LFS Details

  • SHA256: aff9c14589c340e31b61bf82e4506d77d72c511e741260fa1e600cefa4e103e6
  • Pointer size: 131 Bytes
  • Size of remote file: 496 kB
assets/example_image/typical_misc_mailbox.png ADDED

Git LFS Details

  • SHA256: 01e86a5d68edafb7e11d7a86f7e8081f5ed1b02578198a3271554c5fb8fb9fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 631 kB
assets/example_image/typical_misc_monster_chest.png ADDED

Git LFS Details

  • SHA256: c57a598e842225a31b9770bf3bbb9ae86197ec57d0c2883caf8cb5eed4908fbc
  • Pointer size: 131 Bytes
  • Size of remote file: 690 kB
assets/example_image/typical_misc_paper_machine.png ADDED

Git LFS Details

  • SHA256: 2d55400ae5d4df2377258400d800ece75766d5274e80ce07c3b29a4d1fd1fa36
  • Pointer size: 131 Bytes
  • Size of remote file: 614 kB
assets/example_image/typical_misc_phonograph.png ADDED

Git LFS Details

  • SHA256: 14fff9a27ea769d3ca711e9ff55ab3d9385486a5e8b99117f506df326a0a357e
  • Pointer size: 131 Bytes
  • Size of remote file: 517 kB
assets/example_image/typical_misc_portal2.png ADDED

Git LFS Details

  • SHA256: 57aab2bba56bc946523a3fca77ca70651a4ad8c6fbf1b91a1a824418df48faae
  • Pointer size: 131 Bytes
  • Size of remote file: 386 kB
assets/example_image/typical_misc_storage_chest.png ADDED

Git LFS Details

  • SHA256: 0e4ac1c67fdda902ecb709447b8defd949c738954c844c1b8364b8e3f7d9e55a
  • Pointer size: 131 Bytes
  • Size of remote file: 632 kB
assets/example_image/typical_misc_telephone.png ADDED

Git LFS Details

  • SHA256: 00048be46234a2709c12614b04cbad61c6e3c7e63c2a4ef33d999185f5393e36
  • Pointer size: 131 Bytes
  • Size of remote file: 648 kB
assets/example_image/typical_misc_television.png ADDED

Git LFS Details

  • SHA256: 6a1947b737398bf535ec212668a4d78cd38fe84cf9da1ccd6c0c0d838337755e
  • Pointer size: 131 Bytes
  • Size of remote file: 627 kB
assets/example_image/typical_misc_workbench.png ADDED

Git LFS Details

  • SHA256: a6d9ed4d005a5253b8571fd976b0d102e293512d7b5a8ed5e3f7f17c5f4e19da
  • Pointer size: 131 Bytes
  • Size of remote file: 463 kB
assets/example_image/typical_vehicle_biplane.png ADDED

Git LFS Details

  • SHA256: c73e98112eb603b4ba635b8965cad7807d0588f083811bc2faa0c7ab9668a65a
  • Pointer size: 131 Bytes
  • Size of remote file: 574 kB
assets/example_image/typical_vehicle_bulldozer.png ADDED

Git LFS Details

  • SHA256: 23d821b4daea61cbea28cc6ddd3ae46712514dfcdff995c2664f5a70d21f4ef3
  • Pointer size: 131 Bytes
  • Size of remote file: 693 kB
assets/example_image/typical_vehicle_cart.png ADDED

Git LFS Details

  • SHA256: b72c04a2aa5cf57717c05151a2982d6dc31afde130d5e830adf37a84a70616cb
  • Pointer size: 131 Bytes
  • Size of remote file: 693 kB
assets/example_image/typical_vehicle_excavator.png ADDED

Git LFS Details

  • SHA256: 27a418853eefa197f1e10ed944a7bb071413fd2bc1681804ee773a6ce3799c52
  • Pointer size: 131 Bytes
  • Size of remote file: 712 kB
assets/example_image/typical_vehicle_helicopter.png ADDED

Git LFS Details

  • SHA256: 7f1a1b37bc52417c0e1048927a30bf3a52dde81345f90114040608186196ffe7
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
assets/example_image/typical_vehicle_locomotive.png ADDED

Git LFS Details

  • SHA256: 67d5124e7069b133dc0aaa16047a52c6dc1d7c2a4e4510ffd3235fe95597fbef
  • Pointer size: 131 Bytes
  • Size of remote file: 806 kB
assets/example_image/typical_vehicle_pirate_ship.png ADDED

Git LFS Details

  • SHA256: 8926ec7c9f36a52e3bf1ca4e8cfc75d297da934fe7c0e8d7a73f0d35a5ef38ad
  • Pointer size: 131 Bytes
  • Size of remote file: 611 kB
assets/example_image/weatherworn_misc_paper_machine3.png ADDED

Git LFS Details

  • SHA256: 3c6fbf47ed53ffad1a3027f72bf0806c238682c7bf7604b8770aef428906d33b
  • Pointer size: 131 Bytes
  • Size of remote file: 502 kB
assets/example_multi_image/character_1.png ADDED

Git LFS Details

  • SHA256: 729e2e0214232e1dd45c9187e339f8a2a87c6e41257ef701e578a4f0a8be7ef1
  • Pointer size: 131 Bytes
  • Size of remote file: 172 kB
assets/example_multi_image/character_2.png ADDED

Git LFS Details

  • SHA256: e8afc8af9960a5f2315d9d5b9815f29137ef9b63c4d512c451a8ba374003c3ac
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
assets/example_multi_image/character_3.png ADDED

Git LFS Details

  • SHA256: 3413a2b4f67105b947a42ebbe14d3f6ab9f68a99f2258a86fd50d94342b49bdd
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
assets/example_multi_image/mushroom_1.png ADDED

Git LFS Details

  • SHA256: 3e5fd9ee75d39c827b0c5544392c255a89b4ca62bf3cf31f702d39b150bea00c
  • Pointer size: 131 Bytes
  • Size of remote file: 434 kB
assets/example_multi_image/mushroom_2.png ADDED

Git LFS Details

  • SHA256: 6e5709b910341b12d149632b8442aebd218dd591b238ccb7e4e8b185860aae04
  • Pointer size: 131 Bytes
  • Size of remote file: 462 kB
assets/example_multi_image/mushroom_3.png ADDED

Git LFS Details

  • SHA256: 115c9c3a11d3d08de568680468ad42ac1322c21bdad46a43e149fc97cc687e48
  • Pointer size: 131 Bytes
  • Size of remote file: 425 kB
assets/example_multi_image/orangeguy_1.png ADDED

Git LFS Details

  • SHA256: ab30ee372fc365e5d100f2e06ea7cd17b3ea3f53b1a76ebe44e69a1cf834700e
  • Pointer size: 131 Bytes
  • Size of remote file: 632 kB