Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -106,8 +106,8 @@ def inference(ic_image, ic_mask, image1, image2):
|
|
| 106 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 107 |
|
| 108 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 109 |
-
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
| 110 |
-
|
| 111 |
predictor = SamPredictor(sam)
|
| 112 |
|
| 113 |
# Image features encoding
|
|
@@ -206,8 +206,8 @@ def inference_scribble(image, image1, image2):
|
|
| 206 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 207 |
|
| 208 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 209 |
-
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
| 210 |
-
|
| 211 |
predictor = SamPredictor(sam)
|
| 212 |
|
| 213 |
# Image features encoding
|
|
@@ -304,12 +304,12 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
|
|
| 304 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 305 |
|
| 306 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
| 307 |
-
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
| 308 |
-
|
| 309 |
|
| 310 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 311 |
-
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
| 312 |
-
|
| 313 |
for name, param in sam.named_parameters():
|
| 314 |
param.requires_grad = False
|
| 315 |
predictor = SamPredictor(sam)
|
|
@@ -347,8 +347,8 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
|
|
| 347 |
|
| 348 |
print('======> Start Training')
|
| 349 |
# Learnable mask weights
|
| 350 |
-
mask_weights = Mask_Weights().cuda()
|
| 351 |
-
|
| 352 |
mask_weights.train()
|
| 353 |
train_epoch = 1000
|
| 354 |
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
|
|
|
|
| 106 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 107 |
|
| 108 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 109 |
+
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
| 110 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
| 111 |
predictor = SamPredictor(sam)
|
| 112 |
|
| 113 |
# Image features encoding
|
|
|
|
| 206 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 207 |
|
| 208 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 209 |
+
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
| 210 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
| 211 |
predictor = SamPredictor(sam)
|
| 212 |
|
| 213 |
# Image features encoding
|
|
|
|
| 304 |
ic_mask = np.array(ic_mask.convert("RGB"))
|
| 305 |
|
| 306 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
| 307 |
+
# gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
| 308 |
+
gt_mask = gt_mask.float().unsqueeze(0).flatten(1)
|
| 309 |
|
| 310 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
| 311 |
+
# sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
| 312 |
+
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt)
|
| 313 |
for name, param in sam.named_parameters():
|
| 314 |
param.requires_grad = False
|
| 315 |
predictor = SamPredictor(sam)
|
|
|
|
| 347 |
|
| 348 |
print('======> Start Training')
|
| 349 |
# Learnable mask weights
|
| 350 |
+
# mask_weights = Mask_Weights().cuda()
|
| 351 |
+
mask_weights = Mask_Weights()
|
| 352 |
mask_weights.train()
|
| 353 |
train_epoch = 1000
|
| 354 |
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=1e-3, eps=1e-4)
|