Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
|
@@ -33,94 +33,6 @@ def end_session(req: gr.Request):
|
|
| 33 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 34 |
shutil.rmtree(user_dir)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
def select_point(predictor: SamPredictor,
|
| 38 |
-
annotated_img: np.ndarray,
|
| 39 |
-
orig_img: np.ndarray,
|
| 40 |
-
sel_pix: list,
|
| 41 |
-
point_type: str,
|
| 42 |
-
evt: gr.SelectData):
|
| 43 |
-
"""
|
| 44 |
-
当用户在标注图像上点击时:
|
| 45 |
-
- 将点击坐标添加到 sel_pix(正/负 prompt 根据 point_type),
|
| 46 |
-
- 根据 sel_pix 调用 SAM 得到 mask,
|
| 47 |
-
- 在 annotated_img 上绘制所有已选点的标记,
|
| 48 |
-
- 返回更新后的标注图像、SAM 输出(用于显示)及生成的 visible_mask(用于后续 pix2gestalt)。
|
| 49 |
-
"""
|
| 50 |
-
# 拷贝原图(用于标注)
|
| 51 |
-
img = annotated_img.copy()
|
| 52 |
-
h_original, w_original, _ = orig_img.shape
|
| 53 |
-
h_new, w_new = 256, 256
|
| 54 |
-
scale_x = w_new / w_original
|
| 55 |
-
scale_y = h_new / h_original
|
| 56 |
-
|
| 57 |
-
# 根据 prompt 类型添加点击点(evt.index 格式为 (x, y))
|
| 58 |
-
if point_type == 'positive_prompt':
|
| 59 |
-
sel_pix.append((evt.index, 1))
|
| 60 |
-
elif point_type == 'negative_prompt':
|
| 61 |
-
sel_pix.append((evt.index, 0))
|
| 62 |
-
else:
|
| 63 |
-
sel_pix.append((evt.index, 1))
|
| 64 |
-
|
| 65 |
-
# 将原始尺寸的点转换到 256x256 尺寸(SAM 输入要求)
|
| 66 |
-
processed_sel_pix = []
|
| 67 |
-
for point, label in sel_pix:
|
| 68 |
-
x, y = point
|
| 69 |
-
new_x = int(x * scale_x)
|
| 70 |
-
new_y = int(y * scale_y)
|
| 71 |
-
processed_sel_pix.append(([new_x, new_y], label))
|
| 72 |
-
|
| 73 |
-
visible_mask, overlay_mask = run_sam(predictor, processed_sel_pix)
|
| 74 |
-
# overlay_mask 是 SAM 输出的 mask(256x256),调整尺寸到原图尺寸以便显示
|
| 75 |
-
mask = np.squeeze(overlay_mask[0][0]) # (256, 256)
|
| 76 |
-
resized_mask = cv2.resize(mask.astype(np.uint8) * 255, (w_original, h_original), interpolation=cv2.INTER_AREA)
|
| 77 |
-
resized_mask = resized_mask > 127
|
| 78 |
-
# 制作 overlay 信息(供 output_mask 使用)
|
| 79 |
-
resized_overlay_mask = [(resized_mask, 'visible_mask')]
|
| 80 |
-
|
| 81 |
-
# 绘制所有点的标记
|
| 82 |
-
COLORS = [(255, 0, 0), (0, 255, 0)]
|
| 83 |
-
MARKERS = [1, 4]
|
| 84 |
-
scaling_factor = min(h_original / 256, w_original / 256)
|
| 85 |
-
marker_size = int(6 * scaling_factor)
|
| 86 |
-
marker_thickness = int(2 * scaling_factor)
|
| 87 |
-
for point, label in sel_pix:
|
| 88 |
-
cv2.drawMarker(img, tuple(point), COLORS[label], markerType=MARKERS[label],
|
| 89 |
-
markerSize=marker_size, thickness=marker_thickness)
|
| 90 |
-
|
| 91 |
-
return img, (orig_img, resized_overlay_mask), visible_mask
|
| 92 |
-
|
| 93 |
-
def undo_points(predictor, orig_img, sel_pix):
|
| 94 |
-
"""
|
| 95 |
-
撤销最后一次点击:
|
| 96 |
-
- 从 sel_pix 中 pop 出最后一个点,
|
| 97 |
-
- 根据剩余点重新调用 SAM 得到 mask,
|
| 98 |
-
- 返回更新后的图像和 mask。
|
| 99 |
-
"""
|
| 100 |
-
temp = orig_img.copy()
|
| 101 |
-
h_original, w_original, _ = orig_img.shape
|
| 102 |
-
COLORS = [(255, 0, 0), (0, 255, 0)]
|
| 103 |
-
MARKERS = [0, 5]
|
| 104 |
-
scaling_factor = min(h_original / 256, w_original / 256)
|
| 105 |
-
marker_size = int(6 * scaling_factor)
|
| 106 |
-
marker_thickness = int(2 * scaling_factor)
|
| 107 |
-
if len(sel_pix) > 0:
|
| 108 |
-
sel_pix.pop()
|
| 109 |
-
# 重新绘制剩余点
|
| 110 |
-
for point, label in sel_pix:
|
| 111 |
-
cv2.drawMarker(temp, tuple(point), COLORS[label],
|
| 112 |
-
markerType=MARKERS[label], markerSize=marker_size, thickness=marker_thickness)
|
| 113 |
-
else:
|
| 114 |
-
dummy_overlay_mask = [(np.zeros((h_original, w_original), dtype=np.uint8), 'visible_mask')]
|
| 115 |
-
return orig_img, (orig_img, dummy_overlay_mask), []
|
| 116 |
-
|
| 117 |
-
visible_mask, overlay_mask = run_sam(predictor, sel_pix)
|
| 118 |
-
mask = np.squeeze(overlay_mask[0][0])
|
| 119 |
-
resized_mask = cv2.resize(mask.astype(np.uint8) * 255, (w_original, h_original), interpolation=cv2.INTER_AREA)
|
| 120 |
-
resized_mask = resized_mask > 127
|
| 121 |
-
resized_overlay_mask = [(resized_mask, 'visible_mask')]
|
| 122 |
-
return temp, (orig_img, resized_overlay_mask), visible_mask
|
| 123 |
-
|
| 124 |
def reset_image(predictor, img):
|
| 125 |
"""
|
| 126 |
上传图像后调用:
|
|
@@ -144,6 +56,8 @@ def run_sam(image, predictor, selected_points):
|
|
| 144 |
"""
|
| 145 |
调用 SAM 模型进行分割。
|
| 146 |
"""
|
|
|
|
|
|
|
| 147 |
predictor.set_image(image)
|
| 148 |
if len(selected_points) == 0:
|
| 149 |
return [], None
|
|
|
|
| 33 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 34 |
shutil.rmtree(user_dir)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def reset_image(predictor, img):
|
| 37 |
"""
|
| 38 |
上传图像后调用:
|
|
|
|
| 56 |
"""
|
| 57 |
调用 SAM 模型进行分割。
|
| 58 |
"""
|
| 59 |
+
print(image.shape)
|
| 60 |
+
print(np.unique(image))
|
| 61 |
predictor.set_image(image)
|
| 62 |
if len(selected_points) == 0:
|
| 63 |
return [], None
|