Sm0kyWu commited on
Commit
b7fa320
·
verified ·
1 Parent(s): 0cc751c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -265
app.py CHANGED
@@ -33,102 +33,42 @@ def end_session(req: gr.Request):
33
  shutil.rmtree(user_dir)
34
 
35
 
36
- def preprocess_image(image: Image.Image) -> Image.Image:
37
  """
38
- Preprocess the input image.
39
-
40
- Args:
41
- image (Image.Image): The input image.
42
-
43
- Returns:
44
- Image.Image: The preprocessed image.
45
- """
46
- processed_image = pipeline.preprocess_image(image)
47
- return processed_image
48
-
49
-
50
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
51
- """
52
- Preprocess a list of input images.
53
-
54
- Args:
55
- images (List[Tuple[Image.Image, str]]): The input images.
56
-
57
- Returns:
58
- List[Image.Image]: The preprocessed images.
59
- """
60
- images = [image[0] for image in images]
61
- processed_images = [pipeline.preprocess_image(image) for image in images]
62
- return processed_images
63
-
64
-
65
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
66
- return {
67
- 'gaussian': {
68
- **gs.init_params,
69
- '_xyz': gs._xyz.cpu().numpy(),
70
- '_features_dc': gs._features_dc.cpu().numpy(),
71
- '_scaling': gs._scaling.cpu().numpy(),
72
- '_rotation': gs._rotation.cpu().numpy(),
73
- '_opacity': gs._opacity.cpu().numpy(),
74
- },
75
- 'mesh': {
76
- 'vertices': mesh.vertices.cpu().numpy(),
77
- 'faces': mesh.faces.cpu().numpy(),
78
- },
79
- }
80
-
81
-
82
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
83
- gs = Gaussian(
84
- aabb=state['gaussian']['aabb'],
85
- sh_degree=state['gaussian']['sh_degree'],
86
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
87
- scaling_bias=state['gaussian']['scaling_bias'],
88
- opacity_bias=state['gaussian']['opacity_bias'],
89
- scaling_activation=state['gaussian']['scaling_activation'],
90
- )
91
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
92
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
93
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
94
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
95
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
96
-
97
- mesh = edict(
98
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
99
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
100
- )
101
-
102
- return gs, mesh
103
-
104
-
105
- def get_seed(randomize_seed: bool, seed: int) -> int:
106
- """
107
- Get the random seed.
108
- """
109
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
110
-
111
-
112
- def record_click(evt, points):
113
- """
114
- 记录在图像上点击的位置,默认所有点击均为目标对象的 prompt,标签设为 1
115
  """
116
  if points is None:
117
  points = []
118
- if evt is None:
119
- return points, str(points)
120
- # 假设 evt 中包含 "index" 键,其值为 (x, y)
121
- coord = evt.get("index", None)
122
- if coord is not None:
123
- points.append((coord, 1))
124
- return points, str(points)
 
 
 
 
125
 
126
  @spaces.GPU
127
  def run_sam(predictor: SamPredictor, image, selected_points):
128
  """
129
- 调用 Segment Anything 模型进行分割,返回 mask 及其他信息
130
  """
131
- assert image.mode == 'RGB', "Image should be RGB"
 
 
 
 
132
  if len(selected_points) == 0:
133
  return [], None
134
  input_points = [p for p, _ in selected_points]
@@ -144,33 +84,24 @@ def run_sam(predictor: SamPredictor, image, selected_points):
144
 
145
  def apply_mask_overlay(image: Image.Image, mask: np.ndarray) -> Image.Image:
146
  """
147
- 在原图上叠加 mask:使用红色绘制 mask 的轮廓,
148
- 非 mask 区域叠加浅灰色半透明遮罩
149
  """
150
- # 转换图像为 numpy 数组
151
  img_arr = np.array(image)
152
- # 如果 mask 为三维,则取第一个通道
153
  if mask.ndim == 3:
154
  mask = mask[:, :, 0]
