Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| from PIL import Image | |
| import torch | |
| import yaml | |
| import cv2 | |
| import importlib | |
| import numpy as np | |
| from tqdm import tqdm | |
| from inpainter.util.tensor_util import resize_frames, resize_masks | |
| class BaseInpainter: | |
| def __init__(self, E2FGVI_checkpoint, device) -> None: | |
| """ | |
| E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support) | |
| """ | |
| net = importlib.import_module('inpainter.model.e2fgvi_hq') | |
| self.model = net.InpaintGenerator().to(device) | |
| self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device)) | |
| self.model.eval() | |
| self.device = device | |
| # load configurations | |
| with open("inpainter/config/config.yaml", 'r') as stream: | |
| config = yaml.safe_load(stream) | |
| self.neighbor_stride = config['neighbor_stride'] | |
| self.num_ref = config['num_ref'] | |
| self.step = config['step'] | |
| # sample reference frames from the whole video | |
| def get_ref_index(self, f, neighbor_ids, length): | |
| ref_index = [] | |
| if self.num_ref == -1: | |
| for i in range(0, length, self.step): | |
| if i not in neighbor_ids: | |
| ref_index.append(i) | |
| else: | |
| start_idx = max(0, f - self.step * (self.num_ref // 2)) | |
| end_idx = min(length, f + self.step * (self.num_ref // 2)) | |
| for i in range(start_idx, end_idx + 1, self.step): | |
| if i not in neighbor_ids: | |
| if len(ref_index) > self.num_ref: | |
| break | |
| ref_index.append(i) | |
| return ref_index | |
| def inpaint(self, frames, masks, dilate_radius=15, ratio=1): | |
| """ | |
| frames: numpy array, T, H, W, 3 | |
| masks: numpy array, T, H, W | |
| dilate_radius: radius when applying dilation on masks | |
| ratio: down-sample ratio | |
| Output: | |
| inpainted_frames: numpy array, T, H, W, 3 | |
| """ | |
| assert frames.shape[:3] == masks.shape, 'different size between frames and masks' | |
| assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]' | |
| masks = masks.copy() | |
| masks = np.clip(masks, 0, 1) | |
| kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius)) | |
| masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0) | |
| T, H, W = masks.shape | |
| masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1 | |
| # size: (w, h) | |
| if ratio == 1: | |
| size = None | |
| binary_masks = masks | |
| else: | |
| size = [int(W*ratio), int(H*ratio)] | |
| size = [si+1 if si%2>0 else si for si in size] # only consider even values | |
| # shortest side should be larger than 50 | |
| if min(size) < 50: | |
| ratio = 50. / min(H, W) | |
| size = [int(W*ratio), int(H*ratio)] | |
| binary_masks = resize_masks(masks, tuple(size)) | |
| frames = resize_frames(frames, tuple(size)) # T, H, W, 3 | |
| # frames and binary_masks are numpy arrays | |
| h, w = frames.shape[1:3] | |
| video_length = T | |
| # convert to tensor | |
| imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1 | |
| masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0) | |
| imgs, masks = imgs.to(self.device), masks.to(self.device) | |
| comp_frames = [None] * video_length | |
| for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'): | |
| neighbor_ids = [ | |
| i for i in range(max(0, f - self.neighbor_stride), | |
| min(video_length, f + self.neighbor_stride + 1)) | |
| ] | |
| ref_ids = self.get_ref_index(f, neighbor_ids, video_length) | |
| selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :] | |
| selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :] | |
| with torch.no_grad(): | |
| masked_imgs = selected_imgs * (1 - selected_masks) | |
| mod_size_h = 60 | |
| mod_size_w = 108 | |
| h_pad = (mod_size_h - h % mod_size_h) % mod_size_h | |
| w_pad = (mod_size_w - w % mod_size_w) % mod_size_w | |
| masked_imgs = torch.cat( | |
| [masked_imgs, torch.flip(masked_imgs, [3])], | |
| 3)[:, :, :, :h + h_pad, :] | |
| masked_imgs = torch.cat( | |
| [masked_imgs, torch.flip(masked_imgs, [4])], | |
| 4)[:, :, :, :, :w + w_pad] | |
| pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids)) | |
| pred_imgs = pred_imgs[:, :, :h, :w] | |
| pred_imgs = (pred_imgs + 1) / 2 | |
| pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255 | |
| for i in range(len(neighbor_ids)): | |
| idx = neighbor_ids[i] | |
| img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * ( | |
| 1 - binary_masks[idx]) | |
| if comp_frames[idx] is None: | |
| comp_frames[idx] = img | |
| else: | |
| comp_frames[idx] = comp_frames[idx].astype( | |
| np.float32) * 0.5 + img.astype(np.float32) * 0.5 | |
| inpainted_frames = np.stack(comp_frames, 0) | |
| return inpainted_frames.astype(np.uint8) | |
| if __name__ == '__main__': | |
| frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg')) | |
| frame_path.sort() | |
| mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png")) | |
| mask_path.sort() | |
| save_path = '/ssd1/gaomingqi/results/inpainting/parkour' | |
| if not os.path.exists(save_path): | |
| os.mkdir(save_path) | |
| frames = [] | |
| masks = [] | |
| for fid, mid in zip(frame_path, mask_path): | |
| frames.append(Image.open(fid).convert('RGB')) | |
| masks.append(Image.open(mid).convert('P')) | |
| frames = np.stack(frames, 0) | |
| masks = np.stack(masks, 0) | |
| # ---------------------------------------------- | |
| # how to use | |
| # ---------------------------------------------- | |
| # 1/3: set checkpoint and device | |
| checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth' | |
| device = 'cuda:6' | |
| # 2/3: initialise inpainter | |
| base_inpainter = BaseInpainter(checkpoint, device) | |
| # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W) | |
| # ratio: (0, 1], ratio for down sample, default value is 1 | |
| inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3 | |
| # ---------------------------------------------- | |
| # end | |
| # ---------------------------------------------- | |
| # save | |
| for ti, inpainted_frame in enumerate(inpainted_frames): | |
| frame = Image.fromarray(inpainted_frame).convert('RGB') | |
| frame.save(os.path.join(save_path, f'{ti:05d}.jpg')) | |