Spaces:
Paused
Paused
Update web-demos/hugging_face/inpainter/base_inpainter.py
Browse files
web-demos/hugging_face/inpainter/base_inpainter.py
CHANGED
@@ -246,6 +246,7 @@ class ProInpainter:
|
|
246 |
gt_flows_f_list.append(flows_f)
|
247 |
gt_flows_b_list.append(flows_b)
|
248 |
torch.cuda.empty_cache()
|
|
|
249 |
|
250 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
251 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
@@ -253,6 +254,7 @@ class ProInpainter:
|
|
253 |
else:
|
254 |
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
255 |
torch.cuda.empty_cache()
|
|
|
256 |
|
257 |
if self.use_half:
|
258 |
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
|
@@ -279,7 +281,7 @@ class ProInpainter:
|
|
279 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
280 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
281 |
torch.cuda.empty_cache()
|
282 |
-
|
283 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
284 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
285 |
pred_flows_bi = (pred_flows_f, pred_flows_b)
|
@@ -287,6 +289,7 @@ class ProInpainter:
|
|
287 |
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
288 |
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
289 |
torch.cuda.empty_cache()
|
|
|
290 |
|
291 |
# ---- image propagation ----
|
292 |
masked_frames = frames * (1 - masks_dilated)
|
@@ -313,7 +316,7 @@ class ProInpainter:
|
|
313 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
314 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
315 |
torch.cuda.empty_cache()
|
316 |
-
|
317 |
updated_frames = torch.cat(updated_frames, dim=1)
|
318 |
updated_masks = torch.cat(updated_masks, dim=1)
|
319 |
else:
|
@@ -322,6 +325,7 @@ class ProInpainter:
|
|
322 |
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
323 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
324 |
torch.cuda.empty_cache()
|
|
|
325 |
|
326 |
ori_frames = frames_inp
|
327 |
comp_frames = [None] * video_length
|
@@ -369,6 +373,7 @@ class ProInpainter:
|
|
369 |
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
370 |
|
371 |
torch.cuda.empty_cache()
|
|
|
372 |
|
373 |
# need to return numpy array, T, H, W, 3
|
374 |
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
|
|
246 |
gt_flows_f_list.append(flows_f)
|
247 |
gt_flows_b_list.append(flows_b)
|
248 |
torch.cuda.empty_cache()
|
249 |
+
torch.cuda.ipc_collect()
|
250 |
|
251 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
252 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
|
|
254 |
else:
|
255 |
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
256 |
torch.cuda.empty_cache()
|
257 |
+
torch.cuda.ipc_collect()
|
258 |
|
259 |
if self.use_half:
|
260 |
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
|
|
|
281 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
282 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
283 |
torch.cuda.empty_cache()
|
284 |
+
torch.cuda.ipc_collect()
|
285 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
286 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
287 |
pred_flows_bi = (pred_flows_f, pred_flows_b)
|
|
|
289 |
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
290 |
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
291 |
torch.cuda.empty_cache()
|
292 |
+
torch.cuda.ipc_collect()
|
293 |
|
294 |
# ---- image propagation ----
|
295 |
masked_frames = frames * (1 - masks_dilated)
|
|
|
316 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
317 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
318 |
torch.cuda.empty_cache()
|
319 |
+
torch.cuda.ipc_collect()
|
320 |
updated_frames = torch.cat(updated_frames, dim=1)
|
321 |
updated_masks = torch.cat(updated_masks, dim=1)
|
322 |
else:
|
|
|
325 |
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
326 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
327 |
torch.cuda.empty_cache()
|
328 |
+
torch.cuda.ipc_collect()
|
329 |
|
330 |
ori_frames = frames_inp
|
331 |
comp_frames = [None] * video_length
|
|
|
373 |
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
374 |
|
375 |
torch.cuda.empty_cache()
|
376 |
+
torch.cuda.ipc_collect()
|
377 |
|
378 |
# need to return numpy array, T, H, W, 3
|
379 |
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|