155
- # 创建副本用于叠加
156
  overlay = img_arr.copy()
157
- # 定义浅灰色(例如 RGB=(200,200,200))
158
  gray_color = np.array([200, 200, 200], dtype=np.uint8)
159
- # 对于非 mask 区域(mask == 0),进行半透明混合
160
  non_mask = mask == 0
161
  overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
162
- # 使用 OpenCV 找到 mask 的轮廓
163
  contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
164
- # 在 overlay 上绘制红色轮廓,粗细为2个像素
165
  cv2.drawContours(overlay, contours, -1, (255, 0, 0), 2)
166
  return Image.fromarray(overlay)
167
 
168
 
169
  def segment_and_overlay(image: Image.Image, points):
170
  """
171
- 调用 run_sam 获得 mask,然后调用 apply_mask_overlay 生成叠加图像
172
  """
173
- # 确保输入图像为 RGB
174
  if image.mode != "RGB":
175
  image = image.convert("RGB")
176
  mask, _ = run_sam(sam_predictor, image, points)
@@ -179,9 +110,10 @@ def segment_and_overlay(image: Image.Image, points):
179
  overlaid = apply_mask_overlay(image, mask)
180
  return overlaid
181
 
 
182
  def reset_points():
183
  """
184
- 清空 prompt 点
185
  """
186
  return [], ""
187
 
@@ -189,33 +121,18 @@ def reset_points():
189
  @spaces.GPU
190
  def image_to_3d(
191
  image: Image.Image,
192
- multiimages: List[Tuple[Image.Image, str]],
193
  is_multiimage: bool,
194
  seed: int,
195
  ss_guidance_strength: float,
196
  ss_sampling_steps: int,
197
  slat_guidance_strength: float,
198
  slat_sampling_steps: int,
199
- multiimage_algo: Literal["multidiffusion", "stochastic"],
200
  req: gr.Request,
201
- ) -> Tuple[dict, str]:
202
  """
203
- Convert an image to a 3D model.
204
-
205
- Args:
206
- image (Image.Image): The input image.
207
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
208
- is_multiimage (bool): Whether is in multi-image mode.
209
- seed (int): The random seed.
210
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
211
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
212
- slat_guidance_strength (float): The guidance strength for structured latent generation.
213
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
214
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
215
-
216
- Returns:
217
- dict: The information of the generated 3D model.
218
- str: The path to the video of the 3D model.
219
  """
220
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
221
  if not is_multiimage:
@@ -235,7 +152,7 @@ def image_to_3d(
235
  )
236
  else:
