MykolaL commited on
Commit
c866eb2
·
verified ·
1 Parent(s): 3099000

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -3
app.py CHANGED
@@ -58,7 +58,10 @@ def create_depth_demo(model, device):
58
  image = F.pad(image, (0, 0, 40, 0))
59
  with torch.no_grad():
60
  pred = model(image)#['pred_d']
61
- pred = torch.from_numpy(pred).to(device)
 
 
 
62
 
63
  pred = pred[:,:,40:,:]
64
  pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
@@ -91,9 +94,28 @@ def create_refseg_demo(model, tokenizer, device):
91
  image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
92
 
93
  with torch.no_grad():
94
- pred = model(image_t, text)
95
 
96
- pred = torch.from_numpy(pred).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  pred = torch.nn.functional.interpolate(pred.float(), shape[2:], mode='bilinear', align_corners=True)
98
  output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
99
  alpha = 0.65
 
58
  image = F.pad(image, (0, 0, 40, 0))
59
  with torch.no_grad():
60
  pred = model(image)#['pred_d']
61
+ pred = torch.from_numpy(pred).to(device).float()
62
+
63
+ if pred.dim() == 2: # H×W
64
+ pred = pred.unsqueeze(0).unsqueeze(0)
65
 
66
  pred = pred[:,:,40:,:]
67
  pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
 
94
  image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
95
 
96
  with torch.no_grad():
97
+ out = model(image_t, text)
98
 
99
+ if isinstance(out, np.ndarray):
100
+ pred = torch.from_numpy(out).to(device)
101
+ else:
102
+ pred = out
103
+
104
+ pred = pred.float()
105
+
106
+ if pred.dim() == 2:
107
+ # H×W mask -> N×C×H×W
108
+ pred = pred.unsqueeze(0).unsqueeze(0)
109
+ one_channel_mask = True
110
+ elif pred.dim() == 3:
111
+ # N×H×W -> add channel
112
+ pred = pred.unsqueeze(1)
113
+ one_channel_mask = True
114
+ elif pred.dim() == 4:
115
+ # N×C×H×W (logits) -> argmax later
116
+ one_channel_mask = (pred.shape[1] == 1)
117
+
118
+
119
  pred = torch.nn.functional.interpolate(pred.float(), shape[2:], mode='bilinear', align_corners=True)
120
  output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
121
  alpha = 0.65