Spaces:
Paused
Paused
| # -*- coding: utf-8 -*- | |
| import sys | |
| sys.path.append(".") | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import argparse | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from model.modules.flow_comp_raft import RAFT_bi | |
| from model.recurrent_flow_completion import RecurrentFlowCompleteNet | |
| from model.propainter import InpaintGenerator | |
| # from core.dataset import TestDataset | |
| from core.dataset import TestDataset | |
| from core.metrics import calc_psnr_and_ssim, calculate_i3d_activations, calculate_vfid, init_i3d_model | |
| from time import time | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # sample reference frames from the whole video | |
| def get_ref_index(neighbor_ids, length, ref_stride=10): | |
| ref_index = [] | |
| for i in range(0, length, ref_stride): | |
| if i not in neighbor_ids: | |
| ref_index.append(i) | |
| return ref_index | |
| def main_worker(args): | |
| args.size = (args.width, args.height) | |
| w, h = args.size | |
| # set up datasets and data loader | |
| assert (args.dataset == 'davis') or args.dataset == 'youtube-vos', \ | |
| f"{args.dataset} dataset is not supported" | |
| test_dataset = TestDataset(vars(args)) | |
| test_loader = DataLoader(test_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=args.num_workers) | |
| # set up models | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| fix_raft = RAFT_bi(args.raft_model_path, device) | |
| fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path) | |
| for p in fix_flow_complete.parameters(): | |
| p.requires_grad = False | |
| fix_flow_complete.to(device) | |
| fix_flow_complete.eval() | |
| model = InpaintGenerator(model_path=args.propainter_model_path).to(device) | |
| model.eval() | |
| time_all = [] | |
| print('Start evaluation ...') | |
| if args.task == 'video_completion': | |
| result_path = os.path.join(f'results_eval', | |
| f'{args.dataset}_rs_{args.ref_stride}_nl_{args.neighbor_length}_video_completion') | |
| if not os.path.exists(result_path): | |
| os.makedirs(result_path, exist_ok=True) | |
| eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"),"w") | |
| total_frame_psnr = [] | |
| total_frame_ssim = [] | |
| output_i3d_activations = [] | |
| real_i3d_activations = [] | |
| i3d_model = init_i3d_model('weights/i3d_rgb_imagenet.pt') | |
| else: | |
| result_path = os.path.join(f'results_eval', | |
| f'{args.dataset}_rs_{args.ref_stride}_nl_{args.neighbor_length}_object_removal') | |
| if not os.path.exists(result_path): | |
| os.makedirs(result_path, exist_ok=True) | |
| if not os.path.exists(result_path): | |
| os.makedirs(result_path) | |
| for index, items in enumerate(test_loader): | |
| torch.cuda.empty_cache() | |
| # frames, masks, video_name, frames_PIL = items | |
| frames, masks, flows_f, flows_b, video_name, frames_PIL = items | |
| video_name = video_name[0] | |
| print('Processing:', video_name) | |
| video_length = frames.size(1) | |
| frames, masks = frames.to(device), masks.to(device) | |
| masked_frames = frames * (1 - masks) | |
| torch.cuda.synchronize() | |
| time_start = time() | |
| with torch.no_grad(): | |
| # ---- compute flow ---- | |
| if args.load_flow: | |
| gt_flows_bi = (flows_f.to(device), flows_b.to(device)) | |
| else: | |
| short_len = 60 | |
| if frames.size(1) > short_len: | |
| gt_flows_f_list, gt_flows_b_list = [], [] | |
| for f in range(0, video_length, short_len): | |
| end_f = min(video_length, f + short_len) | |
| if f == 0: | |
| flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) | |
| else: | |
| flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) | |
| gt_flows_f_list.append(flows_f) | |
| gt_flows_b_list.append(flows_b) | |
| gt_flows_f = torch.cat(gt_flows_f_list, dim=1) | |
| gt_flows_b = torch.cat(gt_flows_b_list, dim=1) | |
| gt_flows_bi = (gt_flows_f, gt_flows_b) | |
| else: | |
| gt_flows_bi = fix_raft(frames, iters=args.raft_iter) | |
| # ---- complete flow ---- | |
| pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, masks) | |
| pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, masks) | |
| # ---- temporal propagation ---- | |
| prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks, 'nearest') | |
| b, t, _, _, _ = masks.size() | |
| updated_masks = updated_local_masks.view(b, t, 1, h, w) | |
| updated_frames = frames * (1-masks) + prop_imgs.view(b, t, 3, h, w) * masks # merge | |
| del gt_flows_bi, frames, updated_local_masks | |
| if not args.load_flow: | |
| torch.cuda.empty_cache() | |
| ori_frames = frames_PIL | |
| ori_frames = [ | |
| ori_frames[i].squeeze().cpu().numpy() for i in range(video_length) | |
| ] | |
| comp_frames = [None] * video_length | |
| # complete holes by our model | |
| neighbor_stride = args.neighbor_length // 2 | |
| for f in range(0, video_length, neighbor_stride): | |
| neighbor_ids = [ | |
| i for i in range(max(0, f - neighbor_stride), | |
| min(video_length, f + neighbor_stride + 1)) | |
| ] | |
| ref_ids = get_ref_index(neighbor_ids, video_length, args.ref_stride) | |
| selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] | |
| selected_masks = masks[:, neighbor_ids + ref_ids, :, :, :] | |
| selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] | |
| selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) | |
| with torch.no_grad(): | |
| l_t = len(neighbor_ids) | |
| pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) | |
| pred_img = pred_img.view(-1, 3, h, w) | |
| pred_img = (pred_img + 1) / 2 | |
| pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 | |
| binary_masks = masks[0, neighbor_ids, :, :, :].cpu().permute( | |
| 0, 2, 3, 1).numpy().astype(np.uint8) | |
| for i in range(len(neighbor_ids)): | |
| idx = neighbor_ids[i] | |
| img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ | |
| + ori_frames[idx] * (1 - binary_masks[i]) | |
| if comp_frames[idx] is None: | |
| comp_frames[idx] = img | |
| else: | |
| comp_frames[idx] = comp_frames[idx].astype( | |
| np.float32) * 0.5 + img.astype(np.float32) * 0.5 | |
| torch.cuda.synchronize() | |
| time_i = time() - time_start | |
| time_i = time_i*1.0/video_length | |
| time_all.append(time_i) | |
| if args.task == 'video_completion': | |
| # calculate metrics | |
| cur_video_psnr = [] | |
| cur_video_ssim = [] | |
| comp_PIL = [] # to calculate VFID | |
| frames_PIL = [] | |
| for ori, comp in zip(ori_frames, comp_frames): | |
| psnr, ssim = calc_psnr_and_ssim(ori, comp) | |
| cur_video_psnr.append(psnr) | |
| cur_video_ssim.append(ssim) | |
| total_frame_psnr.append(psnr) | |
| total_frame_ssim.append(ssim) | |
| frames_PIL.append(Image.fromarray(ori.astype(np.uint8))) | |
| comp_PIL.append(Image.fromarray(comp.astype(np.uint8))) | |
| # saving i3d activations | |
| frames_i3d, comp_i3d = calculate_i3d_activations(frames_PIL, | |
| comp_PIL, | |
| i3d_model, | |
| device=device) | |
| real_i3d_activations.append(frames_i3d) | |
| output_i3d_activations.append(comp_i3d) | |
| cur_psnr = sum(cur_video_psnr) / len(cur_video_psnr) | |
| cur_ssim = sum(cur_video_ssim) / len(cur_video_ssim) | |
| avg_psnr = sum(total_frame_psnr) / len(total_frame_psnr) | |
| avg_ssim = sum(total_frame_ssim) / len(total_frame_ssim) | |
| avg_time = sum(time_all) / len(time_all) | |
| print( | |
| f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f} \ | |
| | Avg PSNR/SSIM: {avg_psnr:.4f}/{avg_ssim:.4f} | Time: {avg_time:.4f}' | |
| ) | |
| eval_summary.write( | |
| f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | PSNR/SSIM: {cur_psnr:.4f}/{cur_ssim:.4f} \ | |
| | Avg PSNR/SSIM: {avg_psnr:.4f}/{avg_ssim:.4f} | Time: {avg_time:.4f}\n' | |
| ) | |
| else: | |
| avg_time = sum(time_all) / len(time_all) | |
| print( | |
| f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | Time: {avg_time:.4f}' | |
| ) | |
| # saving images for evaluating warpping errors | |
| if args.save_results: | |
| save_frame_path = os.path.join(result_path, video_name) | |
| if not os.path.exists(save_frame_path): | |
| os.makedirs(save_frame_path, exist_ok=False) | |
| for i, frame in enumerate(comp_frames): | |
| cv2.imwrite( | |
| os.path.join(save_frame_path, | |
| str(i).zfill(5) + '.png'), | |
| cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)) | |
| if args.task == 'video_completion': | |
| avg_frame_psnr = sum(total_frame_psnr) / len(total_frame_psnr) | |
| avg_frame_ssim = sum(total_frame_ssim) / len(total_frame_ssim) | |
| fid_score = calculate_vfid(real_i3d_activations, output_i3d_activations) | |
| print('Finish evaluation... Average Frame PSNR/SSIM/VFID: ' | |
| f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f} | Time: {avg_time:.4f}') | |
| eval_summary.write( | |
| 'Finish evaluation... Average Frame PSNR/SSIM/VFID: ' | |
| f'{avg_frame_psnr:.2f}/{avg_frame_ssim:.4f}/{fid_score:.3f} | Time: {avg_time:.4f}') | |
| eval_summary.close() | |
| else: | |
| print('Finish evaluation... Time: {avg_time:.4f}') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--height', type=int, default=240) | |
| parser.add_argument('--width', type=int, default=432) | |
| parser.add_argument("--ref_stride", type=int, default=10) | |
| parser.add_argument("--neighbor_length", type=int, default=20) | |
| parser.add_argument("--raft_iter", type=int, default=20) | |
| parser.add_argument('--task', default='video_completion', choices=['object_removal', 'video_completion']) | |
| parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str) | |
| parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str) | |
| parser.add_argument('--propainter_model_path', default='weights/ProPainter.pth', type=str) | |
| parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str) | |
| parser.add_argument('--video_root', default='dataset_root', type=str) | |
| parser.add_argument('--mask_root', default='mask_root', type=str) | |
| parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str) | |
| parser.add_argument('--load_flow', default=False, type=bool) | |
| parser.add_argument('--save_results', action='store_true') | |
| parser.add_argument('--num_workers', default=4, type=int) | |
| args = parser.parse_args() | |
| main_worker(args) | |