237
  outputs = pipeline.run_multi_image(
238
- [image[0] for image in multiimages],
239
  seed=seed,
240
  formats=["gaussian", "mesh"],
241
  preprocess_image=False,
@@ -265,17 +182,9 @@ def extract_glb(
265
  mesh_simplify: float,
266
  texture_size: int,
267
  req: gr.Request,
268
- ) -> Tuple[str, str]:
269
  """
270
- Extract a GLB file from the 3D model.
271
-
272
- Args:
273
- state (dict): The state of the generated 3D model.
274
- mesh_simplify (float): The mesh simplification factor.
275
- texture_size (int): The texture resolution.
276
-
277
- Returns:
278
- str: The path to the extracted GLB file.
279
  """
280
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
281
  gs, mesh = unpack_state(state)
@@ -287,15 +196,9 @@ def extract_glb(
287
 
288
 
289
  @spaces.GPU
290
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
291
  """
292
- Extract a Gaussian file from the 3D model.
293
-
294
- Args:
295
- state (dict): The state of the generated 3D model.
296
-
297
- Returns:
298
- str: The path to the extracted Gaussian file.
299
  """
300
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
301
  gs, _ = unpack_state(state)
@@ -305,7 +208,47 @@ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
305
  return gaussian_path, gaussian_path
306
 
307
 
308
- def prepare_multi_example() -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
310
  images = []
311
  for case in multi_case:
@@ -319,49 +262,37 @@ def prepare_multi_example() -> List[Image.Image]:
319
  return images
320
 
321
 
322
- def split_image(image: Image.Image) -> List[Image.Image]:
323
  """
324
- Split an image into multiple views.
325
  """
326
  image = np.array(image)
327
  alpha = image[..., 3]
328
- alpha = np.any(alpha>0, axis=0)
329
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
330
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
331
  images = []
332
  for s, e in zip(start_pos, end_pos):
333
  images.append(Image.fromarray(image[:, s:e+1]))
334
- return [preprocess_image(image) for image in images]
335
 
336
 
337
  with gr.Blocks(delete_cache=(600, 600)) as demo:
338
  gr.Markdown("""
339
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
340
- * Upload an image and click "Generate" to create a 3D asset.
341
- * Target object selection. Multiple point prompts are supported until you get the ideal visible area.
342
- * Occluders selection, this can be done by squential point prompts. You can choose "all occ", then all the other areas except the target object will be treated as occluders.
343
- * Different random seeds can be tried in "Generation Settings", if you think the results are not ideal.
344
- * If the reconstruction 3D asset is satisfactory, you can extract the GLB file and download it.
345
  """)
346
  with gr.Row():
347
  with gr.Column():
348
- with gr.Tabs() as input_tabs:
349
- image_prompt = gr.Image(type="numpy", label="Input Occlusion Image", height=512)
350
-
351
- # 用于交互标注的图像
352
  image_annotation = gr.Image(type="numpy", label="Select Point Prompts for Target Object", interactive=True, height=512)
353
- # 记录用户点击的点及显示当前 prompt 列表
354
  points_state = gr.State([])
355
  points_output = gr.Textbox(label="Target Object Prompts", interactive=False)
356
- # image_annotation 添加点击事件记录 prompt 点
357
- image_annotation.select(
358
- record_click,
359
- inputs=[points_state],
360
- outputs=[points_state, points_output]
361
- )
362
- # 新增:分割后展示结果的组件
363
- segmented_output = gr.Image(label="Segmented Result", height=512)
364
-
365
 
366
  with gr.Accordion(label="Generation Settings", open=False):
367
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
@@ -374,119 +305,38 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
374
  with gr.Row():
375
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
376
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
377
-
378
- # with gr.Column():
379
-
380
-
381
- # generate_btn = gr.Button("Generate")
382
-
383
- # with gr.Accordion(label="GLB Extraction Settings", open=False):
384
- # mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
385
- # texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
386
-
387
- # with gr.Row():
388
- # extract_glb_btn = gr.Button("Extract GLB", interactive=False)
389
- # extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
390
- # gr.Markdown("""
391
- # *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
392
- # """)
393
-
394
- # with gr.Column():
395
- # video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
396
- # model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
397
-
398
- # with gr.Row():
399
- # download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
400
- # download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
401
 
402
- # is_multiimage = gr.State(False)
403
- # output_buf = gr.State()
404
-
405
- # # Example images at the bottom of the page
406
- # with gr.Row() as single_image_example:
407
- # examples = gr.Examples(
408
- # examples=[
409
- # f'assets/example_image/{image}'
410
- # for image in os.listdir("assets/example_image")
411
- # ],
412
- # inputs=[image_prompt],
413
- # fn=preprocess_image,
414
- # outputs=[image_prompt],
415
- # run_on_click=True,
416
- # examples_per_page=64,
417
- # )
418
- # with gr.Row(visible=False) as multiimage_example:
419
- # examples_multi = gr.Examples(
420
- # examples=prepare_multi_example(),
421
- # inputs=[image_prompt],
422
- # fn=split_image,
423
- # outputs=[multiimage_prompt],
424
- # run_on_click=True,
425
- # examples_per_page=8,
426
- # )
427
-
428
- # Handlers
429
  demo.load(start_session)
430
  demo.unload(end_session)
431
 
432
- # single_image_input_tab.select(
433
- # lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
434
- # outputs=[single_image_example]
435
- # )
436
- # multiimage_input_tab.select(
437
- # lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
438
- # outputs=[is_multiimage, single_image_example, multiimage_example]
439
- # )
440
-
441
  image_prompt.upload(
442
- preprocess_image,
443
  inputs=[image_prompt],
444
- outputs=[image_prompt],
445
  )
446
-
447
- # generate_btn.click(
448
- # get_seed,
449
- # inputs=[randomize_seed, seed],
450
- # outputs=[seed],
451
- # ).then(
452
- # image_to_3d,
453
- # inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
454
- # outputs=[output_buf, video_output],
455
- # ).then(
456
- # lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
457
- # outputs=[extract_glb_btn, extract_gs_btn],
458
- # )
459
-
460
- # video_output.clear(
461
- # lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
462
- # outputs=[extract_glb_btn, extract_gs_btn],
463
- # )
464
-
465
- # extract_glb_btn.click(
466
- # extract_glb,
467
- # inputs=[output_buf, mesh_simplify, texture_size],
468
- # outputs=[model_output, download_glb],
469
- # ).then(
470
- # lambda: gr.Button(interactive=True),
471
- # outputs=[download_glb],
472
- # )
473
 
474
- # extract_gs_btn.click(
475
- # extract_gaussian,
476
- # inputs=[output_buf],
477
- # outputs=[model_output, download_gs],
478
- # ).then(
479
- # lambda: gr.Button(interactive=True),
480
- # outputs=[download_gs],
481
- # )
482
-
483
- # model_output.clear(
484
- # lambda: gr.Button(interactive=False),
485
- # outputs=[download_glb],
486
- # )
 
 
487
 
 
488
 
489
- # Launch the Gradio app
490
  if __name__ == "__main__":
491
  sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
492
  model_type = "vit_h"
@@ -497,7 +347,7 @@ if __name__ == "__main__":
497
  pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
498
  pipeline.cuda()
499
  try:
500
- pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) # Preload rembg
501
  except:
502
  pass
503
  demo.launch()
 
33
  shutil.rmtree(user_dir)
34
 
35
 
36
+ def select_point_callback(image, points, evt):
37
  """
38
+ 当用户点击图像时,记录点击点并在图像上绘制标记(十字)。
39
+ 输入参数:
40
+ - image:当前图像(numpy 数组)。
41
+ - points:已记录的点列表。
42
+ - evt:Gradio 的点击事件数据(包含 .index,即点击坐标)。
43
+ 返回:
44
+ - 更新后的图像(带标记)。
45
+ - 更新后的点列表。
46
+ - 以字符串形式展示的点列表(用于显示在文本框中)。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  """
48
  if points is None:
49
  points = []
50
+ annotated_img = image.copy()
51
+ # 如果没有点击事件,则直接返回原图和当前点列表
52
+ if evt is None or evt.index is None:
53
+ return image, points, str(points)
54
+ coord = evt.index # 期望返回 (x, y)
55
+ points.append((tuple(coord), 1)) # 记录为正样本 prompt
56
+ # 绘制十字标记,颜色为红色
57
+ cv2.drawMarker(annotated_img, tuple(coord), (255, 0, 0),
58
+ markerType=cv2.MARKER_CROSS, markerSize=10, thickness=2)
59
+ return annotated_img, points, str(points)
60
+
61
 
62
  @spaces.GPU
63
  def run_sam(predictor: SamPredictor, image, selected_points):
64
  """
65
+ 调用 SAM 模型进行分割。
66
  """
67
+ # 确保图像为 RGB 模式
68
+ if isinstance(image, np.ndarray):
69
+ image = Image.fromarray(image)
70
+ if image.mode != 'RGB':
71
+ image = image.convert("RGB")
72
  if len(selected_points) == 0:
73
  return [], None
74
  input_points = [p for p, _ in selected_points]
 
84
 
85
  def apply_mask_overlay(image: Image.Image, mask: np.ndarray) -> Image.Image:
86
  """
87
+ ���原图上叠加 mask:使用红色绘制 mask 的轮廓,非 mask 区域叠加浅灰色半透明遮罩。
 
88
  """
 
89
  img_arr = np.array(image)
 
90
  if mask.ndim == 3:
91
  mask = mask[:, :, 0]
 
92
  overlay = img_arr.copy()
 
93
  gray_color = np.array([200, 200, 200], dtype=np.uint8)
 
94
  non_mask = mask == 0
95
  overlay[non_mask] = (0.5 * overlay[non_mask] + 0.5 * gray_color).astype(np.uint8)
 
96
  contours, _ = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
97
  cv2.drawContours(overlay, contours, -1, (255, 0, 0), 2)
98
  return Image.fromarray(overlay)
99
 
100
 
101
  def segment_and_overlay(image: Image.Image, points):
102
  """
103
+ 调用 run_sam 获得 mask,然后叠加显示分割结果。
104
  """
 
105
  if image.mode != "RGB":
106
  image = image.convert("RGB")
107
  mask, _ = run_sam(sam_predictor, image, points)
 
110
  overlaid = apply_mask_overlay(image, mask)
111
  return overlaid
112
 
113
+
114
  def reset_points():
115
  """
116
+ 清空点击点提示。
117
  """
118
  return [], ""
119
 
 
121
  @spaces.GPU
122
  def image_to_3d(
123
  image: Image.Image,
124
+ multiimages: List[tuple],
125
  is_multiimage: bool,
126
  seed: int,
127
  ss_guidance_strength: float,
128
  ss_sampling_steps: int,
129
  slat_guidance_strength: float,
130
  slat_sampling_steps: int,
131
+ multiimage_algo: str,
132
  req: gr.Request,
133
+ ) -> tuple:
134
  """
135
+ 将图像转换为 3D 模型。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  """
137
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
138
  if not is_multiimage:
 
152
  )
