Spaces:
Running
on
L4
Running
on
L4
Upload app.py
Browse files
app.py
CHANGED
@@ -95,34 +95,38 @@ def create_refseg_demo(model, tokenizer, device):
|
|
95 |
|
96 |
with torch.no_grad():
|
97 |
out = model(image_t, text)
|
98 |
-
|
|
|
99 |
if isinstance(out, np.ndarray):
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
else:
|
102 |
-
|
103 |
-
|
104 |
-
pred = pred.float()
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
# N×H×W -> add channel
|
112 |
-
pred = pred.unsqueeze(1)
|
113 |
-
one_channel_mask = True
|
114 |
-
elif pred.dim() == 4:
|
115 |
-
# N×C×H×W (logits) -> argmax later
|
116 |
-
one_channel_mask = (pred.shape[1] == 1)
|
117 |
-
|
118 |
-
|
119 |
-
pred = torch.nn.functional.interpolate(pred.float(), shape[2:], mode='bilinear', align_corners=True)
|
120 |
-
output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
|
121 |
alpha = 0.65
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
|
127 |
submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
|
128 |
examples = gr.Examples(examples=[["imgs/test_img2.jpg", "green plant"], ["imgs/test_img3.jpg", "chair"], ["imgs/test_img4.jpg", "left green plant"], ["imgs/test_img5.jpg", "man walking on foot"], ["imgs/test_img5.jpg", "the rightest camel"]],
|
|
|
95 |
|
96 |
with torch.no_grad():
|
97 |
out = model(image_t, text)
|
98 |
+
|
99 |
+
# --- normalize to numpy mask ---
|
100 |
if isinstance(out, np.ndarray):
|
101 |
+
mask = out
|
102 |
+
elif isinstance(out, torch.Tensor):
|
103 |
+
pred = out.float()
|
104 |
+
if pred.dim() == 2:
|
105 |
+
mask = pred.cpu().numpy()
|
106 |
+
elif pred.dim() == 3:
|
107 |
+
# (N,H,W) → squeeze batch
|
108 |
+
mask = pred.squeeze(0).cpu().numpy()
|
109 |
+
elif pred.dim() == 4:
|
110 |
+
# logits (N,C,H,W) → argmax over channel
|
111 |
+
pred = torch.nn.functional.interpolate(pred, size=orig_shape, mode='bilinear', align_corners=True)
|
112 |
+
mask = pred.argmax(1).squeeze().cpu().numpy()
|
113 |
+
else:
|
114 |
+
raise RuntimeError(f"Unexpected output shape {pred.shape}")
|
115 |
else:
|
116 |
+
raise RuntimeError(f"Unexpected output type {type(out)}")
|
|
|
|
|
117 |
|
118 |
+
# --- ensure mask is binary uint8 ---
|
119 |
+
if mask.dtype != np.uint8:
|
120 |
+
mask = (mask > 0.5).astype(np.uint8)
|
121 |
+
|
122 |
+
# --- overlay like your Colab code ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
alpha = 0.65
|
124 |
+
overlay = image.copy()
|
125 |
+
overlay[mask == 0] = (overlay[mask == 0] * alpha).astype(np.uint8)
|
126 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
127 |
+
cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
|
128 |
+
|
129 |
+
return Image.fromarray(overlay)
|
130 |
|
131 |
submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
|
132 |
examples = gr.Examples(examples=[["imgs/test_img2.jpg", "green plant"], ["imgs/test_img3.jpg", "chair"], ["imgs/test_img4.jpg", "left green plant"], ["imgs/test_img5.jpg", "man walking on foot"], ["imgs/test_img5.jpg", "the rightest camel"]],
|