jev-aleks's picture
scenedino init
9e15541
from __future__ import absolute_import, division, print_function
import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import (
functional as F,
Conv2d,
LeakyReLU,
Upsample,
Sigmoid,
ConvTranspose2d,
)
class ConvBlock(nn.Module):
"""Layer to perform a convolution followed by ELU"""
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = Conv3x3(in_channels, out_channels)
self.nonlin = nn.ELU(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.nonlin(out)
return out
class Conv3x3(nn.Module):
"""Layer to pad and convolve input"""
def __init__(self, in_channels, out_channels, use_refl=True):
super(Conv3x3, self).__init__()
if use_refl:
self.pad = nn.ReflectionPad2d(1)
else:
self.pad = nn.ZeroPad2d(1)
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
def forward(self, x):
out = self.pad(x)
out = self.conv(out)
return out
class Backprojection(nn.Module):
def __init__(self, height, width):
super(Backprojection, self).__init__()
self.H, self.W = height, width
yy, xx = torch.meshgrid(
[torch.linspace(-1.0, 1.0, self.H), torch.linspace(-1.0, 1.0, self.W)]
)
yy = yy.contiguous().view(-1)
xx = xx.contiguous().view(-1)
self.ones = nn.Parameter(
torch.ones(1, 1, self.H * self.W, dtype=torch.float), requires_grad=False
)
self.coord = torch.unsqueeze(torch.stack([xx, yy], 0), 0)
self.coord = nn.Parameter(
torch.cat([self.coord, self.ones], 1), requires_grad=False
)
def forward(self, depth, inv_K):
n = depth.shape[0]
cam_p_norm = torch.matmul(inv_K[:, :3, :3], self.coord.expand(n, -1, -1))
cam_p_euc = depth.view(n, 1, -1) * cam_p_norm
cam_p_h = torch.cat([cam_p_euc, self.ones.expand(n,e -1, -1)], 1)
return cam_p_h
def point_projection(points3D, batch_size, height, width, K, T):
N, H, W = batch_size, height, width
cam_coord = torch.matmul(torch.matmul(K, T[:, :3, :]), points3D)
img_coord = cam_coord[:, :2, :] / (cam_coord[:, 2:3, :] + 1e-7)
img_coord = img_coord.view(N, 2, H, W).permute(0, 2, 3, 1)
return img_coord, cam_coord[:, 2, :]
def upsample(x):
"""Upsample input tensor by a factor of 2"""
return F.interpolate(x, scale_factor=2, mode="nearest")
class GaussianAverage(nn.Module):
def __init__(self) -> None:
super().__init__()
self.window = torch.Tensor(
[
[0.0947, 0.1183, 0.0947],
[0.1183, 0.1478, 0.1183],
[0.0947, 0.1183, 0.0947],
]
)
def forward(self, x):
kernel = self.window.to(x.device).to(x.dtype).repeat(x.shape[1], 1, 1, 1)
return F.conv2d(x, kernel, padding=0, groups=x.shape[1])
class SSIM(nn.Module):
"""Layer to compute the SSIM loss between a pair of images"""
def __init__(
self,
pad_reflection=True,
gaussian_average=False,
comp_mode=False,
eval_mode=False,
):
super(SSIM, self).__init__()
self.comp_mode = comp_mode
self.eval_mode = eval_mode
if not gaussian_average:
self.mu_x_pool = nn.AvgPool2d(3, 1)
self.mu_y_pool = nn.AvgPool2d(3, 1)
self.sig_x_pool = nn.AvgPool2d(3, 1)
self.sig_y_pool = nn.AvgPool2d(3, 1)
self.sig_xy_pool = nn.AvgPool2d(3, 1)
else:
self.mu_x_pool = GaussianAverage()
self.mu_y_pool = GaussianAverage()
self.sig_x_pool = GaussianAverage()
self.sig_y_pool = GaussianAverage()
self.sig_xy_pool = GaussianAverage()
if pad_reflection:
self.pad = nn.ReflectionPad2d(1)
else:
self.pad = nn.ZeroPad2d(1)
self.C1 = 0.01**2
self.C2 = 0.03**2
def forward(self, x, y, pad=True):
if pad:
x = self.pad(x)
y = self.pad(y)
## average of pixels in x and y, average pooling or Gaussian averaging, based on the initialization
mu_x = self.mu_x_pool(x)
mu_y = self.mu_y_pool(y)
mu_x_sq = (
mu_x**2
) ## squares of the averages and the product of the averages, respectively.
mu_y_sq = mu_y**2
mu_x_y = mu_x * mu_y
## variances and covariance:
sigma_x = self.sig_x_pool(x**2) - mu_x_sq
sigma_y = self.sig_y_pool(y**2) - mu_y_sq
sigma_xy = self.sig_xy_pool(x * y) - mu_x_y
SSIM_n = (2 * mu_x_y + self.C1) * (2 * sigma_xy + self.C2)
SSIM_d = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x + sigma_y + self.C2)
if (
not self.eval_mode
): ## determines how to handle the output of the SSIM calculation
if (
not self.comp_mode
): ## error (1 - SSIM index), used as a loss function during training
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
else: ## return the SSIM index itself.
return torch.clamp((1 - SSIM_n / SSIM_d), 0, 1) / 2
else:
return (
SSIM_n / SSIM_d
) ## returned error is scaled to range between 0 and 1 for easier interpretation and usage in loss calculations
## (2*mu_x*mu_y + C1)*(2*sigma_xy + C2) / ((mu_x^2 + mu_y^2 + C1)*(sigma_x + sigma_y + C2)).
## SSIM index ranges from -1 to 1, where 1 means the images are identical, -1 means the images are totally different, and 0 means the images are not correlated
class GEO(nn.Module):
"""Layer to compute the pseudo label, L_{geo}, loss between a pair of images"""
def __init__(
self,
pad_reflection=True,
gaussian_average=False,
comp_mode=False,
eval_mode=False,
):
super(GEO, self).__init__()
self.comp_mode = comp_mode
self.eval_mode = eval_mode
if not gaussian_average:
self.mu_x_pool = nn.AvgPool2d(3, 1)
self.mu_y_pool = nn.AvgPool2d(3, 1)
self.sig_x_pool = nn.AvgPool2d(3, 1)
self.sig_y_pool = nn.AvgPool2d(3, 1)
self.sig_xy_pool = nn.AvgPool2d(3, 1)
else:
self.mu_x_pool = GaussianAverage()
self.mu_y_pool = GaussianAverage()
self.sig_x_pool = GaussianAverage()
self.sig_y_pool = GaussianAverage()
self.sig_xy_pool = GaussianAverage()
if pad_reflection:
self.pad = nn.ReflectionPad2d(1)
else:
self.pad = nn.ZeroPad2d(1)
self.C1 = 0.01**2
self.C2 = 0.03**2
def forward(self, x, y, pad=True):
if pad:
x = self.pad(x)
y = self.pad(y)
## average of pixels in x and y, average pooling or Gaussian averaging, based on the initialization
mu_x = self.mu_x_pool(x)
mu_y = self.mu_y_pool(y)
mu_x_sq = (
mu_x**2
) ## squares of the averages and the product of the averages, respectively.
mu_y_sq = mu_y**2
mu_x_y = mu_x * mu_y
## variances and covariance:
sigma_x = self.sig_x_pool(x**2) - mu_x_sq
sigma_y = self.sig_y_pool(y**2) - mu_y_sq
sigma_xy = self.sig_xy_pool(x * y) - mu_x_y
SSIM_n = (2 * mu_x_y + self.C1) * (2 * sigma_xy + self.C2)
SSIM_d = (mu_x_sq + mu_y_sq + self.C1) * (sigma_x + sigma_y + self.C2)
if (
not self.eval_mode
): ## determines how to handle the output of the SSIM calculation
if (
not self.comp_mode
): ## error (1 - SSIM index), used as a loss function during training
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
else: ## return the SSIM index itself.
return torch.clamp((1 - SSIM_n / SSIM_d), 0, 1) / 2
else:
return (
SSIM_n / SSIM_d
) ## returned error is scaled to range between 0 and 1 for easier interpretation and usage in loss calculations
## (2*mu_x*mu_y + C1)*(2*sigma_xy + C2) / ((mu_x^2 + mu_y^2 + C1)*(sigma_x + sigma_y + C2)).
## SSIM index ranges from -1 to 1, where 1 means the images are identical, -1 means the images are totally different, and 0 means the images are not correlated
def ssim(
x,
y,
pad_reflection=True,
gaussian_average=False,
comp_mode=False,
eval_mode=False,
pad=True,
):
ssim_ = SSIM(pad_reflection, gaussian_average, comp_mode, eval_mode)
return ssim_(x, y, pad=pad)
def geo(
x,
y,
pad_reflection=True,
gaussian_average=False,
comp_mode=False,
eval_mode=False,
pad=True,
):
geo_ = GEO(pad_reflection, gaussian_average, comp_mode, eval_mode)
return geo_(x, y, pad=pad)
# NOTE: is the following used? Ask Felix
class ResidualImage(nn.Module):
def __init__(self):
super().__init__()
self.residual_image = ResidualImageModule()
def forward(
self,
keyframe: torch.Tensor,
keyframe_pose: torch.Tensor,
keyframe_intrinsics: torch.Tensor,
depths: torch.Tensor,
frames: list,
poses: list,
intrinsics: list,
):
data_dict = {
"keyframe": keyframe,
"keyframe_pose": keyframe_pose,
"keyframe_intrinsics": keyframe_intrinsics,
"predicted_inverse_depths": [depths],
"frames": frames,
"poses": poses,
"list": list,
"intrinsics": intrinsics,
"inv_depth_max": 0,
"inv_depth_min": 1,
}
data_dict = self.residual_image(data_dict)
return data_dict["residual_image"]
class ResidualImageModule(nn.Module):
def __init__(self, use_mono=True, use_stereo=False):
super().__init__()
self.use_mono = use_mono
self.use_stereo = use_stereo
self.ssim = SSIM()
def forward(self, data_dict):
keyframe = data_dict["keyframe"]
keyframe_intrinsics = data_dict["keyframe_intrinsics"]
keyframe_pose = data_dict["keyframe_pose"]
depths = (1 - data_dict["predicted_inverse_depths"][0]) * data_dict[
"inv_depth_max"
] + data_dict["predicted_inverse_depths"][0] * data_dict["inv_depth_min"]
frames = []
intrinsics = []
poses = []
if self.use_mono:
frames += data_dict["frames"]
intrinsics += data_dict["intrinsics"]
poses += data_dict["poses"]
if self.use_stereo:
frames += [data_dict["stereoframe"]]
intrinsics += [data_dict["stereoframe_intrinsics"]]
poses += [data_dict["stereoframe_pose"]]
n, c, h, w = keyframe.shape
backproject_depth = Backprojection(n, h, w)
backproject_depth.to(keyframe.device)
inv_k = torch.inverse(keyframe_intrinsics)
cam_points = inv_k[:, :3, :3] @ backproject_depth.pix_coords
cam_points = cam_points / depths.view(n, 1, -1)
cam_points = torch.cat([cam_points, backproject_depth.ones], 1)
masks = []
residuals = []
for i, image in enumerate(frames):
t = torch.inverse(poses[i]) @ keyframe_pose
pix_coords = point_projection(cam_points, n, h, w, intrinsics[i], t)
warped_image = F.grid_sample(image + 1, pix_coords)
mask = torch.any(warped_image == 0, dim=1, keepdim=True)
warped_image -= 0.5
residual = self.ssim(warped_image, keyframe + 0.5)
masks.append(mask)
residuals.append(residual)
masks = torch.stack(masks, dim=1)
residuals = torch.stack(residuals, dim=1)
residuals[masks.expand(-1, -1, c, -1, -1)] = float("inf")
residual_image = torch.min(torch.mean(residuals, dim=2, keepdim=True), dim=1)[0]
residual_image[torch.min(masks, dim=1)[0]] = 0
data_dict["residual_image"] = residual_image
return data_dict
class PadSameConv2d(torch.nn.Module):
def __init__(self, kernel_size, stride=1):
"""
Imitates padding_mode="same" from tensorflow.
:param kernel_size: Kernelsize of the convolution, int or tuple/list
:param stride: Stride of the convolution, int or tuple/list
"""
super().__init__()
if isinstance(kernel_size, (tuple, list)):
self.kernel_size_y = kernel_size[0]
self.kernel_size_x = kernel_size[1]
else:
self.kernel_size_y = kernel_size
self.kernel_size_x = kernel_size
if isinstance(stride, (tuple, list)):
self.stride_y = stride[0]
self.stride_x = stride[1]
else:
self.stride_y = stride
self.stride_x = stride
def forward(self, x: torch.Tensor):
_, _, height, width = x.shape
# For the convolution we want to achieve a output size of (n_h, n_w) = (math.ceil(h / s_y), math.ceil(w / s_y)).
# Therefore we need to apply n_h convolution kernels with stride s_y. We will have n_h - 1 offsets of size s_y.
# Additionally, we need to add the size of our kernel. This is the height we require to get n_h. We need to pad
# the read difference between this and the old height. We will pad math.floor(pad_y / 2) on the left and
# math-ceil(pad_y / 2) on the right. Same for pad_x respectively.
padding_y = (
self.stride_y * (math.ceil(height / self.stride_y) - 1)
+ self.kernel_size_y
- height
) / 2
padding_x = (
self.stride_x * (math.ceil(width / self.stride_x) - 1)
+ self.kernel_size_x
- width
) / 2
padding = [
math.floor(padding_x),
math.ceil(padding_x),
math.floor(padding_y),
math.ceil(padding_y),
]
return F.pad(input=x, pad=padding)
class PadSameConv2dTransposed(torch.nn.Module):
def __init__(self, stride):
"""
Imitates padding_mode="same" from tensorflow.
:param stride: Stride of the convolution_transposed, int or tuple/list
"""
super().__init__()
if isinstance(stride, (tuple, list)):
self.stride_y = stride[0]
self.stride_x = stride[1]
else:
self.stride_y = stride
self.stride_x = stride
def forward(self, x: torch.Tensor, orig_shape: torch.Tensor):
target_shape = x.new_tensor(list(orig_shape))
target_shape[-2] *= self.stride_y
target_shape[-1] *= self.stride_x
oversize = target_shape[-2:] - x.new_tensor(x.shape)[-2:]
if oversize[0] > 0 and oversize[1] > 0:
x = F.pad(
x,
[
math.floor(oversize[1] / 2),
math.ceil(oversize[1] / 2),
math.floor(oversize[0] / 2),
math.ceil(oversize[0] / 2),
],
)
elif oversize[0] > 0 >= oversize[1]:
x = F.pad(
x, [0, 0, math.floor(oversize[0] / 2), math.ceil(oversize[0] / 2)]
)
x = x[:, :, :, math.floor(-oversize[1] / 2) : -math.ceil(-oversize[1] / 2)]
elif oversize[0] <= 0 < oversize[1]:
x = F.pad(x, [math.floor(oversize[1] / 2), math.ceil(oversize[1] / 2)])
x = x[:, :, math.floor(-oversize[0] / 2) : -math.ceil(-oversize[0] / 2), :]
else:
x = x[
:,
:,
math.floor(-oversize[0] / 2) : -math.ceil(-oversize[0] / 2),
math.floor(-oversize[1] / 2) : -math.ceil(-oversize[1] / 2),
]
return x
class ConvReLU2(torch.nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1
):
"""
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
only in x direction.
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
:param stride: Stride for the convolutions, first in y direction, then in x direction
"""
super().__init__()
self.pad_0 = PadSameConv2d(kernel_size=(kernel_size, 1), stride=(stride, 1))
self.conv_y = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(kernel_size, 1),
stride=(stride, 1),
)
self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope)
self.pad_1 = PadSameConv2d(kernel_size=(1, kernel_size), stride=(1, stride))
self.conv_x = Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=(1, kernel_size),
stride=(1, stride),
)
def forward(self, x: torch.Tensor):
t = self.pad_0(x)
t = self.conv_y(t)
t = self.leaky_relu(t)
t = self.pad_1(t)
t = self.conv_x(t)
return self.leaky_relu(t)
class ConvReLU(torch.nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, leaky_relu_neg_slope=0.1
):
"""
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
only in x direction.
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
:param stride: Stride for the convolutions, first in y direction, then in x direction
"""
super().__init__()
self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride)
self.conv = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope)
def forward(self, x: torch.Tensor):
t = self.pad(x)
t = self.conv(t)
return self.leaky_relu(t)
class Upconv(torch.nn.Module):
def __init__(self, in_channels, out_channels):
"""
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
only in x direction.
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
:param stride: Stride for the convolutions, first in y direction, then in x direction
"""
super().__init__()
self.upsample = Upsample(scale_factor=2)
self.pad = PadSameConv2d(kernel_size=2)
self.conv = Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=1
)
def forward(self, x: torch.Tensor):
t = self.upsample(x)
t = self.pad(t)
return self.conv(t)
class ConvSig(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
"""
Performs two convolutions and a leaky relu. The first operation only convolves in y direction, the second one
only in x direction.
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Kernel size for the convolutions, first in y direction, then in x direction
:param stride: Stride for the convolutions, first in y direction, then in x direction
"""
super().__init__()
self.pad = PadSameConv2d(kernel_size=kernel_size, stride=stride)
self.conv = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
self.sig = Sigmoid()
def forward(self, x: torch.Tensor):
t = self.pad(x)
t = self.conv(t)
return self.sig(t)
class Refine(torch.nn.Module):
def __init__(self, in_channels, out_channels, leaky_relu_neg_slope=0.1):
"""
Performs a transposed conv2d with padding that imitates tensorflow same behaviour. The transposed conv2d has
parameters kernel_size=4 and stride=2.
:param in_channels: Channels that go into the conv2d_transposed
:param out_channels: Channels that come out of the conv2d_transposed
"""
super().__init__()
self.conv2d_t = ConvTranspose2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2
)
self.pad = PadSameConv2dTransposed(stride=2)
self.leaky_relu = LeakyReLU(negative_slope=leaky_relu_neg_slope)
def forward(self, x: torch.Tensor, features_direct=None):
orig_shape = x.shape
x = self.conv2d_t(x)
x = self.leaky_relu(x)
x = self.pad(x, orig_shape)
if features_direct is not None:
x = torch.cat([x, features_direct], dim=1)
return x