153
  else:
154
  outputs = pipeline.run_multi_image(
155
+ [img[0] for img in multiimages],
156
  seed=seed,
157
  formats=["gaussian", "mesh"],
158
  preprocess_image=False,
 
182
  mesh_simplify: float,
183
  texture_size: int,
184
  req: gr.Request,
185
+ ) -> tuple:
186
  """
187
+ 从生成的 3D 模型中提取 GLB 文件。
 
 
 
 
 
 
 
 
188
  """
189
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
190
  gs, mesh = unpack_state(state)
 
196
 
197
 
198
  @spaces.GPU
199
+ def extract_gaussian(state: dict, req: gr.Request) -> tuple:
200
  """
201
+ 从生成的 3D 模型中提取 Gaussian 文件。
 
 
 
 
 
 
202
  """
203
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
204
  gs, _ = unpack_state(state)
 
208
  return gaussian_path, gaussian_path
209
 
210
 
211
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
212
+ return {
213
+ 'gaussian': {
214
+ **gs.init_params,
215
+ '_xyz': gs._xyz.cpu().numpy(),
216
+ '_features_dc': gs._features_dc.cpu().numpy(),
217
+ '_scaling': gs._scaling.cpu().numpy(),
218
+ '_rotation': gs._rotation.cpu().numpy(),
219
+ '_opacity': gs._opacity.cpu().numpy(),
220
+ },
221
+ 'mesh': {
222
+ 'vertices': mesh.vertices.cpu().numpy(),
223
+ 'faces': mesh.faces.cpu().numpy(),
224
+ },
225
+ }
226
+
227
+
228
+ def unpack_state(state: dict) -> tuple:
229
+ gs = Gaussian(
230
+ aabb=state['gaussian']['aabb'],
231
+ sh_degree=state['gaussian']['sh_degree'],
232
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
233
+ scaling_bias=state['gaussian']['scaling_bias'],
234
+ opacity_bias=state['gaussian']['opacity_bias'],
235
+ scaling_activation=state['gaussian']['scaling_activation'],
236
+ )
237
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
238
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
239
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
240
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
241
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
242
+
243
+ mesh = edict(
244
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
245
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
246
+ )
247
+
248
+ return gs, mesh
249
+
250
+
251
+ def prepare_multi_example() -> list:
252
  multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
