goryhon commited on
Commit
7fe2640
·
verified ·
1 Parent(s): f3aed33

Update web-demos/hugging_face/inpainter/base_inpainter.py

Browse files
web-demos/hugging_face/inpainter/base_inpainter.py CHANGED
@@ -246,6 +246,7 @@ class ProInpainter:
246
  gt_flows_f_list.append(flows_f)
247
  gt_flows_b_list.append(flows_b)
248
  torch.cuda.empty_cache()
 
249
 
250
  gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
251
  gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
@@ -253,6 +254,7 @@ class ProInpainter:
253
  else:
254
  gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
255
  torch.cuda.empty_cache()
 
256
 
257
  if self.use_half:
258
  frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
@@ -279,7 +281,7 @@ class ProInpainter:
279
  pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
280
  pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
281
  torch.cuda.empty_cache()
282
-
283
  pred_flows_f = torch.cat(pred_flows_f, dim=1)
284
  pred_flows_b = torch.cat(pred_flows_b, dim=1)
285
  pred_flows_bi = (pred_flows_f, pred_flows_b)
@@ -287,6 +289,7 @@ class ProInpainter:
287
  pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
288
  pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
289
  torch.cuda.empty_cache()
 
290
 
291
  # ---- image propagation ----
292
  masked_frames = frames * (1 - masks_dilated)
@@ -313,7 +316,7 @@ class ProInpainter:
313
  updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
314
  updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
315
  torch.cuda.empty_cache()
316
-
317
  updated_frames = torch.cat(updated_frames, dim=1)
318
  updated_masks = torch.cat(updated_masks, dim=1)
319
  else:
@@ -322,6 +325,7 @@ class ProInpainter:
322
  updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
323
  updated_masks = updated_local_masks.view(b, t, 1, h, w)
324
  torch.cuda.empty_cache()
 
325
 
326
  ori_frames = frames_inp
327
  comp_frames = [None] * video_length
@@ -369,6 +373,7 @@ class ProInpainter:
369
  comp_frames[idx] = comp_frames[idx].astype(np.uint8)
370
 
371
  torch.cuda.empty_cache()
 
372
 
373
  # need to return numpy array, T, H, W, 3
374
  comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
 
246
  gt_flows_f_list.append(flows_f)
247
  gt_flows_b_list.append(flows_b)
248
  torch.cuda.empty_cache()
249
+ torch.cuda.ipc_collect()
250
 
251
  gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
252
  gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
 
254
  else:
255
  gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
256
  torch.cuda.empty_cache()
257
+ torch.cuda.ipc_collect()
258
 
259
  if self.use_half:
260
  frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
 
281
  pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
282
  pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
283
  torch.cuda.empty_cache()
284
+ torch.cuda.ipc_collect()
285
  pred_flows_f = torch.cat(pred_flows_f, dim=1)
286
  pred_flows_b = torch.cat(pred_flows_b, dim=1)
287
  pred_flows_bi = (pred_flows_f, pred_flows_b)
 
289
  pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
290
  pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
291
  torch.cuda.empty_cache()
292
+ torch.cuda.ipc_collect()
293
 
294
  # ---- image propagation ----
295
  masked_frames = frames * (1 - masks_dilated)
 
316
  updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
317
  updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
318
  torch.cuda.empty_cache()
319
+ torch.cuda.ipc_collect()
320
  updated_frames = torch.cat(updated_frames, dim=1)
321
  updated_masks = torch.cat(updated_masks, dim=1)
322
  else:
 
325
  updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
326
  updated_masks = updated_local_masks.view(b, t, 1, h, w)
327
  torch.cuda.empty_cache()
328
+ torch.cuda.ipc_collect()
329
 
330
  ori_frames = frames_inp
331
  comp_frames = [None] * video_length
 
373
  comp_frames[idx] = comp_frames[idx].astype(np.uint8)
374
 
375
  torch.cuda.empty_cache()
376
+ torch.cuda.ipc_collect()
377
 
378
  # need to return numpy array, T, H, W, 3
379
  comp_frames = [cv2.resize(f, out_size) for f in comp_frames]