Spaces:
Paused
Paused
| import os | |
| import torch | |
| from torch.nn import functional as F | |
| # from .model.pytorch_msssim import ssim_matlab | |
| from .ssim import ssim_matlab | |
| from .RIFE_HDv3 import Model | |
| def get_frame(frames, frame_no): | |
| if frame_no >= frames.shape[1]: | |
| return None | |
| frame = (frames[:, frame_no] + 1) /2 | |
| frame = frame.clip(0., 1.) | |
| return frame | |
| def add_frame(frames, frame, h, w): | |
| frame = (frame * 2) - 1 | |
| frame = frame.clip(-1., 1.) | |
| frame = frame.squeeze(0) | |
| frame = frame[:, :h, :w] | |
| frame = frame.unsqueeze(1) | |
| frames.append(frame.cpu()) | |
| def process_frames(model, device, frames, exp): | |
| pos = 0 | |
| output_frames = [] | |
| lastframe = get_frame(frames, 0) | |
| _, h, w = lastframe.shape | |
| scale = 1 | |
| fp16 = False | |
| def make_inference(I0, I1, n): | |
| middle = model.inference(I0, I1, scale) | |
| if n == 1: | |
| return [middle] | |
| first_half = make_inference(I0, middle, n=n//2) | |
| second_half = make_inference(middle, I1, n=n//2) | |
| if n%2: | |
| return [*first_half, middle, *second_half] | |
| else: | |
| return [*first_half, *second_half] | |
| tmp = max(32, int(32 / scale)) | |
| ph = ((h - 1) // tmp + 1) * tmp | |
| pw = ((w - 1) // tmp + 1) * tmp | |
| padding = (0, pw - w, 0, ph - h) | |
| def pad_image(img): | |
| if(fp16): | |
| return F.pad(img, padding).half() | |
| else: | |
| return F.pad(img, padding) | |
| I1 = lastframe.to(device, non_blocking=True).unsqueeze(0) | |
| I1 = pad_image(I1) | |
| temp = None # save lastframe when processing static frame | |
| while True: | |
| if temp is not None: | |
| frame = temp | |
| temp = None | |
| else: | |
| pos += 1 | |
| frame = get_frame(frames, pos) | |
| if frame is None: | |
| break | |
| I0 = I1 | |
| I1 = frame.to(device, non_blocking=True).unsqueeze(0) | |
| I1 = pad_image(I1) | |
| I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) | |
| I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) | |
| ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) | |
| break_flag = False | |
| if ssim > 0.996 or pos > 100: | |
| pos += 1 | |
| frame = get_frame(frames, pos) | |
| if frame is None: | |
| break_flag = True | |
| frame = lastframe | |
| else: | |
| temp = frame | |
| I1 = frame.to(device, non_blocking=True).unsqueeze(0) | |
| I1 = pad_image(I1) | |
| I1 = model.inference(I0, I1, scale) | |
| I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) | |
| ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) | |
| frame = I1[0][:, :h, :w] | |
| if ssim < 0.2: | |
| output = [] | |
| for _ in range((2 ** exp) - 1): | |
| output.append(I0) | |
| else: | |
| output = make_inference(I0, I1, 2**exp-1) if exp else [] | |
| add_frame(output_frames, lastframe, h, w) | |
| for mid in output: | |
| add_frame(output_frames, mid, h, w) | |
| lastframe = frame | |
| if break_flag: | |
| break | |
| add_frame(output_frames, lastframe, h, w) | |
| return torch.cat( output_frames, dim=1) | |
| def temporal_interpolation(model_path, frames, exp, device ="cuda"): | |
| model = Model() | |
| model.load_model(model_path, -1, device=device) | |
| model.eval() | |
| model.to(device=device) | |
| with torch.no_grad(): | |
| output = process_frames(model, device, frames.float(), exp) | |
| return output | |