253
  images = []
254
  for case in multi_case:
 
262
  return images
263
 
264
 
265
+ def split_image(image: Image.Image) -> list:
266
  """
267
+ 将图像拆分为多个视图(不进行预处理)。
268
  """
269
  image = np.array(image)
270
  alpha = image[..., 3]
271
+ alpha = np.any(alpha > 0, axis=0)
272
  start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
273
  end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
274
  images = []
275
  for s, e in zip(start_pos, end_pos):
276
  images.append(Image.fromarray(image[:, s:e+1]))
277
+ return [image for image in images]
278
 
279
 
280
  with gr.Blocks(delete_cache=(600, 600)) as demo:
281
  gr.Markdown("""
282
  ## 3D Amodal Reconstruction with [Amodal3R](https://sm0kywu.github.io/Amodal3R/)
283
+ * 上传图像后,点击图像选择目标区域,点击的点会在图像上显示。
 
 
 
 
284
  """)
285
  with gr.Row():
286
  with gr.Column():
287
+ # 上传的图像不经过预处理,直接展示原始图像
288
+ image_prompt = gr.Image(type="numpy", label="Input Occlusion Image", height=512)
289
+ # 用于交互标注的图像,点击时更新显示标记
 
290
  image_annotation = gr.Image(type="numpy", label="Select Point Prompts for Target Object", interactive=True, height=512)
