Spaces:
Running
on
L4
Running
on
L4
Upload app.py
Browse files
app.py
CHANGED
@@ -58,7 +58,10 @@ def create_depth_demo(model, device):
|
|
58 |
image = F.pad(image, (0, 0, 40, 0))
|
59 |
with torch.no_grad():
|
60 |
pred = model(image)#['pred_d']
|
61 |
-
pred = torch.from_numpy(pred).to(device)
|
|
|
|
|
|
|
62 |
|
63 |
pred = pred[:,:,40:,:]
|
64 |
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
@@ -91,9 +94,28 @@ def create_refseg_demo(model, tokenizer, device):
|
|
91 |
image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
|
92 |
|
93 |
with torch.no_grad():
|
94 |
-
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
pred = torch.nn.functional.interpolate(pred.float(), shape[2:], mode='bilinear', align_corners=True)
|
98 |
output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
|
99 |
alpha = 0.65
|
|
|
58 |
image = F.pad(image, (0, 0, 40, 0))
|
59 |
with torch.no_grad():
|
60 |
pred = model(image)#['pred_d']
|
61 |
+
pred = torch.from_numpy(pred).to(device).float()
|
62 |
+
|
63 |
+
if pred.dim() == 2: # H×W
|
64 |
+
pred = pred.unsqueeze(0).unsqueeze(0)
|
65 |
|
66 |
pred = pred[:,:,40:,:]
|
67 |
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
|
|
94 |
image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
|
95 |
|
96 |
with torch.no_grad():
|
97 |
+
out = model(image_t, text)
|
98 |
|
99 |
+
if isinstance(out, np.ndarray):
|
100 |
+
pred = torch.from_numpy(out).to(device)
|
101 |
+
else:
|
102 |
+
pred = out
|
103 |
+
|
104 |
+
pred = pred.float()
|
105 |
+
|
106 |
+
if pred.dim() == 2:
|
107 |
+
# H×W mask -> N×C×H×W
|
108 |
+
pred = pred.unsqueeze(0).unsqueeze(0)
|
109 |
+
one_channel_mask = True
|
110 |
+
elif pred.dim() == 3:
|
111 |
+
# N×H×W -> add channel
|
112 |
+
pred = pred.unsqueeze(1)
|
113 |
+
one_channel_mask = True
|
114 |
+
elif pred.dim() == 4:
|
115 |
+
# N×C×H×W (logits) -> argmax later
|
116 |
+
one_channel_mask = (pred.shape[1] == 1)
|
117 |
+
|
118 |
+
|
119 |
pred = torch.nn.functional.interpolate(pred.float(), shape[2:], mode='bilinear', align_corners=True)
|
120 |
output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
|
121 |
alpha = 0.65
|