Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -114,6 +114,20 @@ GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
|
|
| 114 |
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
|
| 115 |
GLEEmodel_swin.eval()
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
|
| 118 |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
|
| 119 |
normalizer = lambda x: (x - pixel_mean) / pixel_std
|
|
@@ -130,16 +144,26 @@ TEXT_Y_OFFSET_SCALE = 1e-2
|
|
| 130 |
if inference_type != 'LSJ':
|
| 131 |
resizer = torchvision.transforms.Resize(inference_size,antialias=True)
|
| 132 |
videoresizer = torchvision.transforms.Resize(video_inference_size,antialias=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
def segment_image(img, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection):
|
| 136 |
torch.cuda.empty_cache()
|
| 137 |
if model_selection == 'GLEE-Plus (SwinL)':
|
| 138 |
GLEEmodel = GLEEmodel_swin
|
|
|
|
| 139 |
print('use GLEE-Plus')
|
| 140 |
-
|
|
|
|
| 141 |
GLEEmodel = GLEEmodel_r50
|
| 142 |
print('use GLEE-Lite')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
copyed_img = img['background'][:,:,:3].copy()
|
| 145 |
|
|
@@ -148,8 +172,12 @@ def segment_image(img, prompt_mode, categoryname, custom_category, expressiong,
|
|
| 148 |
_,_, ori_height, ori_width = ori_image.shape
|
| 149 |
|
| 150 |
if inference_type == 'LSJ':
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
else:
|
| 154 |
resize_image = resizer(ori_image)
|
| 155 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
|
@@ -309,8 +337,9 @@ def segment_image(img, prompt_mode, categoryname, custom_category, expressiong,
|
|
| 309 |
|
| 310 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
| 311 |
if inference_type == 'LSJ':
|
| 312 |
-
|
| 313 |
-
infer_visual_prompt
|
|
|
|
| 314 |
else:
|
| 315 |
resize_fakemask = resizer(fakemask)
|
| 316 |
if size_divisibility > 1:
|
|
@@ -377,8 +406,12 @@ def process_frames(frame_list):
|
|
| 377 |
_,_, ori_height, ori_width = ori_image.shape
|
| 378 |
|
| 379 |
if inference_type == 'LSJ':
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
else:
|
| 383 |
resize_image = videoresizer(ori_image)
|
| 384 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
|
@@ -414,14 +447,23 @@ def match_from_embds(tgt_embds, cur_embds):
|
|
| 414 |
def segment_video(video, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection,video_frames_select, prompter):
|
| 415 |
torch.cuda.empty_cache()
|
| 416 |
### model selection
|
|
|
|
|
|
|
| 417 |
if model_selection == 'GLEE-Plus (SwinL)':
|
| 418 |
GLEEmodel = GLEEmodel_swin
|
|
|
|
| 419 |
print('use GLEE-Plus')
|
| 420 |
clip_length = 2 #batchsize
|
| 421 |
-
|
|
|
|
| 422 |
GLEEmodel = GLEEmodel_r50
|
| 423 |
print('use GLEE-Lite')
|
| 424 |
clip_length = 4 #batchsize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
# read video and get sparse frames
|
| 427 |
cap = cv2.VideoCapture(video)
|
|
@@ -678,8 +720,9 @@ def segment_video(video, prompt_mode, categoryname, custom_category, expressiong
|
|
| 678 |
|
| 679 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
| 680 |
if inference_type == 'LSJ':
|
| 681 |
-
|
| 682 |
-
infer_visual_prompt
|
|
|
|
| 683 |
else:
|
| 684 |
resize_fakemask = videoresizer(fakemask)
|
| 685 |
if size_divisibility > 1:
|
|
|
|
| 114 |
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
|
| 115 |
GLEEmodel_swin.eval()
|
| 116 |
|
| 117 |
+
|
| 118 |
+
cfg_eva02 = get_cfg()
|
| 119 |
+
add_deeplab_config(cfg_eva02)
|
| 120 |
+
add_glee_config(cfg_eva02)
|
| 121 |
+
conf_files_swin = 'GLEE/configs/EVA02.yaml'
|
| 122 |
+
checkpoints_eva = torch.load('GLEE/GLEE_{}.pth'.format(args.version))
|
| 123 |
+
cfg_eva02.merge_from_file(conf_files_swin)
|
| 124 |
+
GLEEmodel_eva02 = GLEE_Model(cfg_eva02, None, device, None, True).to(device)
|
| 125 |
+
GLEEmodel_eva02.load_state_dict(checkpoints_eva, strict=False)
|
| 126 |
+
GLEEmodel_eva02.eval()
|
| 127 |
+
# inference_type = 'LSJ'
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
|
| 132 |
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
|
| 133 |
normalizer = lambda x: (x - pixel_mean) / pixel_std
|
|
|
|
| 144 |
if inference_type != 'LSJ':
|
| 145 |
resizer = torchvision.transforms.Resize(inference_size,antialias=True)
|
| 146 |
videoresizer = torchvision.transforms.Resize(video_inference_size,antialias=True)
|
| 147 |
+
else:
|
| 148 |
+
resizer = torchvision.transforms.Resize(size = 1535, max_size=1536, antialias=True)
|
| 149 |
+
videoresizer = torchvision.transforms.Resize(size = 1535, max_size=1536, antialias=True)
|
| 150 |
+
|
| 151 |
|
| 152 |
|
| 153 |
def segment_image(img, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection):
|
| 154 |
torch.cuda.empty_cache()
|
| 155 |
if model_selection == 'GLEE-Plus (SwinL)':
|
| 156 |
GLEEmodel = GLEEmodel_swin
|
| 157 |
+
inference_type = 'resize_shot'
|
| 158 |
print('use GLEE-Plus')
|
| 159 |
+
elif model_selection == 'GLEE-Lite (R50)':
|
| 160 |
+
inference_type = 'resize_shot'
|
| 161 |
GLEEmodel = GLEEmodel_r50
|
| 162 |
print('use GLEE-Lite')
|
| 163 |
+
else:
|
| 164 |
+
GLEEmodel = GLEEmodel_eva02
|
| 165 |
+
print('use GLEE-Pro')
|
| 166 |
+
inference_type = 'LSJ'
|
| 167 |
|
| 168 |
copyed_img = img['background'][:,:,:3].copy()
|
| 169 |
|
|
|
|
| 172 |
_,_, ori_height, ori_width = ori_image.shape
|
| 173 |
|
| 174 |
if inference_type == 'LSJ':
|
| 175 |
+
resize_image = resizer(ori_image)
|
| 176 |
+
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
| 177 |
+
re_size = resize_image.shape[-2:]
|
| 178 |
+
infer_image = torch.zeros(1,3,1536,1536).to(ori_image)
|
| 179 |
+
infer_image[:,:,:image_size[0],:image_size[1]] = resize_image
|
| 180 |
+
padding_size = (1536,1536)
|
| 181 |
else:
|
| 182 |
resize_image = resizer(ori_image)
|
| 183 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
|
|
|
| 337 |
|
| 338 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
| 339 |
if inference_type == 'LSJ':
|
| 340 |
+
resize_fakemask = resizer(fakemask)
|
| 341 |
+
infer_visual_prompt = torch.zeros(1,1536,1536).to(resize_fakemask)
|
| 342 |
+
infer_visual_prompt[:,:image_size[0],:image_size[1]] = resize_fakemask
|
| 343 |
else:
|
| 344 |
resize_fakemask = resizer(fakemask)
|
| 345 |
if size_divisibility > 1:
|
|
|
|
| 406 |
_,_, ori_height, ori_width = ori_image.shape
|
| 407 |
|
| 408 |
if inference_type == 'LSJ':
|
| 409 |
+
resize_image = resizer(ori_image)
|
| 410 |
+
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
| 411 |
+
re_size = resize_image.shape[-2:]
|
| 412 |
+
infer_image = torch.zeros(1,3,1536,1536).to(ori_image)
|
| 413 |
+
infer_image[:,:,:image_size[0],:image_size[1]] = resize_image
|
| 414 |
+
padding_size = (1536,1536)
|
| 415 |
else:
|
| 416 |
resize_image = videoresizer(ori_image)
|
| 417 |
image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
|
|
|
|
| 447 |
def segment_video(video, prompt_mode, categoryname, custom_category, expressiong, results_select, num_inst_select, threshold_select, mask_image_mix_ration, model_selection,video_frames_select, prompter):
|
| 448 |
torch.cuda.empty_cache()
|
| 449 |
### model selection
|
| 450 |
+
|
| 451 |
+
|
| 452 |
if model_selection == 'GLEE-Plus (SwinL)':
|
| 453 |
GLEEmodel = GLEEmodel_swin
|
| 454 |
+
inference_type = 'resize_shot'
|
| 455 |
print('use GLEE-Plus')
|
| 456 |
clip_length = 2 #batchsize
|
| 457 |
+
elif model_selection == 'GLEE-Lite (R50)':
|
| 458 |
+
inference_type = 'resize_shot'
|
| 459 |
GLEEmodel = GLEEmodel_r50
|
| 460 |
print('use GLEE-Lite')
|
| 461 |
clip_length = 4 #batchsize
|
| 462 |
+
else:
|
| 463 |
+
GLEEmodel = GLEEmodel_eva02
|
| 464 |
+
print('use GLEE-Pro')
|
| 465 |
+
inference_type = 'LSJ'
|
| 466 |
+
clip_length = 1 #batchsize
|
| 467 |
|
| 468 |
# read video and get sparse frames
|
| 469 |
cap = cv2.VideoCapture(video)
|
|
|
|
| 720 |
|
| 721 |
fakemask = torch.from_numpy(fakemask).unsqueeze(0).to(ori_image)
|
| 722 |
if inference_type == 'LSJ':
|
| 723 |
+
resize_fakemask = resizer(fakemask)
|
| 724 |
+
infer_visual_prompt = torch.zeros(1,1536,1536).to(resize_fakemask)
|
| 725 |
+
infer_visual_prompt[:,:image_size[0],:image_size[1]] = resize_fakemask
|
| 726 |
else:
|
| 727 |
resize_fakemask = videoresizer(fakemask)
|
| 728 |
if size_divisibility > 1:
|