Spaces:
Paused
Paused
Update web-demos/hugging_face/inpainter/base_inpainter.py
Browse files
web-demos/hugging_face/inpainter/base_inpainter.py
CHANGED
|
@@ -20,367 +20,367 @@ warnings.filterwarnings("ignore")
|
|
| 20 |
|
| 21 |
|
| 22 |
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
|
| 29 |
def resize_frames(frames, size=None):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
|
| 43 |
def read_frame_from_videos(frame_root):
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
|
| 64 |
def binary_mask(mask, th=0.1):
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
|
| 70 |
def extrapolation(video_ori, scale):
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
|
| 111 |
|
| 112 |
def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
| 128 |
|
| 129 |
def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5):
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
|
| 162 |
|
| 163 |
class ProInpainter:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
| 23 |
+
if auto_mkdir:
|
| 24 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
| 25 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 26 |
+
return cv2.imwrite(file_path, img, params)
|
| 27 |
|
| 28 |
|
| 29 |
def resize_frames(frames, size=None):
|
| 30 |
+
if size is not None:
|
| 31 |
+
out_size = size
|
| 32 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
| 33 |
+
frames = [f.resize(process_size) for f in frames]
|
| 34 |
+
else:
|
| 35 |
+
out_size = frames[0].size
|
| 36 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
| 37 |
+
if not out_size == process_size:
|
| 38 |
+
frames = [f.resize(process_size) for f in frames]
|
| 39 |
+
|
| 40 |
+
return frames, process_size, out_size
|
| 41 |
|
| 42 |
|
| 43 |
def read_frame_from_videos(frame_root):
|
| 44 |
+
if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
| 45 |
+
video_name = os.path.basename(frame_root)[:-4]
|
| 46 |
+
vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
|
| 47 |
+
frames = list(vframes.numpy())
|
| 48 |
+
frames = [Image.fromarray(f) for f in frames]
|
| 49 |
+
fps = info['video_fps']
|
| 50 |
+
else:
|
| 51 |
+
video_name = os.path.basename(frame_root)
|
| 52 |
+
frames = []
|
| 53 |
+
fr_lst = sorted(os.listdir(frame_root))
|
| 54 |
+
for fr in fr_lst:
|
| 55 |
+
frame = cv2.imread(os.path.join(frame_root, fr))
|
| 56 |
+
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 57 |
+
frames.append(frame)
|
| 58 |
+
fps = None
|
| 59 |
+
size = frames[0].size
|
| 60 |
+
|
| 61 |
+
return frames, fps, size, video_name
|
| 62 |
|
| 63 |
|
| 64 |
def binary_mask(mask, th=0.1):
|
| 65 |
+
mask[mask>th] = 1
|
| 66 |
+
mask[mask<=th] = 0
|
| 67 |
+
return mask
|
| 68 |
|
| 69 |
|
| 70 |
def extrapolation(video_ori, scale):
|
| 71 |
+
"""Prepares the data for video outpainting.
|
| 72 |
+
"""
|
| 73 |
+
nFrame = len(video_ori)
|
| 74 |
+
imgW, imgH = video_ori[0].size
|
| 75 |
+
|
| 76 |
+
# Defines new FOV.
|
| 77 |
+
imgH_extr = int(scale[0] * imgH)
|
| 78 |
+
imgW_extr = int(scale[1] * imgW)
|
| 79 |
+
imgH_extr = imgH_extr - imgH_extr % 8
|
| 80 |
+
imgW_extr = imgW_extr - imgW_extr % 8
|
| 81 |
+
H_start = int((imgH_extr - imgH) / 2)
|
| 82 |
+
W_start = int((imgW_extr - imgW) / 2)
|
| 83 |
+
|
| 84 |
+
# Extrapolates the FOV for video.
|
| 85 |
+
frames = []
|
| 86 |
+
for v in video_ori:
|
| 87 |
+
frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
|
| 88 |
+
frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
|
| 89 |
+
frames.append(Image.fromarray(frame))
|
| 90 |
+
|
| 91 |
+
# Generates the mask for missing region.
|
| 92 |
+
masks_dilated = []
|
| 93 |
+
flow_masks = []
|
| 94 |
+
|
| 95 |
+
dilate_h = 4 if H_start > 10 else 0
|
| 96 |
+
dilate_w = 4 if W_start > 10 else 0
|
| 97 |
+
mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
|
| 98 |
+
|
| 99 |
+
mask[H_start+dilate_h: H_start+imgH-dilate_h,
|
| 100 |
+
W_start+dilate_w: W_start+imgW-dilate_w] = 0
|
| 101 |
+
flow_masks.append(Image.fromarray(mask * 255))
|
| 102 |
+
|
| 103 |
+
mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
|
| 104 |
+
masks_dilated.append(Image.fromarray(mask * 255))
|
| 105 |
|
| 106 |
+
flow_masks = flow_masks * nFrame
|
| 107 |
+
masks_dilated = masks_dilated * nFrame
|
| 108 |
+
|
| 109 |
+
return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
|
| 110 |
|
| 111 |
|
| 112 |
def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
|
| 113 |
+
ref_index = []
|
| 114 |
+
if ref_num == -1:
|
| 115 |
+
for i in range(0, length, ref_stride):
|
| 116 |
+
if i not in neighbor_ids:
|
| 117 |
+
ref_index.append(i)
|
| 118 |
+
else:
|
| 119 |
+
start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
|
| 120 |
+
end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
|
| 121 |
+
for i in range(start_idx, end_idx, ref_stride):
|
| 122 |
+
if i not in neighbor_ids:
|
| 123 |
+
if len(ref_index) > ref_num:
|
| 124 |
+
break
|
| 125 |
+
ref_index.append(i)
|
| 126 |
+
return ref_index
|
| 127 |
|
| 128 |
|
| 129 |
def read_mask_demo(masks, length, size, flow_mask_dilates=8, mask_dilates=5):
|
| 130 |
+
masks_img = []
|
| 131 |
+
masks_dilated = []
|
| 132 |
+
flow_masks = []
|
| 133 |
+
|
| 134 |
+
for mp in masks:
|
| 135 |
+
masks_img.append(Image.fromarray(mp.astype('uint8')))
|
| 136 |
+
|
| 137 |
+
for mask_img in masks_img:
|
| 138 |
+
if size is not None:
|
| 139 |
+
mask_img = mask_img.resize(size, Image.NEAREST)
|
| 140 |
+
mask_img = np.array(mask_img.convert('L'))
|
| 141 |
+
|
| 142 |
+
# Dilate 8 pixel so that all known pixel is trustworthy
|
| 143 |
+
if flow_mask_dilates > 0:
|
| 144 |
+
flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
|
| 145 |
+
else:
|
| 146 |
+
flow_mask_img = binary_mask(mask_img).astype(np.uint8)
|
| 147 |
+
|
| 148 |
+
flow_masks.append(Image.fromarray(flow_mask_img * 255))
|
| 149 |
+
|
| 150 |
+
if mask_dilates > 0:
|
| 151 |
+
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
|
| 152 |
+
else:
|
| 153 |
+
mask_img = binary_mask(mask_img).astype(np.uint8)
|
| 154 |
+
masks_dilated.append(Image.fromarray(mask_img * 255))
|
| 155 |
+
|
| 156 |
+
if len(masks_img) == 1:
|
| 157 |
+
flow_masks = flow_masks * length
|
| 158 |
+
masks_dilated = masks_dilated * length
|
| 159 |
+
|
| 160 |
+
return flow_masks, masks_dilated
|
| 161 |
|
| 162 |
|
| 163 |
class ProInpainter:
|
| 164 |
+
def __init__(self, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, device="cuda:0", use_half=True):
|
| 165 |
+
self.device = device
|
| 166 |
+
self.use_half = use_half
|
| 167 |
+
if self.device == torch.device('cpu'):
|
| 168 |
+
self.use_half = False
|
| 169 |
+
|
| 170 |
+
##############################################
|
| 171 |
+
# set up RAFT and flow competition model
|
| 172 |
+
##############################################
|
| 173 |
+
self.fix_raft = RAFT_bi(raft_checkpoint, self.device)
|
| 174 |
+
|
| 175 |
+
self.fix_flow_complete = RecurrentFlowCompleteNet(flow_completion_checkpoint)
|
| 176 |
+
for p in self.fix_flow_complete.parameters():
|
| 177 |
+
p.requires_grad = False
|
| 178 |
+
self.fix_flow_complete.to(self.device)
|
| 179 |
+
self.fix_flow_complete.eval()
|
| 180 |
+
|
| 181 |
+
##############################################
|
| 182 |
+
# set up ProPainter model
|
| 183 |
+
##############################################
|
| 184 |
+
self.model = InpaintGenerator(model_path=propainter_checkpoint).to(self.device)
|
| 185 |
+
self.model.eval()
|
| 186 |
+
|
| 187 |
+
if self.use_half:
|
| 188 |
+
self.fix_flow_complete = self.fix_flow_complete.half()
|
| 189 |
+
self.model = self.model.half()
|
| 190 |
+
|
| 191 |
+
def inpaint(self, npframes, masks, ratio=1.0, dilate_radius=4, raft_iter=20, subvideo_length=80, neighbor_length=10, ref_stride=10):
|
| 192 |
+
"""
|
| 193 |
+
Perform Inpainting for video subsets
|
| 194 |
+
|
| 195 |
+
Output:
|
| 196 |
+
inpainted_frames: numpy array, T, H, W, 3
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
frames = []
|
| 200 |
+
for i in range(len(npframes)):
|
| 201 |
+
frames.append(Image.fromarray(npframes[i].astype('uint8'), mode="RGB"))
|
| 202 |
+
del npframes
|
| 203 |
+
|
| 204 |
+
# Получаем оригинальный размер
|
| 205 |
+
size = frames[0].size # (width, height)
|
| 206 |
+
|
| 207 |
+
# Применяем ratio, только если он отличается от 1.0
|
| 208 |
+
if ratio != 1.0:
|
| 209 |
+
size = (int(ratio * size[0]) // 2 * 2, int(ratio * size[1]) // 2 * 2)
|
| 210 |
+
else:
|
| 211 |
+
size = (size[0] // 2 * 2, size[1] // 2 * 2) # просто округляем до ближайшего чётного
|
| 212 |
+
|
| 213 |
+
frames_len = len(frames)
|
| 214 |
+
|
| 215 |
+
# ⚠️ resize_frames больше не меняет разрешение, если оно уже чётное
|
| 216 |
+
frames, size, out_size = resize_frames(frames, size)
|
| 217 |
+
|
| 218 |
+
flow_masks, masks_dilated = read_mask_demo(masks, frames_len, size, dilate_radius, dilate_radius)
|
| 219 |
+
w, h = size
|
| 220 |
+
|
| 221 |
+
frames_inp = [np.array(f).astype(np.uint8) for f in frames]
|
| 222 |
+
|
| 223 |
+
frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
|
| 224 |
+
flow_masks = to_tensors()(flow_masks).unsqueeze(0)
|
| 225 |
+
masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
|
| 226 |
+
|
| 227 |
+
frames = frames.to(self.device)
|
| 228 |
+
flow_masks = flow_masks.to(self.device)
|
| 229 |
+
masks_dilated = masks_dilated.to(self.device)
|
| 230 |
+
|
| 231 |
+
##############################################
|
| 232 |
+
# ProPainter inference
|
| 233 |
+
##############################################
|
| 234 |
+
video_length = frames.size(1)
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
# ---- compute flow ----
|
| 237 |
+
if frames.size(-1) <= 640:
|
| 238 |
+
short_clip_len = 12
|
| 239 |
+
elif frames.size(-1) <= 720:
|
| 240 |
+
short_clip_len = 8
|
| 241 |
+
elif frames.size(-1) <= 1280:
|
| 242 |
+
short_clip_len = 4
|
| 243 |
+
else:
|
| 244 |
+
short_clip_len = 2
|
| 245 |
+
|
| 246 |
+
# use fp32 for RAFT
|
| 247 |
+
if frames.size(1) > short_clip_len:
|
| 248 |
+
gt_flows_f_list, gt_flows_b_list = [], []
|
| 249 |
+
for f in range(0, video_length, short_clip_len):
|
| 250 |
+
end_f = min(video_length, f + short_clip_len)
|
| 251 |
+
if f == 0:
|
| 252 |
+
flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
|
| 253 |
+
else:
|
| 254 |
+
flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
|
| 255 |
+
|
| 256 |
+
gt_flows_f_list.append(flows_f)
|
| 257 |
+
gt_flows_b_list.append(flows_b)
|
| 258 |
+
torch.cuda.empty_cache()
|
| 259 |
+
|
| 260 |
+
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
| 261 |
+
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
| 262 |
+
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
| 263 |
+
else:
|
| 264 |
+
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
| 265 |
+
torch.cuda.empty_cache()
|
| 266 |
+
|
| 267 |
+
if self.use_half:
|
| 268 |
+
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
|
| 269 |
+
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
|
| 270 |
+
|
| 271 |
+
# ---- complete flow ----
|
| 272 |
+
flow_length = gt_flows_bi[0].size(1)
|
| 273 |
+
if flow_length > subvideo_length:
|
| 274 |
+
pred_flows_f, pred_flows_b = [], []
|
| 275 |
+
pad_len = 5
|
| 276 |
+
for f in range(0, flow_length, subvideo_length):
|
| 277 |
+
s_f = max(0, f - pad_len)
|
| 278 |
+
e_f = min(flow_length, f + subvideo_length + pad_len)
|
| 279 |
+
pad_len_s = max(0, f) - s_f
|
| 280 |
+
pad_len_e = e_f - min(flow_length, f + subvideo_length)
|
| 281 |
+
pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
|
| 282 |
+
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
|
| 283 |
+
flow_masks[:, s_f:e_f+1])
|
| 284 |
+
pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
|
| 285 |
+
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
|
| 286 |
+
pred_flows_bi_sub,
|
| 287 |
+
flow_masks[:, s_f:e_f+1])
|
| 288 |
+
|
| 289 |
+
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 290 |
+
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 291 |
+
torch.cuda.empty_cache()
|
| 292 |
+
|
| 293 |
+
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
| 294 |
+
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
| 295 |
+
pred_flows_bi = (pred_flows_f, pred_flows_b)
|
| 296 |
+
else:
|
| 297 |
+
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
| 298 |
+
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
| 299 |
+
torch.cuda.empty_cache()
|
| 300 |
+
|
| 301 |
+
# ---- image propagation ----
|
| 302 |
+
masked_frames = frames * (1 - masks_dilated)
|
| 303 |
+
subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
|
| 304 |
+
if video_length > subvideo_length_img_prop:
|
| 305 |
+
updated_frames, updated_masks = [], []
|
| 306 |
+
pad_len = 10
|
| 307 |
+
for f in range(0, video_length, subvideo_length_img_prop):
|
| 308 |
+
s_f = max(0, f - pad_len)
|
| 309 |
+
e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
|
| 310 |
+
pad_len_s = max(0, f) - s_f
|
| 311 |
+
pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
|
| 312 |
+
|
| 313 |
+
b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
|
| 314 |
+
pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
|
| 315 |
+
prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
|
| 316 |
+
pred_flows_bi_sub,
|
| 317 |
+
masks_dilated[:, s_f:e_f],
|
| 318 |
+
'nearest')
|
| 319 |
+
updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
|
| 320 |
+
prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
|
| 321 |
+
updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
|
| 322 |
+
|
| 323 |
+
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 324 |
+
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 325 |
+
torch.cuda.empty_cache()
|
| 326 |
+
|
| 327 |
+
updated_frames = torch.cat(updated_frames, dim=1)
|
| 328 |
+
updated_masks = torch.cat(updated_masks, dim=1)
|
| 329 |
+
else:
|
| 330 |
+
b, t, _, _, _ = masks_dilated.size()
|
| 331 |
+
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
| 332 |
+
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
| 333 |
+
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
| 334 |
+
torch.cuda.empty_cache()
|
| 335 |
+
|
| 336 |
+
ori_frames = frames_inp
|
| 337 |
+
comp_frames = [None] * video_length
|
| 338 |
+
|
| 339 |
+
neighbor_stride = neighbor_length // 2
|
| 340 |
+
if video_length > subvideo_length:
|
| 341 |
+
ref_num = subvideo_length // ref_stride
|
| 342 |
+
else:
|
| 343 |
+
ref_num = -1
|
| 344 |
+
|
| 345 |
+
# ---- feature propagation + transformer ----
|
| 346 |
+
for f in tqdm(range(0, video_length, neighbor_stride)):
|
| 347 |
+
neighbor_ids = [
|
| 348 |
+
i for i in range(max(0, f - neighbor_stride),
|
| 349 |
+
min(video_length, f + neighbor_stride + 1))
|
| 350 |
+
]
|
| 351 |
+
ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
|
| 352 |
+
selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
|
| 353 |
+
selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
|
| 354 |
+
selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
|
| 355 |
+
selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
|
| 356 |
+
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
# 1.0 indicates mask
|
| 359 |
+
l_t = len(neighbor_ids)
|
| 360 |
+
|
| 361 |
+
# pred_img = selected_imgs # results of image propagation
|
| 362 |
+
pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
|
| 363 |
+
|
| 364 |
+
pred_img = pred_img.view(-1, 3, h, w)
|
| 365 |
+
|
| 366 |
+
pred_img = (pred_img + 1) / 2
|
| 367 |
+
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
|
| 368 |
+
binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
|
| 369 |
+
0, 2, 3, 1).numpy().astype(np.uint8)
|
| 370 |
+
for i in range(len(neighbor_ids)):
|
| 371 |
+
idx = neighbor_ids[i]
|
| 372 |
+
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
|
| 373 |
+
+ ori_frames[idx] * (1 - binary_masks[i])
|
| 374 |
+
if comp_frames[idx] is None:
|
| 375 |
+
comp_frames[idx] = img
|
| 376 |
+
else:
|
| 377 |
+
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
| 378 |
+
|
| 379 |
+
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
| 380 |
+
|
| 381 |
+
torch.cuda.empty_cache()
|
| 382 |
+
|
| 383 |
+
# need to return numpy array, T, H, W, 3
|
| 384 |
+
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
| 385 |
+
|
| 386 |
+
return comp_frames
|