jev-aleks's picture
scenedino init
9e15541
"""
Implements image encoders
"""
from collections import OrderedDict
from torch import profiler
from scenedino.models.prediction_heads.layers import *
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
# Code taken from https://github.com/nianticlabs/monodepth2
#
# Godard, Clément, et al.
# "Digging into self-supervised monocular depth estimation."
# Proceedings of the IEEE/CVF international conference on computer vision.
# 2019.
class ResNetMultiImageInput(models.ResNet):
"""Constructs a resnet model with varying number of input images.
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
def __init__(self, block, layers, num_classes=1000, num_input_images=1):
super(ResNetMultiImageInput, self).__init__(block, layers)
self.inplanes = 64
self.conv1 = nn.Conv2d(
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
"""Constructs a ResNet model.
Args:
num_layers (int): Number of resnet layers. Must be 18 or 50
pretrained (bool): If True, returns a model pre-trained on ImageNet
num_input_images (int): Number of frames stacked as input
"""
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[
num_layers
]
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
if pretrained:
loaded = model_zoo.load_url(
models.resnet.model_urls["resnet{}".format(num_layers)]
)
loaded["conv1.weight"] = (
torch.cat([loaded["conv1.weight"]] * num_input_images, 1) / num_input_images
)
model.load_state_dict(loaded)
return model
class ResnetEncoder(nn.Module):
"""Pytorch module for a resnet encoder"""
def __init__(self, num_layers, pretrained, num_input_images=1):
super(ResnetEncoder, self).__init__()
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
resnets = {
18: models.resnet18,
34: models.resnet34,
50: models.resnet50,
101: models.resnet101,
152: models.resnet152,
}
weights = {
18: models.resnet.ResNet18_Weights.IMAGENET1K_V1,
34: models.resnet.ResNet34_Weights.IMAGENET1K_V1,
50: models.resnet.ResNet50_Weights.IMAGENET1K_V1,
101: models.resnet.ResNet101_Weights.IMAGENET1K_V1,
152: models.resnet.ResNet152_Weights.IMAGENET1K_V1,
}
if num_layers not in resnets:
raise ValueError(
"{} is not a valid number of resnet layers".format(num_layers)
)
if num_input_images > 1:
self.encoder = resnet_multiimage_input(
num_layers, pretrained, num_input_images
)
else:
# TODO: currently hardcoded make adaptable
self.encoder = resnets[num_layers](
weights=weights[num_layers]
)
if num_layers > 34:
self.num_ch_enc[1:] *= 4
def forward(self, input_image):
self.features = []
x = (input_image - 0.45) / 0.225
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
self.features.append(self.encoder.relu(x))
self.features.append(
self.encoder.layer1(self.encoder.maxpool(self.features[-1]))
)
self.features.append(self.encoder.layer2(self.features[-1]))
self.features.append(self.encoder.layer3(self.features[-1]))
self.features.append(self.encoder.layer4(self.features[-1]))
return self.features
class DepthDecoder(nn.Module):
def __init__(
self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True
):
super(DepthDecoder, self).__init__()
self.num_output_channels = num_output_channels
self.use_skips = use_skips
self.upsample_mode = "nearest"
self.scales = scales
self.num_ch_enc = num_ch_enc
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
# decoder
self.convs = OrderedDict()
for i in range(4, -1, -1):
# upconv_0
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
# upconv_1
num_ch_in = self.num_ch_dec[i]
if self.use_skips and i > 0:
num_ch_in += self.num_ch_enc[i - 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
for s in self.scales:
self.convs[("dispconv", s)] = Conv3x3(
self.num_ch_dec[s], self.num_output_channels
)
self.decoder_keys = {k: i for i, k in enumerate(self.convs.keys())}
self.decoder = nn.ModuleList(list(self.convs.values()))
self.sigmoid = nn.Sigmoid()
def forward(self, input_features):
self.outputs = {}
# decoder
x = input_features[-1]
for i in range(4, -1, -1):
# x = self.convs[("upconv", i, 0)](x)
x = self.decoder[self.decoder_keys[("upconv", i, 0)]](x)
x = [upsample(x)]
if self.use_skips and i > 0:
if x[0].shape[2] > input_features[i - 1].shape[2]:
x[0] = x[0][:, :, : input_features[i - 1].shape[2], :]
if x[0].shape[3] > input_features[i - 1].shape[3]:
x[0] = x[0][:, :, :, : input_features[i - 1].shape[3]]
x += [input_features[i - 1]]
x = torch.cat(x, 1)
# x = self.convs[("upconv", i, 1)](x)
x = self.decoder[self.decoder_keys[("upconv", i, 1)]](x)
self.outputs[("features", i)] = x
if i in self.scales:
# self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
self.outputs[("disp", i)] = self.sigmoid(
self.decoder[self.decoder_keys[("dispconv", i)]](x)
)
return self.outputs
class Decoder(nn.Module):
def __init__(
self, num_ch_enc, num_ch_dec=None, d_out=1, scales=range(4), use_skips=True, extra_outs=0
):
super(Decoder, self).__init__()
self.use_skips = use_skips
self.upsample_mode = "nearest"
self.num_ch_enc = num_ch_enc
if num_ch_dec is None:
self.num_ch_dec = np.array([128, 128, 256, 256, 512])
else:
self.num_ch_dec = num_ch_dec
self.d_out = d_out
self.scales = scales
self.extra_outs = extra_outs
self.num_ch_dec = [max(self.d_out, chns) for chns in self.num_ch_dec]
# decoder
self.convs = OrderedDict()
for i in range(4, -1, -1):
# upconv_0
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
# upconv_1
num_ch_in = self.num_ch_dec[i]
if self.use_skips and i > 0:
num_ch_in += self.num_ch_enc[i - 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
for s in self.scales:
self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.d_out)
if self.extra_outs > 0:
for s in self.scales:
self.convs[("extra_outs", s)] = Conv3x3(self.num_ch_dec[s], self.extra_outs)
self.d_out += self.extra_outs
self.decoder_keys = {k: i for i, k in enumerate(self.convs.keys())}
self.decoder = nn.ModuleList(list(self.convs.values()))
self.sigmoid = nn.Sigmoid()
def forward(self, input_features):
with profiler.record_function("encoder_forward"):
self.outputs = {}
# decoder
x = input_features[-1]
for i in range(4, -1, -1):
x = self.decoder[self.decoder_keys[("upconv", i, 0)]](x)
x = [F.interpolate(x, scale_factor=(2, 2), mode="nearest")]
if self.use_skips and i > 0:
feats = input_features[i - 1]
if x[0].shape[2] > feats.shape[2]:
x[0] = x[0][:, :, : feats.shape[2], :]
if x[0].shape[3] > feats.shape[3]:
x[0] = x[0][:, :, :, : feats.shape[3]]
x += [feats]
x = torch.cat(x, 1)
x = self.decoder[self.decoder_keys[("upconv", i, 1)]](x)
self.outputs[("features", i)] = x
if i in self.scales:
self.outputs[("disp", i)] = self.decoder[
self.decoder_keys[("dispconv", i)]
](x)
if self.extra_outs > 0:
self.outputs[("extra_outs", i)] = self.decoder[
self.decoder_keys[("extra_outs", i)]
](x)
return self.outputs
class Monodepth2(nn.Module):
"""
2D (Spatial/Pixel-aligned/local) image encoder
"""
def __init__(
self,
resnet_layers=18,
cp_location=None,
freeze=False,
num_ch_dec=None,
d_out=128,
scales=range(4),
pointwise_convs=None,
extra_outs=0
):
super().__init__()
self.encoder = ResnetEncoder(resnet_layers, True, 1)
self.num_ch_enc = self.encoder.num_ch_enc
self.upsample_mode = "nearest"
self.d_out = d_out
self.extra_outs = extra_outs
self.scales = scales
if pointwise_convs is None:
decoder_out = self.d_out
else:
decoder_out = pointwise_convs[0]
# decoder
self.decoder = Decoder(
num_ch_enc=self.num_ch_enc,
d_out=decoder_out,
num_ch_dec=num_ch_dec,
scales=self.scales,
extra_outs=extra_outs,
)
self.num_ch_dec = self.decoder.num_ch_dec
# pointwise_convs
if pointwise_convs is None:
self.pointwise_convs = None
else:
pointwise_convs.append(self.d_out)
self.pointwise_convs = nn.Sequential(
*[nn.Sequential(
nn.ReLU(),
nn.Conv2d(pointwise_convs[i], pointwise_convs[i+1], 1, 1, 0),
) for i in range(len(pointwise_convs)-1)]
)
self.latent_size = self.d_out
if cp_location is not None:
cp = torch.load(cp_location)
self.load_state_dict(cp["model"])
if freeze:
for p in self.parameters(True):
p.requires_grad = False
def forward(self, x):
"""
For extracting ResNet's features.
:param x image (B, C, H, W)
:return latent (B, latent_size, H, W)
"""
with profiler.record_function("encoder_forward"):
x = torch.cat([x * 0.5 + 0.5], dim=1)
image_features = self.encoder(x) ## output encoder from Monodepth2
outputs = self.decoder(image_features)
x = [outputs[("disp", i)] for i in self.scales]
if self.decoder.extra_outs > 0:
x = [torch.cat((x[i], outputs[("extra_outs", i)]), dim=1) for i in self.scales]
if self.pointwise_convs is not None:
x = [
self.pointwise_convs(x_) for x_ in x
]
return x ## default: x ## note: we need img_feat for feeding into NeuRay
@classmethod
def from_conf(cls, conf):
return cls(
cp_location=conf.get("cp_location", None),
freeze=conf.get("freeze", False),
num_ch_dec=conf.get("num_ch_dec", None),
d_out=conf.get("d_out", 128),
resnet_layers=conf.get("resnet_layers", 18),
pointwise_convs=conf.get("pointwise_convs", None),
extra_outs=conf.get("extra_outs", 0),
)