MykolaL commited on
Commit
4a991a4
·
verified ·
1 Parent(s): fcac98a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -87,50 +87,53 @@ def create_refseg_demo(model, tokenizer, device):
87
  submit = gr.Button("Submit")
88
 
89
  def on_submit(image, text):
90
- image = np.array(image)
91
- image_t = transforms.ToTensor()(image).unsqueeze(0).to(device)
 
 
92
  image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
93
- shape = image_t.shape
94
- image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
95
-
96
  with torch.no_grad():
97
  out = model(image_t, text)
98
 
99
- # --- normalize to numpy mask ---
100
- if isinstance(out, np.ndarray):
101
- mask = out
102
- elif isinstance(out, torch.Tensor):
103
- pred = out.float()
104
- if pred.dim() == 2:
105
- mask = pred.cpu().numpy()
106
- elif pred.dim() == 3:
107
- # (N,H,W) → squeeze batch
108
- mask = pred.squeeze(0).cpu().numpy()
109
- elif pred.dim() == 4:
110
- # logits (N,C,H,W) → argmax over channel
111
- pred = torch.nn.functional.interpolate(pred, size=orig_shape, mode='bilinear', align_corners=True)
112
- mask = pred.argmax(1).squeeze().cpu().numpy()
113
- else:
114
- raise RuntimeError(f"Unexpected output shape {pred.shape}")
115
  else:
116
- raise RuntimeError(f"Unexpected output type {type(out)}")
 
 
 
 
 
117
 
118
- # --- ensure mask is binary uint8 ---
119
- if mask.dtype != np.uint8:
120
- mask = (mask > 0.5).astype(np.uint8)
121
 
122
- # --- overlay like your Colab code ---
123
  alpha = 0.65
124
- overlay = image.copy()
125
  overlay[mask == 0] = (overlay[mask == 0] * alpha).astype(np.uint8)
 
 
126
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
127
  cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
128
 
129
  return Image.fromarray(overlay)
130
 
131
  submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
132
- examples = gr.Examples(examples=[["imgs/test_img2.jpg", "green plant"], ["imgs/test_img3.jpg", "chair"], ["imgs/test_img4.jpg", "left green plant"], ["imgs/test_img5.jpg", "man walking on foot"], ["imgs/test_img5.jpg", "the rightest camel"]],
133
- inputs=[input_image, input_text])
 
 
 
 
 
 
 
 
 
134
 
135
 
136
  def main():
 
87
  submit = gr.Button("Submit")
88
 
89
  def on_submit(image, text):
90
+ # Convert PIL -> np array
91
+ image_np = np.array(image).copy()
92
+ transform = transforms.ToTensor()
93
+ image_t = transform(image).unsqueeze(0).to(device)
94
  image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
95
+ image_t = torch.nn.functional.interpolate(image_t, (512, 512), mode='bilinear', align_corners=True)
96
+
 
97
  with torch.no_grad():
98
  out = model(image_t, text)
99
 
100
+ # Ensure numpy mask
101
+ if isinstance(out, torch.Tensor):
102
+ mask = out.squeeze().detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
+ mask = out
105
+
106
+ # Convert to binary mask
107
+ if mask.ndim > 2:
108
+ mask = np.argmax(mask, axis=0)
109
+ mask = (mask > 0).astype(np.uint8)
110
 
111
+ # Resize mask to original image size
112
+ mask = cv2.resize(mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
 
113
 
114
+ # Overlay mask
115
  alpha = 0.65
116
+ overlay = image_np.copy()
117
  overlay[mask == 0] = (overlay[mask == 0] * alpha).astype(np.uint8)
118
+
119
+ # Draw contours
120
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
121
  cv2.drawContours(overlay, contours, -1, (0, 255, 0), 2)
122
 
123
  return Image.fromarray(overlay)
124
 
125
  submit.click(on_submit, inputs=[input_image, input_text], outputs=refseg_image)
126
+ examples = gr.Examples(
127
+ examples=[
128
+ ["imgs/test_img2.jpg", "green plant"],
129
+ ["imgs/test_img3.jpg", "chair"],
130
+ ["imgs/test_img4.jpg", "left green plant"],
131
+ ["imgs/test_img5.jpg", "man walking on foot"],
132
+ ["imgs/test_img5.jpg", "the rightest camel"],
133
+ ],
134
+ inputs=[input_image, input_text]
135
+ )
136
+
137
 
138
 
139
  def main():