291
+ # 存储点击点状态以及显示点击点坐标
292
  points_state = gr.State([])
293
  points_output = gr.Textbox(label="Target Object Prompts", interactive=False)
294
+ # 展示 SAM 分割结果(只用于显示,不允许上传)
295
+ segmented_output = gr.Image(label="Segmented Result", height=512, interactive=False)
 
 
 
 
 
 
 
296
 
297
  with gr.Accordion(label="Generation Settings", open=False):
298
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=1, step=1)
 
305
  with gr.Row():
306
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
307
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
308
+ # 其他组件(如生成按钮、视频展示、GLB 提取等)可根据需要添加
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ # 会话启动与结束
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  demo.load(start_session)
312
  demo.unload(end_session)
313
 
314
+ # 上传图像后直接显示,不做预处理
 
 
 
 
 
 
 
 
315
  image_prompt.upload(
316
+ lambda x: x,
317
  inputs=[image_prompt],
318
+ outputs=[image_prompt]
319
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
+ # 点击 image_annotation 时调用 select_point_callback,
322
+ # 更新图像显示、点状态以及文本显示点击点信息
323
+ image_annotation.select(
324
+ select_point_callback,
325
+ inputs=[image_annotation, points_state],
326
+ outputs=[image_annotation, points_state, points_output]
327
+ )
328
+
329
+ # 添加一个按钮,用于运行 SAM 分割并展示叠加结果
330
+ segment_button = gr.Button("Run Segmentation")
331
+ segment_button.click(
332
+ segment_and_overlay,
333
+ inputs=[image_prompt, points_state],
334
+ outputs=[segmented_output]
335
+ )
336
 
337
+ # 后续可添加生成 3D 模型等其他流程...
338
 
339
+ # 启动 Gradio App
340
  if __name__ == "__main__":
341
  sam_checkpoint = hf_hub_download("ybelkada/segment-anything", "checkpoints/sam_vit_h_4b8939.pth")
342
  model_type = "vit_h"
 
347
  pipeline = Amodal3RImageTo3DPipeline.from_pretrained("Sm0kyWu/Amodal3R")
348
  pipeline.cuda()
349
  try:
350
+ pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
351
  except:
352
  pass
353
  demo.launch()