Sm0kyWu commited on
Commit
e3ffdc8
·
verified ·
1 Parent(s): 58b8df4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -88
app.py CHANGED
@@ -33,94 +33,6 @@ def end_session(req: gr.Request):
33
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
  shutil.rmtree(user_dir)
35
 
36
-
37
- def select_point(predictor: SamPredictor,
38
- annotated_img: np.ndarray,
39
- orig_img: np.ndarray,
40
- sel_pix: list,
41
- point_type: str,
42
- evt: gr.SelectData):
43
- """
44
- 当用户在标注图像上点击时:
45
- - 将点击坐标添加到 sel_pix(正/负 prompt 根据 point_type),
46
- - 根据 sel_pix 调用 SAM 得到 mask,
47
- - 在 annotated_img 上绘制所有已选点的标记,
48
- - 返回更新后的标注图像、SAM 输出(用于显示)及生成的 visible_mask(用于后续 pix2gestalt)。
49
- """
50
- # 拷贝原图(用于标注)
51
- img = annotated_img.copy()
52
- h_original, w_original, _ = orig_img.shape
53
- h_new, w_new = 256, 256
54
- scale_x = w_new / w_original
55
- scale_y = h_new / h_original
56
-
57
- # 根据 prompt 类型添加点击点(evt.index 格式为 (x, y))
58
- if point_type == 'positive_prompt':
59
- sel_pix.append((evt.index, 1))
60
- elif point_type == 'negative_prompt':
61
- sel_pix.append((evt.index, 0))
62
- else:
63
- sel_pix.append((evt.index, 1))
64
-
65
- # 将原始尺寸的点转换到 256x256 尺寸(SAM 输入要求)
66
- processed_sel_pix = []
67
- for point, label in sel_pix:
68
- x, y = point
69
- new_x = int(x * scale_x)
70
- new_y = int(y * scale_y)
71
- processed_sel_pix.append(([new_x, new_y], label))
72
-
73
- visible_mask, overlay_mask = run_sam(predictor, processed_sel_pix)
74
- # overlay_mask 是 SAM 输出的 mask(256x256),调整尺寸到原图尺寸以便显示
75
- mask = np.squeeze(overlay_mask[0][0]) # (256, 256)
76
- resized_mask = cv2.resize(mask.astype(np.uint8) * 255, (w_original, h_original), interpolation=cv2.INTER_AREA)
77
- resized_mask = resized_mask > 127
78
- # 制作 overlay 信息(供 output_mask 使用)
79
- resized_overlay_mask = [(resized_mask, 'visible_mask')]
80
-
81
- # 绘制所有点的标记
82
- COLORS = [(255, 0, 0), (0, 255, 0)]
83
- MARKERS = [1, 4]
84
- scaling_factor = min(h_original / 256, w_original / 256)
85
- marker_size = int(6 * scaling_factor)
86
- marker_thickness = int(2 * scaling_factor)
87
- for point, label in sel_pix:
88
- cv2.drawMarker(img, tuple(point), COLORS[label], markerType=MARKERS[label],
89
- markerSize=marker_size, thickness=marker_thickness)
90
-
91
- return img, (orig_img, resized_overlay_mask), visible_mask
92
-
93
- def undo_points(predictor, orig_img, sel_pix):
94
- """
95
- 撤销最后一次点击:
96
- - 从 sel_pix 中 pop 出最后一个点,
97
- - 根据剩余点重新调用 SAM 得到 mask,
98
- - 返回更新后的图像和 mask。
99
- """
100
- temp = orig_img.copy()
101
- h_original, w_original, _ = orig_img.shape
102
- COLORS = [(255, 0, 0), (0, 255, 0)]
103
- MARKERS = [0, 5]
104
- scaling_factor = min(h_original / 256, w_original / 256)
105
- marker_size = int(6 * scaling_factor)
106
- marker_thickness = int(2 * scaling_factor)
107
- if len(sel_pix) > 0:
108
- sel_pix.pop()
109
- # 重新绘制剩余点
110
- for point, label in sel_pix:
111
- cv2.drawMarker(temp, tuple(point), COLORS[label],
112
- markerType=MARKERS[label], markerSize=marker_size, thickness=marker_thickness)
113
- else:
114
- dummy_overlay_mask = [(np.zeros((h_original, w_original), dtype=np.uint8), 'visible_mask')]
115
- return orig_img, (orig_img, dummy_overlay_mask), []
116
-
117
- visible_mask, overlay_mask = run_sam(predictor, sel_pix)
118
- mask = np.squeeze(overlay_mask[0][0])
119
- resized_mask = cv2.resize(mask.astype(np.uint8) * 255, (w_original, h_original), interpolation=cv2.INTER_AREA)
120
- resized_mask = resized_mask > 127
121
- resized_overlay_mask = [(resized_mask, 'visible_mask')]
122
- return temp, (orig_img, resized_overlay_mask), visible_mask
123
-
124
  def reset_image(predictor, img):
125
  """
126
  上传图像后调用:
@@ -144,6 +56,8 @@ def run_sam(image, predictor, selected_points):
144
  """
145
  调用 SAM 模型进行分割。
146
  """
 
 
147
  predictor.set_image(image)
148
  if len(selected_points) == 0:
149
  return [], None
 
33
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
34
  shutil.rmtree(user_dir)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def reset_image(predictor, img):
37
  """
38
  上传图像后调用:
 
56
  """
57
  调用 SAM 模型进行分割。
58
  """
59
+ print(image.shape)
60
+ print(np.unique(image))
61
  predictor.set_image(image)
62
  if len(selected_points) == 0:
63
  return [], None