Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
def parse_flownetc(modules, weights, biases): | |
keys = [ | |
'conv1', | |
'conv2', | |
'conv3', | |
'conv_redir', | |
'conv3_1', | |
'conv4', | |
'conv4_1', | |
'conv5', | |
'conv5_1', | |
'conv6', | |
'conv6_1', | |
'deconv5', | |
'deconv4', | |
'deconv3', | |
'deconv2', | |
'Convolution1', | |
'Convolution2', | |
'Convolution3', | |
'Convolution4', | |
'Convolution5', | |
'upsample_flow6to5', | |
'upsample_flow5to4', | |
'upsample_flow4to3', | |
'upsample_flow3to2', | |
] | |
i = 0 | |
for m in modules: | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
weight = weights[keys[i]].copy() | |
bias = biases[keys[i]].copy() | |
if keys[i] == 'conv1': | |
m.weight.data[:, :, :, :] = torch.from_numpy( | |
np.flip(weight, axis=1).copy()) | |
m.bias.data[:] = torch.from_numpy(bias) | |
else: | |
m.weight.data[:, :, :, :] = torch.from_numpy(weight) | |
m.bias.data[:] = torch.from_numpy(bias) | |
i = i + 1 | |
return | |
def parse_flownets(modules, weights, biases, param_prefix='net2_'): | |
keys = [ | |
'conv1', | |
'conv2', | |
'conv3', | |
'conv3_1', | |
'conv4', | |
'conv4_1', | |
'conv5', | |
'conv5_1', | |
'conv6', | |
'conv6_1', | |
'deconv5', | |
'deconv4', | |
'deconv3', | |
'deconv2', | |
'predict_conv6', | |
'predict_conv5', | |
'predict_conv4', | |
'predict_conv3', | |
'predict_conv2', | |
'upsample_flow6to5', | |
'upsample_flow5to4', | |
'upsample_flow4to3', | |
'upsample_flow3to2', | |
] | |
for i, k in enumerate(keys): | |
if 'upsample' in k: | |
keys[i] = param_prefix + param_prefix + k | |
else: | |
keys[i] = param_prefix + k | |
i = 0 | |
for m in modules: | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
weight = weights[keys[i]].copy() | |
bias = biases[keys[i]].copy() | |
if keys[i] == param_prefix + 'conv1': | |
m.weight.data[:, 0:3, :, :] = torch.from_numpy( | |
np.flip(weight[:, 0:3, :, :], axis=1).copy()) | |
m.weight.data[:, 3:6, :, :] = torch.from_numpy( | |
np.flip(weight[:, 3:6, :, :], axis=1).copy()) | |
m.weight.data[:, 6:9, :, :] = torch.from_numpy( | |
np.flip(weight[:, 6:9, :, :], axis=1).copy()) | |
m.weight.data[:, 9::, :, :] = torch.from_numpy( | |
weight[:, 9:, :, :].copy()) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
else: | |
m.weight.data[:, :, :, :] = torch.from_numpy(weight) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
i = i + 1 | |
return | |
def parse_flownetsonly(modules, weights, biases, param_prefix=''): | |
keys = [ | |
'conv1', | |
'conv2', | |
'conv3', | |
'conv3_1', | |
'conv4', | |
'conv4_1', | |
'conv5', | |
'conv5_1', | |
'conv6', | |
'conv6_1', | |
'deconv5', | |
'deconv4', | |
'deconv3', | |
'deconv2', | |
'Convolution1', | |
'Convolution2', | |
'Convolution3', | |
'Convolution4', | |
'Convolution5', | |
'upsample_flow6to5', | |
'upsample_flow5to4', | |
'upsample_flow4to3', | |
'upsample_flow3to2', | |
] | |
for i, k in enumerate(keys): | |
if 'upsample' in k: | |
keys[i] = param_prefix + param_prefix + k | |
else: | |
keys[i] = param_prefix + k | |
i = 0 | |
for m in modules: | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
weight = weights[keys[i]].copy() | |
bias = biases[keys[i]].copy() | |
if keys[i] == param_prefix + 'conv1': | |
# print ("%s :"%(keys[i]), m.weight.size(), m.bias.size(), | |
# tf_w[keys[i]].shape[::-1]) | |
m.weight.data[:, 0:3, :, :] = torch.from_numpy( | |
np.flip(weight[:, 0:3, :, :], axis=1).copy()) | |
m.weight.data[:, 3:6, :, :] = torch.from_numpy( | |
np.flip(weight[:, 3:6, :, :], axis=1).copy()) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
else: | |
m.weight.data[:, :, :, :] = torch.from_numpy(weight) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
i = i + 1 | |
return | |
def parse_flownetsd(modules, weights, biases, param_prefix='netsd_'): | |
keys = [ | |
'conv0', | |
'conv1', | |
'conv1_1', | |
'conv2', | |
'conv2_1', | |
'conv3', | |
'conv3_1', | |
'conv4', | |
'conv4_1', | |
'conv5', | |
'conv5_1', | |
'conv6', | |
'conv6_1', | |
'deconv5', | |
'deconv4', | |
'deconv3', | |
'deconv2', | |
'interconv5', | |
'interconv4', | |
'interconv3', | |
'interconv2', | |
'Convolution1', | |
'Convolution2', | |
'Convolution3', | |
'Convolution4', | |
'Convolution5', | |
'upsample_flow6to5', | |
'upsample_flow5to4', | |
'upsample_flow4to3', | |
'upsample_flow3to2', | |
] | |
for i, k in enumerate(keys): | |
keys[i] = param_prefix + k | |
i = 0 | |
for m in modules: | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
weight = weights[keys[i]].copy() | |
bias = biases[keys[i]].copy() | |
if keys[i] == param_prefix + 'conv0': | |
m.weight.data[:, 0:3, :, :] = torch.from_numpy( | |
np.flip(weight[:, 0:3, :, :], axis=1).copy()) | |
m.weight.data[:, 3:6, :, :] = torch.from_numpy( | |
np.flip(weight[:, 3:6, :, :], axis=1).copy()) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
else: | |
m.weight.data[:, :, :, :] = torch.from_numpy(weight) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
i = i + 1 | |
return | |
def parse_flownetfusion(modules, weights, biases, param_prefix='fuse_'): | |
keys = [ | |
'conv0', | |
'conv1', | |
'conv1_1', | |
'conv2', | |
'conv2_1', | |
'deconv1', | |
'deconv0', | |
'interconv1', | |
'interconv0', | |
'_Convolution5', | |
'_Convolution6', | |
'_Convolution7', | |
'upsample_flow2to1', | |
'upsample_flow1to0', | |
] | |
for i, k in enumerate(keys): | |
keys[i] = param_prefix + k | |
i = 0 | |
for m in modules: | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
weight = weights[keys[i]].copy() | |
bias = biases[keys[i]].copy() | |
if keys[i] == param_prefix + 'conv0': | |
m.weight.data[:, 0:3, :, :] = torch.from_numpy( | |
np.flip(weight[:, 0:3, :, :], axis=1).copy()) | |
m.weight.data[:, 3::, :, :] = torch.from_numpy( | |
weight[:, 3:, :, :].copy()) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
else: | |
m.weight.data[:, :, :, :] = torch.from_numpy(weight) | |
if m.bias is not None: | |
m.bias.data[:] = torch.from_numpy(bias) | |
i = i + 1 | |
return | |