Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import absolute_import, division, print_function | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import ( | |
functional as F, | |
Conv2d, | |
LeakyReLU, | |
Upsample, | |
Sigmoid, | |
ConvTranspose2d, | |
) | |
class ConvBlock(nn.Module): | |
"""Layer to perform a convolution followed by ELU""" | |
def __init__(self, in_channels, out_channels): | |
super(ConvBlock, self).__init__() | |
self.conv = Conv3x3(in_channels, out_channels) | |
self.nonlin = nn.ELU(inplace=True) | |
def forward(self, x): | |
out = self.conv(x) | |
out = self.nonlin(out) | |
return out | |
class Conv3x3(nn.Module): | |
"""Layer to pad and convolve input""" | |
def __init__(self, in_channels, out_channels, use_refl=True): | |
super(Conv3x3, self).__init__() | |
if use_refl: | |
self.pad = nn.ReflectionPad2d(1) | |
else: | |
self.pad = nn.ZeroPad2d(1) | |
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) | |
def forward(self, x): | |
out = self.pad(x) | |
out = self.conv(out) | |
return out | |
class Backprojection(nn.Module): | |
def __init__(self, height, width): | |
super(Backprojection, self).__init__() | |
self.H, self.W = height, width | |
yy, xx = torch.meshgrid( | |
[torch.linspace(-1.0, 1.0, self.H), torch.linspace(-1.0, 1.0, self.W)] | |
) | |
yy = yy.contiguous().view(-1) | |
xx = xx.contiguous().view(-1) | |
self.ones = nn.Parameter( | |
torch.ones(1, 1, self.H * self.W, dtype=torch.float), requires_grad=False | |
) | |
self.coord = torch.unsqueeze(torch.stack([xx, yy], 0), 0) | |
self.coord = nn.Parameter( | |
torch.cat([self.coord, self.ones], 1), requires_grad=False | |
) | |
def forward(self, depth, inv_K): | |
n = depth.shape[0] | |
cam_p_norm = torch.matmul(inv_K[:, :3, :3], self.coord.expand(n, -1, -1)) | |
cam_p_euc = depth.view(n, 1, -1) * cam_p_norm | |
cam_p_h = torch.cat([cam_p_euc, self.ones.expand(n,e -1, -1)], 1) | |
return cam_p_h | |
def point_projection(points3D, batch_size, height, width, K, T): | |
N, H, W = batch_size, height, width | |
cam_coord = torch.matmul(torch.matmul(K, T[:, :3, :]), points3D) | |
img_coord = cam_coord[:, :2, :] / (cam_coord[:, 2:3, :] + 1e-7) | |
img_coord = img_coord.view(N, 2, H, W).permute(0, 2, 3, 1) | |
return img_coord, cam_coord[:, 2, :] | |
def upsample(x): | |
"""Upsample input tensor by a factor of 2""" | |
return F.interpolate(x, scale_factor=2, mode="nearest") | |
class GaussianAverage(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.window = torch.Tensor( | |
[ | |
[0.0947, 0.1183, 0.0947], | |
[0.1183, 0.1478, 0.1183], | |
[0.0947, 0.1183, 0.0947], | |
] | |
) | |
def forward(self, x): | |
kernel = self.window.to(x.device).to(x.dtype).repeat(x.shape[1], 1, 1, 1) | |
return F.conv2d(x, kernel, padding=0, groups=x.shape[1]) | |
class SSIM(nn.Module): | |
"""Layer to compute the SSIM loss between a pair of images""" | |
def __init__( | |
self, | |
pad_reflection=True, | |
gaussian_average=False, | |
comp_mode=False, | |
eval_mode=False, | |
): | |
super(SSIM, self).__init__() | |
self.comp_mode = comp_mode | |
self.eval_mode = eval_mode | |
if not gaussian_average: | |
self.mu_x_pool = nn.AvgPool2d(3, 1) | |
self.mu_y_pool = nn.AvgPool2d(3, 1) | |
self.sig_x_pool = nn.AvgPool2d(3, 1) | |
self.sig_y_pool = nn.AvgPool2d(3, 1) | |
self.sig_xy_pool = nn.AvgPool2d(3, 1) | |
else: | |
self.mu_x_pool = GaussianAverage() | |
self.mu_y_pool = GaussianAverage() | |
self.sig_x_pool = GaussianAverage() | |
self.sig_y_pool = GaussianAverage() | |
self.sig_xy_pool = GaussianAverage() | |
if pad_reflection: | |
self.pad = nn.ReflectionPad2d(1) | |
else: | |
self.pad = nn.ZeroPad2d(1) | |
self.C1 = 0.01**2 | |
self.C2 = 0.03**2 | |
def forward(self, x, y, pad=True): | |
if pad: | |
x = self.pad(x) | |
y = self.pad(y) | |
## average of pixels in x and y, average pooling or Gaussian averaging, based on the initialization | |
mu_x = self.mu_x_pool(x) | |
mu_y = self.mu_y_pool(y) | |
mu_x_sq = ( | |
mu_x**2 | |
) ## squares of the averages and the product of the averages, respectively. | |
mu_y_sq = mu_y**2 | |
mu_x_y = mu_x * mu_y | |
## variances and covariance: | |
sigma_x = self.sig_x_pool(x**2) - mu_x_sq | |
sigma_y = self.sig_y_pool(y**2) - mu_y_sq | |
sigma_xy = self.sig_xy_pool(x * y) - mu_x_y | |
SSIM_n = (2 * mu_x_y + self.C1) * (2 * sigma_xy + self.C2) | |
SSIM_d = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x + sigma_y + self.C2) | |
if ( | |
not self.eval_mode | |
): ## determines how to handle the output of the SSIM calculation | |
if ( | |
not self.comp_mode | |
): ## error (1 - SSIM index), used as a loss function during training | |
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) | |
else: ## return the SSIM index itself. | |
return torch.clamp((1 - SSIM_n / SSIM_d), 0, 1) / 2 | |
else: | |
return ( | |
SSIM_n / SSIM_d | |
) ## returned error is scaled to range between 0 and 1 for easier interpretation and usage in loss calculations | |
## (2*mu_x*mu_y + C1)*(2*sigma_xy + C2) / ((mu_x^2 + mu_y^2 + C1)*(sigma_x + sigma_y + C2)). | |
## SSIM index ranges from -1 to 1, where 1 means the images are identical, -1 means the images are totally different, and 0 means the images are not correlated | |
class GEO(nn.Module): | |
"""Layer to compute the pseudo label, L_{geo}, loss between a pair of images""" | |
def __init__( | |
self, | |
pad_reflection=True, | |
gaussian_average=False, | |
comp_mode=False, | |
eval_mode=False, | |
): | |
super(GEO, self).__init__() | |
self.comp_mode = comp_mode | |
self.eval_mode = eval_mode | |
if not gaussian_average: | |
self.mu_x_pool = nn.AvgPool2d(3, 1) | |
self.mu_y_pool = nn.AvgPool2d(3, 1) | |
self.sig_x_pool = nn.AvgPool2d(3, 1) | |
self.sig_y_pool = nn.AvgPool2d(3, 1) | |
self.sig_xy_pool = nn.AvgPool2d(3, 1) | |
else: | |
self.mu_x_pool = GaussianAverage() | |
self.mu_y_pool = GaussianAverage() | |
self.sig_x_pool = GaussianAverage() | |
self.sig_y_pool = GaussianAverage() | |
self.sig_xy_pool = GaussianAverage() | |
if pad_reflection: | |
self.pad = nn.ReflectionPad2d(1) | |
else: | |
self.pad = nn.ZeroPad2d(1) | |
self.C1 = 0.01**2 | |
self.C2 = 0.03**2 | |
def forward(self, x, y, pad=True): | |
if pad: | |
x = self.pad(x) | |
y = self.pad(y) | |
## average of pixels in x and y, average pooling or Gaussian averaging, based on the initialization | |
mu_x = self.mu_x_pool(x) | |
mu_y = self.mu_y_pool(y) | |
mu_x_sq = ( | |
mu_x**2 | |
) ## squares of the averages and the product of the averages, respectively. | |
mu_y_sq = mu_y**2 | |
mu_x_y = mu_x * mu_y | |
## variances and covariance: | |
sigma_x = self.sig_x_pool(x**2) - mu_x_sq | |
sigma_y = self.sig_y_pool(y**2) - mu_y_sq | |
sigma_xy = self.sig_xy_pool(x * y) - mu_x_y | |
SSIM_n = (2 * mu_x_y + self.C1) * (2 * sigma_xy + self.C2) | |
SSIM_d = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x + sigma_y + self.C2) | |
if ( | |
not self.eval_mode | |
): ## determines how to handle the output of the SSIM calculation | |
if ( | |
not self.comp_mode | |
): ## error (1 - SSIM index), used as a loss function during training | |
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) | |
else: ## return the SSIM index itself. | |
return torch.clamp((1 - SSIM_n / SSIM_d), 0, 1) / 2 | |
else: | |
return ( | |
SSIM_n / SSIM_d | |
) ## returned error is scaled to range between 0 and 1 for easier interpretation and usage in loss calculations | |
## (2*mu_x*mu_y + C1)*(2*sigma_xy + C2) / ((mu_x^2 + mu_y^2 + C1)*(sigma_x + sigma_y + C2)). | |
## SSIM index ranges from -1 to 1, where 1 means the images are identical, -1 means the images are totally different, and 0 means the images are not correlated | |
def ssim( | |
x, | |
y, | |
pad_reflection=True, | |
gaussian_average=False, | |
comp_mode=False, | |
eval_mode=False, | |
pad=True, | |
): | |
ssim_ = SSIM(pad_reflection, gaussian_average, comp_mode, eval_mode) | |
return ssim_(x, y, pad=pad) | |
def geo( | |
x, | |
y, | |
pad_reflection=True, | |
gaussian_average=False, | |
comp_mode=False, | |
eval_mode=False, | |
pad=True, | |
): | |
geo_ = GEO(pad_reflection, gaussian_average, comp_mode, eval_mode) | |
return geo_(x, y, pad=pad) | |
# NOTE: is the following used? Ask Felix | |
class ResidualImage(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.residual_image = ResidualImageModule() | |
def forward( | |
self, | |
keyframe: torch.Tensor, | |
keyframe_pose: torch.Tensor, | |
keyframe_intrinsics: torch.Tensor, | |
depths: torch.Tensor, | |
frames: list, | |
poses: list, | |
intrinsics: list, | |
): | |
data_dict = { | |
"keyframe": keyframe, | |
"keyframe_pose": keyframe_pose, | |
"keyframe_intrinsics": keyframe_intrinsics, | |
"predicted_inverse_depths": [depths], | |
"frames": frames, | |
"poses": poses, | |
"list": list, | |
"intrinsics": intrinsics, | |
"inv_depth_max": 0, | |
"inv_depth_min": 1, | |
} | |
data_dict = self.residual_image(data_dict) | |
return data_dict["residual_image"] | |
class ResidualImageModule(nn.Module): | |
def __init__(self, use_mono=True, use_stereo=False): | |
super().__init__() | |
self.use_mono = use_mono | |
self.use_stereo = use_stereo | |
self.ssim = SSIM() | |
def forward(self, data_dict): | |
keyframe = data_dict["keyframe"] | |
keyframe_intrinsics = data_dict["keyframe_intrinsics"] | |
keyframe_pose = data_dict["keyframe_pose"] | |
depths = (1 - data_dict["predicted_inverse_depths"][0]) * data_dict[ | |
"inv_depth_max" | |
] + data_dict["predicted_inverse_depths"][0] * data_dict["inv_depth_min"] | |
frames = [] | |
intrinsics = [] | |
poses = [] | |
if self.use_mono: | |
frames += data_dict["frames"] | |
intrinsics += data_dict["intrinsics"] | |
poses += data_dict["poses"] | |
if self.use_stereo: | |
frames += [data_dict["stereoframe"]] | |
intrinsics += [data_dict["stereoframe_intrinsics"]] | |
poses += [data_dict["stereoframe_pose"]] | |
n, c, h, w = keyframe.shape | |
backproject_depth = Backprojection(n, h, w) | |
backproject_depth.to(keyframe.device) | |
inv_k = torch.inverse(keyframe_intrinsics) | |
cam_points = inv_k[:, :3, :3] @ backproject_depth.pix_coords | |
cam_points = cam_points / depths.view(n, 1, -1) | |
cam_points = torch.cat([cam_points, backproject_depth.ones], 1) | |
masks = [] | |
residuals = [] | |
for i, image in enumerate(frames): | |
t = torch.inverse(poses[i]) @ keyframe_pose | |
pix_coords = point_projection(cam_points, n, h, w, intrinsics[i], t) | |
warped_image = F.grid_sample(image + 1, pix_coords) | |
mask = torch.any(warped_image == 0, dim=1, keepdim=True) | |
warped_image -= 0.5 | |
residual = self.ssim(warped_image, keyframe + 0.5) | |
masks.append(mask) | |
residuals.append(residual) | |
masks = torch.stack(masks, dim=1) | |
residuals = torch.stack(residuals, dim=1) | |
residuals[masks.expand(-1, -1, c, -1, -1)] = float("inf") | |
residual_image = torch.min(torch.mean(residuals, dim=2, keepdim=True), dim=1)[0] | |
residual_image[torch.min(masks, dim=1)[0]] = 0 | |
data_dict["residual_image"] = residual_image | |
return data_dict | |
class PadSameConv2d(torch.nn.Module): | |
def __init__(self, kernel_size, stride=1): | |
""" | |
Imitates padding_mode="same" from tensorflow. | |
:param kernel_size: Kernelsize of the convolution, int or tuple/list | |
:param stride: Stride of the convolution, int or tuple/list | |
""" | |
super().__init__() | |
if isinstance(kernel_size, (tuple, list)): | |
self.kernel_size_y = kernel_size[0] | |
self.kernel_size_x = kernel_size[1] | |
else: | |
self.kernel_size_y = kernel_size | |
self.kernel_size_x = kernel_size | |
if isinstance(stride, (tuple, list)): | |
self.stride_y = stride[0] | |
self.stride_x = stride[1] | |
else: | |
self.stride_y = stride | |
self.stride_x = stride | |
def forward(self, x: torch.Tensor): | |
_, _, height, width = x.shape | |
# For the convolution we want to achieve a output size of (n_h, n_w) = (math.ceil(h / s_y), math.ceil(w / s_y)). | |
# Therefore we need to apply n_h convolution kernels with stride s_y. We will have n_h - 1 offsets of size s_y. | |
# Additionally, we need to add the size of our kernel. This is the height we require to get n_h. We need to pad | |
# the read difference between this and the old height. We will pad math.floor(pad_y / 2) on the left and | |
# math-ceil(pad_y / 2) on the right. Same for pad_x respectively. | |
padding_y = ( | |
self.stride_y * (math.ceil(height / self.stride_y) - 1) | |
+ self.kernel_size_y | |
- height | |
) / 2 | |
padding_x = ( | |
self.stride_x * (math.ceil(width / self.stride_x) - 1) | |
+ self.kernel_size_x | |
- width | |
) / 2 | |
padding = [ | |
math.floor(padding_x), | |
math.ceil(padding_x), | |
math.floor(padding_y), | |
math.ceil(padding_y), | |
] | |
return F.pad(input=x, pad=padding) | |
class PadSameConv2dTransposed(torch.nn.Module): | |
def __init__(self, stride): | |
""" | |
Imitates padding_mode="same" from tensorflow. | |
:param stride: Stride of the convolution_transposed, int or tuple/list | |
""" | |
super().__init__() | |
if isinstance(stride, (tuple, list)): | |
self.stride_y = stride[0] | |
self.stride_x = stride[1] | |
else: | |
self.stride_y = stride | |
self.stride_x = stride | |
def forward(self, x: torch.Tensor, orig_shape: torch.Tensor): | |
target_shape = x.new_tensor(list(orig_shape)) | |
target_shape[-2] *= self.stride_y | |
target_shape[-1] *= self.stride_x | |
oversize = target_shape[-2:] - x.new_tensor(x.shape)[-2:] | |
if oversize[0] > 0 and oversize[1] > 0: | |
x = F.pad( | |
x, | |
[ | |
math.floor(oversize[1] / 2), | |
math.ceil(oversize[1] / 2), | |
math.floor(oversize[0] / 2), | |
math.ceil(oversize[0] / 2), | |
], | |
) | |
elif oversize[0] > 0 >= oversize[1]: | |
x = F.pad( | |
x, [0, 0, math.floor(oversize[0] / 2), math.ceil(oversize[0] / 2)] | |
) | |
x = x[:, :, :, math.floor(-oversize[1] / 2) : -math.ceil(-oversize[1] / 2)] | |
elif oversize[0] <= 0 < oversize[1]: | |
x = F.pad(x, [math.floor(oversize[1] / 2), math.ceil(oversize[1] / 2)]) | |
x = x[:, :, math.floor(-oversize[0] / 2) : -math.ceil(-oversize[0] / 2), :] | |
else: | |
x = x[ | |
:, | |
:, | |
math.floor(-oversize[0] / 2) : -math.ceil(-oversize[0] / 2), | |
math.floor(-oversize[1] / 2) : -math.ceil(-oversize[1] / 2), | |
] | |
return x | |
class ConvReLU2(torch.nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1 | |
): | |
""" | |
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one | |
only in x direction. | |
:param in_channels: Number of input channels | |
:param out_channels: Number of output channels | |
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction | |
:param stride: Stride for the convolutions, first in y direction, then in x direction | |
""" | |
super().__init__() | |
self.pad_0 = PadSameConv2d(kernel_size=(kernel_size, 1), stride=(stride, 1)) | |
self.conv_y = Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=(kernel_size, 1), | |
stride=(stride, 1), | |
) | |
self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope) | |
self.pad_1 = PadSameConv2d(kernel_size=(1, kernel_size), stride=(1, stride)) | |
self.conv_x = Conv2d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=(1, kernel_size), | |
stride=(1, stride), | |
) | |
def forward(self, x: torch.Tensor): | |
t = self.pad_0(x) | |
t = self.conv_y(t) | |
t = self.leaky_relu(t) | |
t = self.pad_1(t) | |
t = self.conv_x(t) | |
return self.leaky_relu(t) | |
class ConvReLU(torch.nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1 | |
): | |
""" | |
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one | |
only in x direction. | |
:param in_channels: Number of input channels | |
:param out_channels: Number of output channels | |
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction | |
:param stride: Stride for the convolutions, first in y direction, then in x direction | |
""" | |
super().__init__() | |
self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride) | |
self.conv = Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
) | |
self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope) | |
def forward(self, x: torch.Tensor): | |
t = self.pad(x) | |
t = self.conv(t) | |
return self.leaky_relu(t) | |
class Upconv(torch.nn.Module): | |
def __init__(self, in_channels, out_channels): | |
""" | |
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one | |
only in x direction. | |
:param in_channels: Number of input channels | |
:param out_channels: Number of output channels | |
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction | |
:param stride: Stride for the convolutions, first in y direction, then in x direction | |
""" | |
super().__init__() | |
self.upsample = Upsample(scale_factor=2) | |
self.pad = PadSameConv2d(kernel_size=2) | |
self.conv = Conv2d( | |
in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=1 | |
) | |
def forward(self, x: torch.Tensor): | |
t = self.upsample(x) | |
t = self.pad(t) | |
return self.conv(t) | |
class ConvSig(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1): | |
""" | |
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one | |
only in x direction. | |
:param in_channels: Number of input channels | |
:param out_channels: Number of output channels | |
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction | |
:param stride: Stride for the convolutions, first in y direction, then in x direction | |
""" | |
super().__init__() | |
self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride) | |
self.conv = Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
) | |
self.sig = Sigmoid() | |
def forward(self, x: torch.Tensor): | |
t = self.pad(x) | |
t = self.conv(t) | |
return self.sig(t) | |
class Refine(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, leaky_relu_neg_slope=0.1): | |
""" | |
Performs a transposed conv2d with padding that imitates tensorflow same behaviour. The transposed conv2d has | |
parameters kernel_size=4 and stride=2. | |
:param in_channels: Channels that go into the conv2d_transposed | |
:param out_channels: Channels that come out of the conv2d_transposed | |
""" | |
super().__init__() | |
self.conv2d_t = ConvTranspose2d( | |
in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2 | |
) | |
self.pad = PadSameConv2dTransposed(stride=2) | |
self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope) | |
def forward(self, x: torch.Tensor, features_direct=None): | |
orig_shape = x.shape | |
x = self.conv2d_t(x) | |
x = self.leaky_relu(x) | |
x = self.pad(x, orig_shape) | |
if features_direct is not None: | |
x = torch.cat([x, features_direct], dim=1) | |
return x | |