Sm0kyWu commited on
Commit
485f2ba
·
verified ·
1 Parent(s): 6324d52

Delete app_old.py

Browse files
Files changed (1) hide show
  1. app_old.py +0 -659
app_old.py DELETED
@@ -1,659 +0,0 @@
1
- import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
-
5
- import os
6
- import shutil
7
- os.environ['SPCONV_ALGO'] = 'native'
8
- from typing import *
9
- import torch
10
- import numpy as np
11
- import imageio
12
- from easydict import EasyDict as edict
13
- from PIL import Image
14
- from Amodal3R.pipelines import Amodal3RImageTo3DPipeline
15
- from Amodal3R.representations import Gaussian, MeshExtractResult
16
- from Amodal3R.utils import render_utils, postprocessing_utils
17
- from segment_anything import sam_model_registry, SamPredictor
18
- from huggingface_hub import hf_hub_download
19
- import cv2
20
-
21
-
22
- MAX_SEED = np.iinfo(np.int32).max
23
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
24
- os.makedirs(TMP_DIR, exist_ok=True)
25
-
26
-
27
- def start_session(req: gr.Request):
28
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
29
- os.makedirs(user_dir, exist_ok=True)
30
-
31
-
32
- def end_session(req: gr.Request):
33
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
- shutil.rmtree(user_dir)
35
-
36
- def reset_image(predictor, img):
37
- """
38
- 上传图像后调用:
39
- - 重置 predictor,
40
- - 设置 predictor 的输入图像,
41
- - 返回原图
42
- """
43
- predictor.set_image(img)
44
- original_img = img.copy()
45
- # 返回predictor,visible occlusion mask初始化, 原始图像
46
- return predictor, original_img, "The models are ready."
47
-
48
- def button_clickable(selected_points):
49
- if len(selected_points) > 0:
50
- return gr.Button.update(interactive=True)
51
- else:
52
- return gr.Button.update(interactive=False)
53
-
54
- def run_sam(predictor, selected_points):
55
- """
56
- 调用 SAM 模型进行分割。
57
- """
58
- # predictor.set_image(image)
59
- if len(selected_points) == 0:
60
- return [], None
61
- input_points = [p for p in selected_points]
62
- input_labels = [1 for _ in range(len(selected_points))]
63
- # input_points = np.array([[210, 300]])
64
- # input_labels = np.array([1])
65
- masks, _, _ = predictor.predict(
66
- point_coords=np.array(input_points),
67
- point_labels=np.array(input_labels),
68
- multimask_output=False, # 单对象输出
69
- )
70
- best_mask = masks[0].astype(np.uint8)
71
- # dilate
72
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
73
- best_mask = cv2.dilate(best_mask, kernel, iterations=1)
74
- best_mask = cv2.erode(best_mask, kernel, iterations=1)
75
- return best_mask
76
-
77
- def apply_mask_overlay(image, mask, color=(255, 0, 0)):
78
- """
79
- 在原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
80
- """
81
- img_arr = image
82
- overlay = img_arr.copy()
83
- gray_color = np.array([200, 200, 200], dtype=np.uint8)
84
- non_mask = mask == 0
85
- overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
86
- contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
87
- cv2.drawContours(overlay, contours, -1, color, 2)
88
- return overlay
89
-
90
- def segment_and_overlay(image, points, sam_predictor):
91
- """
92
- 调用 run_sam 获得 mask,然后叠加显示分割结果。
93
- """
94
- visible_mask = run_sam(sam_predictor, points)
95
- overlaid = apply_mask_overlay(image, visible_mask * 255)
96
- return overlaid, visible_mask
97
-
98
-
99
- def reset_points():
100
- """
101
- 清空点击点提示。
102
- """
103
- return [], ""
104
-
105
-
106
- @spaces.GPU
107
- def image_to_3d(
108
- image: List[tuple],
109
- masks: List[np.ndarray],
110
- seed: int,
111
- ss_guidance_strength: float,
112
- ss_sampling_steps: int,
113
- slat_guidance_strength: float,
114
- slat_sampling_steps: int,
115
- multiimage_algo: str,
116
- req: gr.Request,
117
- ) -> tuple:
118
- """
119
- 将图像转换为 3D 模型。
120
- """
121
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
122
- outputs = pipeline.run_multi_image(
123
- [img[0] for img in image],
124
- [mask[0] for mask in masks],
125
- seed=seed,
126
- formats=["gaussian", "mesh"],
127
- preprocess_image=False,
128
- sparse_structure_sampler_params={
129
- "steps": ss_sampling_steps,
130
- "cfg_strength": ss_guidance_strength,
131
- },
132
- slat_sampler_params={
133
- "steps": slat_sampling_steps,
134
- "cfg_strength": slat_guidance_strength,
135
- },
136
- mode=multiimage_algo,
137
- )
138
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
139
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
140
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
141
- video_path = os.path.join(user_dir, 'sample.mp4')
142
- imageio.mimsave(video_path, video, fps=15)
143
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
144
- torch.cuda.empty_cache()
145
- return state, video_path
146
-
147
-
148
- @spaces.GPU(duration=90)
149
- def extract_glb(
150
- state: dict,
151
- mesh_simplify: float,
152
- texture_size: int,
153
- req: gr.Request,
154
- ) -> tuple:
155
- """
156
- 从生成的 3D 模型中提取 GLB 文件。
157
- """
158
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
159
- gs, mesh = unpack_state(state)
160
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
161
- glb_path = os.path.join(user_dir, 'sample.glb')
162
- glb.export(glb_path)
163
- torch.cuda.empty_cache()
164
- return glb_path, glb_path
165
-
166
-
167
- @spaces.GPU
168
- def extract_gaussian(state: dict, req: gr.Request) -> tuple:
169
- """
170
- 从生成的 3D 模型中提取 Gaussian 文件。
171
- """
172
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
173
- gs, _ = unpack_state(state)
174
- gaussian_path = os.path.join(user_dir, 'sample.ply')
175
- gs.save_ply(gaussian_path)
176
- torch.cuda.empty_cache()
177
- return gaussian_path, gaussian_path
178
-
179
-
180
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
181
- return {
182
- 'gaussian': {
183
- **gs.init_params,
184
- '_xyz': gs._xyz.cpu().numpy(),
185
- '_features_dc': gs._features_dc.cpu().numpy(),
186
- '_scaling': gs._scaling.cpu().numpy(),
187
- '_rotation': gs._rotation.cpu().numpy(),
188
- '_opacity': gs._opacity.cpu().numpy(),
189
- },
190
- 'mesh': {
191
- 'vertices': mesh.vertices.cpu().numpy(),
192
- 'faces': mesh.faces.cpu().numpy(),
193
- },
194
- }
195
-
196
-
197
- def unpack_state(state: dict) -> tuple:
198
- gs = Gaussian(
199
- aabb=state['gaussian']['aabb'],
200
- sh_degree=state['gaussian']['sh_degree'],
201
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
202
- scaling_bias=state['gaussian']['scaling_bias'],
203
- opacity_bias=state['gaussian']['opacity_bias'],
204
- scaling_activation=state['gaussian']['scaling_activation'],
205
- )
206
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
207
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
208
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
209
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
210
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
211
-
212
- mesh = edict(
213
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
214
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
215
- )
216
-
217
- return gs, mesh
218
-
219
-
220
- def prepare_multi_example() -> list:
221
- multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
222
- images = []
223
- for case in multi_case:
224
- _images = []
225
- for i in range(1, 4):
226
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
227
- W, H = img.size
228
- img = img.resize((int(W / H * 512), 512))
229
- _images.append(np.array(img))
230
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
231
- return images
232
-
233
-
234
- def split_image(image: Image.Image) -> list:
235
- """
236
- 将图像拆分为多个视图(不进行预处理)。
237
- """
238
- image = np.array(image)
239
- alpha = image[..., 3]
240
- alpha = np.any(alpha > 0, axis=0)
241
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
242
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
243
- images = []
244
- for s, e in zip(start_pos, end_pos):
245
- images.append(Image.fromarray(image[:, s:e+1]))
246
- return [image for image in images]
247
-
248
- def get_sam_predictor():
249
- sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
250
- model_type = "vit_h"
251
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
252
- # sam.cuda()
253
- sam_predictor = SamPredictor(sam)
254
- return sam_predictor
255
-
256
-
257
- def draw_points_on_image(image, point, point_type):
258
- """在图像上绘制所有点,points 为 [(x, y, point_type), ...]"""
259
- image_with_points = image.copy()
260
- x, y = point
261
- color = (255, 0, 0) if point_type == "vis" else (0, 255, 0)
262
- cv2.circle(image_with_points, (int(x), int(y)), radius=10, color=color, thickness=-1)
263
- return image_with_points
264
-
265
-
266
- def see_point(image, x, y, point_type):
267
- """
268
- see操作:不修改 points 列表,仅在图像上临时显示这个点,
269
- 并返回更新后的图像和当前列表(不更新)。
270
- """
271
- # 复制当前列表,并在副本中加上新点(仅用于显示)
272
- updated_image = draw_points_on_image(image, [x,y], point_type)
273
- return updated_image
274
-
275
- def add_point(x, y, point_type, visible_points, occlusion_points):
276
- """
277
- add操作:将新点添加到 points 列表中,
278
- 并返回更新后的图像和新的点列表。
279
- """
280
- if point_type == "vis":
281
- # check duplicate
282
- if [x, y] not in visible_points:
283
- visible_points.append([x, y])
284
- else:
285
- if [x, y] not in occlusion_points:
286
- occlusion_points.append([x, y])
287
- return visible_points, occlusion_points
288
-
289
- def delete_point(point_type, visible_points, occlusion_points):
290
- """
291
- delete操作:删除 points 列表中的最后一个点,
292
- 并返回更新后的图像和新的点列表。
293
- """
294
- if point_type == "vis":
295
- visible_points.pop()
296
- else:
297
- occlusion_points.pop()
298
- return visible_points, occlusion_points
299
-
300
-
301
- def clear_all_points(image):
302
- """
303
- 清除所有点:返回原图、空的 visible 和 occlusion 列表,
304
- 以及更新后的点文本信息和空下拉菜单列表。
305
- """
306
- updated_image = image.copy()
307
- return updated_image
308
-
309
- def see_visible_points(image, visible_points):
310
- """
311
- 在图像上绘制所有 visible 点(红色)。
312
- """
313
- updated_image = image.copy()
314
- for p in visible_points:
315
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
316
- return updated_image
317
-
318
- def see_occlusion_points(image, occlusion_points):
319
- """
320
- 在图像上绘制所有 occlusion 点(绿色)。
321
- """
322
- updated_image = image.copy()
323
- for p in occlusion_points:
324
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1)
325
- return updated_image
326
-
327
- def update_all_points(visible_points, occlusion_points):
328
- text = f"Visible Points: {visible_points}\nOcclusion Points: {occlusion_points}"
329
- visible_dropdown_choices = [f"({p[0]}, {p[1]})" for p in visible_points]
330
- occlusion_dropdown_choices = [f"({p[0]}, {p[1]})" for p in occlusion_points]
331
- # 返回更新字典来明确设置 choices 和 value
332
- return text, gr.Dropdown(label="Select Visible Point to Delete", choices=visible_dropdown_choices, value=None, interactive=True), gr.Dropdown(label="Select Occlusion Point to Delete", choices=occlusion_dropdown_choices, value=None, interactive=True)
333
-
334
- def delete_selected_visible(image, visible_points, occlusion_points, selected_value):
335
- # selected_value 是类似 "(x, y)" 的字符串
336
- try:
337
- selected_index = [f"({p[0]}, {p[1]})" for p in visible_points].index(selected_value)
338
- except ValueError:
339
- selected_index = None
340
- if selected_index is not None and 0 <= selected_index < len(visible_points):
341
- visible_points.pop(selected_index)
342
- updated_image = image.copy()
343
- # 重新绘制所有 visible 点(红色)
344
- for p in visible_points:
345
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(255, 0, 0), thickness=-1)
346
- updated_text, vis_dropdown, occ_dropdown = update_all_points(visible_points, occlusion_points)
347
- return updated_image, visible_points, occlusion_points, updated_text, vis_dropdown, occ_dropdown
348
-
349
- def delete_selected_occlusion(image, visible_points, occlusion_points, selected_value):
350
- try:
351
- selected_index = [f"({p[0]}, {p[1]})" for p in occlusion_points].index(selected_value)
352
- except ValueError:
353
- selected_index = None
354
- if selected_index is not None and 0 <= selected_index < len(occlusion_points):
355
- occlusion_points.pop(selected_index)
356
- updated_image = image.copy()
357
- # 重新绘制所有 occlusion 点(绿色)
358
- for p in occlusion_points:
359
- cv2.circle(updated_image, (int(p[0]), int(p[1])), radius=10, color=(0, 255, 0), thickness=-1)
360
- updated_text, vis_dropdown, occ_dropdown = update_all_points(visible_points, occlusion_points)
361
- return updated_image, visible_points, occlusion_points, updated_text, vis_dropdown, occ_dropdown
362
-
363
- def add_mask(mask, mask_list):
364
- # check if the mask if same as the last mask in the list
365
- if len(mask_list) > 0:
366
- if np.array_equal(mask, mask_list[-1]):
367
- return mask_list
368
- mask_list.append(mask)
369
- return mask_list
370
-
371
- def vis_mask(image, mask_list):
372
- updated_image = image.copy()
373
- # combine all the mask:
374
- combined_mask = np.zeros_like(updated_image[:, :, 0])
375
- for mask in mask_list:
376
- combined_mask = cv2.bitwise_or(combined_mask, mask)
377
- # overlay the mask on the image
378
- updated_image = apply_mask_overlay(updated_image, combined_mask)
379
- return updated_image
380
-
381
- def delete_mask(mask_list):
382
- if len(mask_list) > 0:
383
- mask_list.pop()
384
- return mask_list
385
-
386
-
387
- def apply_combined_mask_overlay(image, vis_mask, occ_mask):
388
- """
389
- 在原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
390
- """
391
- img_arr = image
392
- overlay = img_arr.copy()
393
- gray_color = np.array([200, 200, 200], dtype=np.uint8)
394
- non_mask = (vis_mask == 0) & (occ_mask == 0)
395
- overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
396
- contours_occ, _ = cv2.findContours(occ_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
397
- cv2.drawContours(overlay, contours_occ, -1, (0,0,255), 2)
398
- contours_vis, _ = cv2.findContours(vis_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
399
- cv2.drawContours(overlay, contours_vis, -1, (255,0,0), 2)
400
- return overlay
401
-
402
-
403
- def combine_mask(image, visible_mask_list, occlusion_mask_list):
404
- combined_vis_mask = np.zeros_like(image[:, :, 0])
405
- combined_occ_mask = np.zeros_like(image[:, :, 0])
406
- combined_mask = np.zeros_like(image[:, :, 0])
407
- for mask in visible_mask_list:
408
- combined_vis_mask = cv2.bitwise_or(combined_mask, mask)
409
- for mask in occlusion_mask_list:
410
- combined_occ_mask = cv2.bitwise_or(combined_mask, mask)
411
- # 添加 visible mask 边缘作为 occlusion mask 的一部分
412
-
413
- overlay = apply_combined_mask_overlay(image, combined_vis_mask, combined_occ_mask)
414
- # 5*5 kernel dilate for occlusion mask
415
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
416
- combined_occ_mask = cv2.dilate(combined_occ_mask, kernel, iterations=1)
417
- combined_mask[combined_occ_mask > 0] = 128
418
- combined_mask[combined_vis_mask > 0] = 255
419
- # concat the mask and overlay to be a single image
420
- print(overlay.shape, combined_mask.shape)
421
- result = cv2.hconcat([overlay, combined_mask[..., None].repeat(3, axis=-1)])
422
- return result, combined_mask, occluded_image
423
-
424
-
425
- def get_seed(randomize_seed: bool, seed: int) -> int:
426
- """
427
- Get the random seed.
428
- """
429
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
430
-
431
-
432
- with gr.Blocks(delete_cache=(600, 600)) as demo:
433
- gr.Markdown("""
434
- ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
435
- """)
436
-
437
- # 定义各状态变量
438
- predictor = gr.State(value=get_sam_predictor())
439
- visible_points_state = gr.State(value=[])
440
- occlusion_points_state = gr.State(value=[])
441
- original_image = gr.State(value=None)
442
- visibility_mask = gr.State(value=None)
443
- occlusion_mask = gr.State(value=None)
444
- visibility_mask_list = gr.State(value=[])
445
- occlusion_mask_list = gr.State(value=[])
446
-
447
- combined_mask = gr.State(value=None)
448
- occluded_image = gr.State(value=None)
449
-
450
-
451
- with gr.Row():
452
- gr.Markdown("""* Step 1 - Generate Visibility Mask and Occlusion Mask.
453
- * Please wait for a few seconds after uploading the image. The 2D segmenter is getting ready.
454
- * Add the point prompts to indicate the target object and occluders separately.
455
- * "Render Point", see the position of the point to be added.
456
- * "Add Point", the point will be added to the list.
457
- * "Generate mask", see the segmented area corresponding to current point list.
458
- * "Add mask", current mask will be added for 3D amodal completion.
459
- """)
460
- with gr.Row():
461
- with gr.Column():
462
- input_image = gr.Image(type="numpy", label='Input Occlusion Image', sources="upload", height=300)
463
- with gr.Row():
464
- message = gr.Markdown("Please wait a few seconds after uploading the image.", label="Message") # 用于显示提示信息
465
- with gr.Row():
466
- x_input = gr.Number(label="X Coordinate", value=0)
467
- y_input = gr.Number(label="Y Coordinate", value=0)
468
- point_type = gr.Radio(["vis", "occ"], label="Point Prompt Type", value="vis")
469
- with gr.Row():
470
- see_button = gr.Button("Render Point")
471
- add_button = gr.Button("Add Point")
472
- with gr.Row():
473
- # 新增按钮:Clear、分别查看 visible/occlusion
474
- clear_button = gr.Button("Clear Points")
475
- see_visible_button = gr.Button("Visible Points")
476
- see_occlusion_button = gr.Button("Occluded Points")
477
- with gr.Row():
478
- # 新增文本框实时显示点列表
479
- points_text = gr.Textbox(label="Points List", interactive=False)
480
- with gr.Row():
481
- # 新增下拉菜单,用户可选择需要删除的点
482
- visible_points_dropdown = gr.Dropdown(label="Select Visible Point to Delete", choices=[], value=None, interactive=True)
483
- occlusion_points_dropdown = gr.Dropdown(label="Select Occlusion Point to Delete", choices=[], value=None, interactive=True)
484
- with gr.Row():
485
- delete_visible_button = gr.Button("Delete Selected Visible")
486
- delete_occlusion_button = gr.Button("Delete Selected Occlusion")
487
- with gr.Column():
488
- # 用于显示 SAM 分割结果
489
- visible_mask = gr.Image(label='Visible Mask', interactive=False, height=300)
490
- with gr.Row():
491
- gen_vis_mask = gr.Button("Generate Mask")
492
- add_vis_mask = gr.Button("Add Mask")
493
- with gr.Row():
494
- render_vis_mask = gr.Button("Render Mask")
495
- undo_vis_mask = gr.Button("Undo Last Mask")
496
- occluded_mask = gr.Image(label='Occlusion Mask', interactive=False, height=300)
497
- with gr.Row():
498
- gen_occ_mask = gr.Button("Generate Mask")
499
- add_occ_mask = gr.Button("Add Mask")
500
- with gr.Row():
501
- render_occ_mask = gr.Button("Render Mask")
502
- undo_occ_mask = gr.Button("Undo Last Mask")
503
- with gr.Row():
504
- with gr.Column():
505
- mask_check = gr.Image(label='Combined Mask', interactive=False, height=300)
506
- with gr.Row():
507
- check_combine_button = gr.Button("Check Combined Mask, make sure there is no GAP between the visible area (white) and occluded area (gray)")
508
- with gr.Row():
509
- gr.Markdown("""* Step 2 - 3D Amodal Completion.
510
- * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
511
- * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
512
- """)
513
- with gr.Row():
514
- with gr.Column():
515
- with gr.Accordion(label="Generation Settings", open=True):
516
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
517
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
518
- gr.Markdown("Stage 1: Sparse Structure Generation")
519
- with gr.Row():
520
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
521
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
522
- gr.Markdown("Stage 2: Structured Latent Generation")
523
- with gr.Row():
524
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
525
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
526
- generate_btn = gr.Button("Generate")
527
- with gr.Column():
528
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
529
-
530
-
531
- # ---------------------------
532
- # 原有交互逻辑(略)
533
- # ---------------------------
534
- input_image.upload(
535
- reset_image,
536
- [predictor, input_image],
537
- [predictor, original_image, message],
538
- )
539
- see_button.click(
540
- see_point,
541
- inputs=[original_image, x_input, y_input, point_type],
542
- outputs=[input_image]
543
- )
544
- add_button.click(
545
- add_point,
546
- inputs=[x_input, y_input, point_type, visible_points_state, occlusion_points_state],
547
- outputs=[visible_points_state, occlusion_points_state]
548
- )
549
-
550
- # ---------------------------
551
- # 新增的交互逻辑
552
- # ---------------------------
553
- clear_button.click(
554
- clear_all_points,
555
- inputs=[original_image],
556
- outputs=[input_image]
557
- )
558
- see_visible_button.click(
559
- see_visible_points,
560
- inputs=[input_image, visible_points_state],
561
- outputs=input_image
562
- )
563
- see_occlusion_button.click(
564
- see_occlusion_points,
565
- inputs=[input_image, occlusion_points_state],
566
- outputs=input_image
567
- )
568
- # 当 visible_points_state 或 occlusion_points_state 变化时,更新文本框和下拉菜单
569
- visible_points_state.change(
570
- update_all_points,
571
- inputs=[visible_points_state, occlusion_points_state],
572
- outputs=[points_text, visible_points_dropdown, occlusion_points_dropdown]
573
- )
574
- occlusion_points_state.change(
575
- update_all_points,
576
- inputs=[visible_points_state, occlusion_points_state],
577
- outputs=[points_text, visible_points_dropdown, occlusion_points_dropdown]
578
- )
579
- delete_visible_button.click(
580
- delete_selected_visible,
581
- inputs=[input_image, visible_points_state, occlusion_points_state, visible_points_dropdown],
582
- outputs=[input_image, visible_points_state, occlusion_points_state, points_text, visible_points_dropdown, occlusion_points_dropdown]
583
- )
584
- delete_occlusion_button.click(
585
- delete_selected_occlusion,
586
- inputs=[input_image, visible_points_state, occlusion_points_state, occlusion_points_dropdown],
587
- outputs=[input_image, visible_points_state, occlusion_points_state, points_text, visible_points_dropdown, occlusion_points_dropdown]
588
- )
589
-
590
- # 生成mask的逻辑
591
- gen_vis_mask.click(
592
- segment_and_overlay,
593
- inputs=[original_image, visible_points_state, predictor],
594
- outputs=[visible_mask, visibility_mask]
595
- )
596
- add_vis_mask.click(
597
- add_mask,
598
- inputs=[visibility_mask, visibility_mask_list],
599
- outputs=[visibility_mask_list]
600
- )
601
- render_vis_mask.click(
602
- vis_mask,
603
- inputs=[original_image, visibility_mask_list],
604
- outputs=[visible_mask]
605
- )
606
- undo_vis_mask.click(
607
- delete_mask,
608
- inputs=[visibility_mask_list],
609
- outputs=[visibility_mask_list]
610
- )
611
- gen_occ_mask.click(
612
- segment_and_overlay,
613
- inputs=[original_image, occlusion_points_state, predictor],
614
- outputs=[occluded_mask, occlusion_mask]
615
- )
616
- add_occ_mask.click(
617
- add_mask,
618
- inputs=[occlusion_mask, occlusion_mask_list],
619
- outputs=[occlusion_mask_list]
620
- )
621
- render_occ_mask.click(
622
- vis_mask,
623
- inputs=[original_image, occlusion_mask_list],
624
- outputs=[occluded_mask]
625
- )
626
- undo_occ_mask.click(
627
- delete_mask,
628
- inputs=[occlusion_mask_list],
629
- outputs=[occlusion_mask_list]
630
- )
631
-
632
- # check combined mask
633
- check_combine_button.click(
634
- combine_mask,
635
- inputs=[original_image, visibility_mask_list, occlusion_mask_list],
636
- outputs=[mask_check, combined_mask]
637
- )
638
-
639
- # 3D Amodal Reconstruction
640
- generate_btn.click(
641
- get_seed,
642
- inputs=[randomize_seed, seed],
643
- outputs=[seed],
644
- ).then(
645
- image_to_3d,
646
- inputs=[original_image, [combined_mask], seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, "multiimage"],
647
- outputs=[visibility_mask]
648
- )
649
-
650
-
651
- # 启动 Gradio App
652
- if __name__ == "__main__":
653
- pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
654
- pipeline.cuda()
655
- try:
656
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
657
- except:
658
- pass
659
- demo.launch()