File size: 3,662 Bytes
78360e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
from torch.nn import functional as F
# from .model.pytorch_msssim import ssim_matlab
from .ssim import ssim_matlab

from .RIFE_HDv3 import Model

def get_frame(frames, frame_no):
    if frame_no >= frames.shape[1]:
        return None
    frame = (frames[:, frame_no] + 1) /2
    frame = frame.clip(0., 1.)
    return frame

def add_frame(frames, frame, h, w):
    frame = (frame * 2) - 1
    frame = frame.clip(-1., 1.)    
    frame = frame.squeeze(0)
    frame = frame[:, :h, :w]
    frame = frame.unsqueeze(1)
    frames.append(frame.cpu())

def process_frames(model, device, frames, exp):
    pos = 0
    output_frames = []

    lastframe = get_frame(frames, 0)
    _,  h, w = lastframe.shape
    scale = 1
    fp16 = False

    def make_inference(I0, I1, n):
        middle = model.inference(I0, I1, scale)
        if n == 1:
            return [middle]
        first_half = make_inference(I0, middle, n=n//2)
        second_half = make_inference(middle, I1, n=n//2)
        if n%2:
            return [*first_half, middle, *second_half]
        else:
            return [*first_half, *second_half]

    tmp = max(32, int(32 / scale))
    ph = ((h - 1) // tmp + 1) * tmp
    pw = ((w - 1) // tmp + 1) * tmp
    padding = (0, pw - w, 0, ph - h)

    def pad_image(img):
        if(fp16):
            return F.pad(img, padding).half()
        else:
            return F.pad(img, padding)

    I1 = lastframe.to(device, non_blocking=True).unsqueeze(0)
    I1 = pad_image(I1)
    temp = None # save lastframe when processing static frame

    while True:
        if temp is not None:
            frame = temp
            temp = None
        else:
            pos += 1
            frame = get_frame(frames, pos)
        if frame is None:
            break
        I0 = I1
        I1 = frame.to(device, non_blocking=True).unsqueeze(0)
        I1 = pad_image(I1)
        I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
        I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
        ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])

        break_flag = False
        if ssim > 0.996 or pos > 100:        
            pos += 1
            frame = get_frame(frames, pos)
            if frame is None:
                break_flag = True
                frame = lastframe
            else:
                temp = frame
            I1 = frame.to(device, non_blocking=True).unsqueeze(0)
            I1 = pad_image(I1)
            I1 = model.inference(I0, I1, scale)
            I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
            ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
            frame = I1[0][:, :h, :w]
        
        if ssim < 0.2:
            output = []
            for _ in range((2 ** exp) - 1):
                output.append(I0)
        else:
            output = make_inference(I0, I1, 2**exp-1) if exp else []

        add_frame(output_frames, lastframe, h, w)
        for mid in output:
            add_frame(output_frames, mid, h, w)
        lastframe = frame
        if break_flag:
            break

    add_frame(output_frames, lastframe, h, w)
    return torch.cat( output_frames, dim=1)

def temporal_interpolation(model_path, frames, exp, device ="cuda"):

    model = Model()
    model.load_model(model_path, -1, device=device)

    model.eval()
    model.to(device=device)

    with torch.no_grad():    
        output = process_frames(model, device, frames.float(), exp)

    return output