venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn as nn
import torch.nn.functional as F
import types
from imaginaire.third_party.flow_net.flownet2 import models as \
flownet2_models
from imaginaire.third_party.flow_net.flownet2.utils import tools \
as flownet2_tools
from imaginaire.model_utils.fs_vid2vid import resample
from imaginaire.utils.io import get_checkpoint
class FlowNet(nn.Module):
def __init__(self, pretrained=True, fp16=False):
super().__init__()
flownet2_args = types.SimpleNamespace()
setattr(flownet2_args, 'fp16', fp16)
setattr(flownet2_args, 'rgb_max', 1.0)
if fp16:
print('FlowNet2 is running in fp16 mode.')
self.flowNet = flownet2_tools.module_to_dict(flownet2_models)[
'FlowNet2'](flownet2_args).to('cuda')
if pretrained:
flownet2_path = get_checkpoint('flownet2.pth.tar',
'1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da')
checkpoint = torch.load(flownet2_path,
map_location=torch.device('cpu'))
self.flowNet.load_state_dict(checkpoint['state_dict'])
self.flowNet.eval()
def forward(self, input_A, input_B):
size = input_A.size()
assert(len(size) == 4 or len(size) == 5 or len(size) == 6)
if len(size) >= 5:
if len(size) == 5:
b, n, c, h, w = size
else:
b, t, n, c, h, w = size
input_A = input_A.contiguous().view(-1, c, h, w)
input_B = input_B.contiguous().view(-1, c, h, w)
flow, conf = self.compute_flow_and_conf(input_A, input_B)
if len(size) == 5:
return flow.view(b, n, 2, h, w), conf.view(b, n, 1, h, w)
else:
return flow.view(b, t, n, 2, h, w), conf.view(b, t, n, 1, h, w)
else:
return self.compute_flow_and_conf(input_A, input_B)
def compute_flow_and_conf(self, im1, im2):
assert(im1.size()[1] == 3)
assert(im1.size() == im2.size())
old_h, old_w = im1.size()[2], im1.size()[3]
new_h, new_w = old_h // 64 * 64, old_w // 64 * 64
if old_h != new_h:
im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear',
align_corners=False)
im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear',
align_corners=False)
data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2)
with torch.no_grad():
flow1 = self.flowNet(data1)
# img_diff = torch.sum(abs(im1 - resample(im2, flow1)),
# dim=1, keepdim=True)
# conf = torch.clamp(1 - img_diff, 0, 1)
conf = (self.norm(im1 - resample(im2, flow1)) < 0.02).float()
# data2 = torch.cat([im2.unsqueeze(2), im1.unsqueeze(2)], dim=2)
# with torch.no_grad():
# flow2 = self.flowNet(data2)
# warped_flow2 = resample(flow2, flow1)
# flow_sum = self.norm(flow1 + warped_flow2)
# disocc = flow_sum > (0.05 * (self.norm(flow1) +
# self.norm(warped_flow2)) + 0.5)
# conf = 1 - disocc.float()
if old_h != new_h:
flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear',
align_corners=False) * old_h / new_h
conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear',
align_corners=False)
return flow1, conf
def norm(self, t):
return torch.sum(t * t, dim=1, keepdim=True)