Sm0kyWu commited on
Commit
2e379a5
·
verified ·
1 Parent(s): 56e0f94

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -128,7 +128,7 @@ def reset_image(predictor, img):
128
  - 设置 predictor 的输入图像,
129
  - 返回原图
130
  """
131
- predictor.set_image(img)
132
  original_img = img.copy()
133
  # 返回predictor,visible occlusion mask初始化, 原始图像
134
  return predictor, img, img, original_img
@@ -140,22 +140,25 @@ def button_clickable(selected_points):
140
  return gr.Button.update(interactive=False)
141
 
142
 
143
- def run_sam(predictor: SamPredictor, selected_points):
144
  """
145
  调用 SAM 模型进行分割。
146
  """
147
- # 确保图像为 RGB 模式
148
  if len(selected_points) == 0:
149
  return [], None
150
  input_points = [p for p in selected_points]
151
  input_labels = [1 for _ in range(len(selected_points))]
 
 
152
  masks, _, _ = predictor.predict(
153
  point_coords=np.array(input_points),
154
- point_labels=input_labels,
155
  multimask_output=False, # 单对象输出
156
  )
 
157
  visible_mask = 255 * np.squeeze(masks).astype(np.uint8)
158
- return visible_mask, None
159
 
160
  def apply_mask_overlay(image, mask):
161
  """
@@ -176,7 +179,7 @@ def segment_and_overlay(image, points, sam_predictor):
176
  """
177
  调用 run_sam 获得 mask,然后叠加显示分割结果。
178
  """
179
- mask, _ = run_sam(sam_predictor, points)
180
  overlaid = apply_mask_overlay(image, mask)
181
  return overlaid, mask
182
 
 
128
  - 设置 predictor 的输入图像,
129
  - 返回原图
130
  """
131
+ # predictor.set_image(img)
132
  original_img = img.copy()
133
  # 返回predictor,visible occlusion mask初始化, 原始图像
134
  return predictor, img, img, original_img
 
140
  return gr.Button.update(interactive=False)
141
 
142
 
143
+ def run_sam(image, predictor, selected_points):
144
  """
145
  调用 SAM 模型进行分割。
146
  """
147
+ predictor.set_image(image)
148
  if len(selected_points) == 0:
149
  return [], None
150
  input_points = [p for p in selected_points]
151
  input_labels = [1 for _ in range(len(selected_points))]
152
+ print(input_points)
153
+ print(input_labels)
154
  masks, _, _ = predictor.predict(
155
  point_coords=np.array(input_points),
156
+ point_labels=np.array(input_labels),
157
  multimask_output=False, # 单对象输出
158
  )
159
+ print(masks.shape)
160
  visible_mask = 255 * np.squeeze(masks).astype(np.uint8)
161
+ return visible_mask
162
 
163
  def apply_mask_overlay(image, mask):
164
  """
 
179
  """
180
  调用 run_sam 获得 mask,然后叠加显示分割结果。
181
  """
182
+ mask, _ = run_sam(image, sam_predictor, points)
183
  overlaid = apply_mask_overlay(image, mask)
184
  return overlaid, mask
185