MykolaL commited on
Commit
fcac98a
·
verified ·
1 Parent(s): c866eb2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -24
app.py CHANGED
@@ -95,34 +95,38 @@ def create_refseg_demo(model, tokenizer, device):
95
 
96
  with torch.no_grad():
97
  out = model(image_t, text)
98
-
 
99
  if isinstance(out, np.ndarray):
100
- pred = torch.from_numpy(out).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  else:
102
- pred = out
103
-
104
- pred = pred.float()
105
 
106
- if pred.dim() == 2:
107
- # H×W mask -> N×C×H×W
108
- pred = pred.unsqueeze(0).unsqueeze(0)
109
- one_channel_mask = True
110
- elif pred.dim() == 3:
111
- # N×H×W -> add channel
112
- pred = pred.unsqueeze(1)
113
- one_channel_mask = True
114
- elif pred.dim() == 4:
115
- # N×C×H×W (logits) -> argmax later
116
- one_channel_mask = (pred.shape[1] == 1)
117
-
118
-
119
- pred = torch.nn.functional.interpolate(pred.float(), shape[2:], mode='bilinear', align_corners=True)
120
- output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
121
  alpha = 0.65
122
- image[output_mask == 0] = (image[output_mask == 0]*alpha).astype(np.uint8)
123
- contours, _ = cv2.findContours(output_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
124
- cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
125
- return Image.fromarray(image)
 
 
126
 
127
  submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
128
  examples = gr.Examples(examples=[["imgs/test_img2.jpg", "green plant"], ["imgs/test_img3.jpg", "chair"], ["imgs/test_img4.jpg", "left green plant"], ["imgs/test_img5.jpg", "man walking on foot"], ["imgs/test_img5.jpg", "the rightest camel"]],
 
95
 
96
  with torch.no_grad():
97
  out = model(image_t, text)
98
+
99
+ # --- normalize to numpy mask ---
100
  if isinstance(out, np.ndarray):
101
+ mask = out
102
+ elif isinstance(out, torch.Tensor):
103
+ pred = out.float()
104
+ if pred.dim() == 2:
105
+ mask = pred.cpu().numpy()
106
+ elif pred.dim() == 3:
107
+ # (N,H,W) → squeeze batch
108
+ mask = pred.squeeze(0).cpu().numpy()
109
+ elif pred.dim() == 4:
110
+ # logits (N,C,H,W) → argmax over channel
111
+ pred = torch.nn.functional.interpolate(pred, size=orig_shape, mode='bilinear', align_corners=True)
112
+ mask = pred.argmax(1).squeeze().cpu().numpy()
113
+ else:
114
+ raise RuntimeError(f"Unexpected output shape {pred.shape}")
115
  else:
116
+ raise RuntimeError(f"Unexpected output type {type(out)}")
 
 
117
 
118
+ # --- ensure mask is binary uint8 ---
119
+ if mask.dtype != np.uint8:
120
+ mask = (mask > 0.5).astype(np.uint8)
121
+
122
+ # --- overlay like your Colab code ---
 
 
 
 
 
 
 
 
 
 
123
  alpha = 0.65
124
+ overlay = image.copy()
125
+ overlay[mask == 0] = (overlay[mask == 0] * alpha).astype(np.uint8)
126
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
127
+ cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
128
+
129
+ return Image.fromarray(overlay)
130
 
131
  submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
132
  examples = gr.Examples(examples=[["imgs/test_img2.jpg", "green plant"], ["imgs/test_img3.jpg", "chair"], ["imgs/test_img4.jpg", "left green plant"], ["imgs/test_img5.jpg", "man walking on foot"], ["imgs/test_img5.jpg", "the rightest camel"]],