""" Implements image encoders """ from collections import OrderedDict from torch import profiler from scenedino.models.prediction_heads.layers import * import numpy as np import torch import torch.nn as nn import torchvision.models as models import torch.utils.model_zoo as model_zoo # Code taken from https://github.com/nianticlabs/monodepth2 # # Godard, Clément, et al. # "Digging into self-supervised monocular depth estimation." # Proceedings of the IEEE/CVF international conference on computer vision. # 2019. class ResNetMultiImageInput(models.ResNet): """Constructs a resnet model with varying number of input images. Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ def __init__(self, block, layers, num_classes=1000, num_input_images=1): super(ResNetMultiImageInput, self).__init__(block, layers) self.inplanes = 64 self.conv1 = nn.Conv2d( num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): """Constructs a ResNet model. Args: num_layers (int): Number of resnet layers. Must be 18 or 50 pretrained (bool): If True, returns a model pre-trained on ImageNet num_input_images (int): Number of frames stacked as input """ assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[ num_layers ] model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) if pretrained: loaded = model_zoo.load_url( models.resnet.model_urls["resnet{}".format(num_layers)] ) loaded["conv1.weight"] = ( torch.cat([loaded["conv1.weight"]] * num_input_images, 1) / num_input_images ) model.load_state_dict(loaded) return model class ResnetEncoder(nn.Module): """Pytorch module for a resnet encoder""" def __init__(self, num_layers, pretrained, num_input_images=1): super(ResnetEncoder, self).__init__() self.num_ch_enc = np.array([64, 64, 128, 256, 512]) resnets = { 18: models.resnet18, 34: models.resnet34, 50: models.resnet50, 101: models.resnet101, 152: models.resnet152, } weights = { 18: models.resnet.ResNet18_Weights.IMAGENET1K_V1, 34: models.resnet.ResNet34_Weights.IMAGENET1K_V1, 50: models.resnet.ResNet50_Weights.IMAGENET1K_V1, 101: models.resnet.ResNet101_Weights.IMAGENET1K_V1, 152: models.resnet.ResNet152_Weights.IMAGENET1K_V1, } if num_layers not in resnets: raise ValueError( "{} is not a valid number of resnet layers".format(num_layers) ) if num_input_images > 1: self.encoder = resnet_multiimage_input( num_layers, pretrained, num_input_images ) else: # TODO: currently hardcoded make adaptable self.encoder = resnets[num_layers]( weights=weights[num_layers] ) if num_layers > 34: self.num_ch_enc[1:] *= 4 def forward(self, input_image): self.features = [] x = (input_image - 0.45) / 0.225 x = self.encoder.conv1(x) x = self.encoder.bn1(x) self.features.append(self.encoder.relu(x)) self.features.append( self.encoder.layer1(self.encoder.maxpool(self.features[-1])) ) self.features.append(self.encoder.layer2(self.features[-1])) self.features.append(self.encoder.layer3(self.features[-1])) self.features.append(self.encoder.layer4(self.features[-1])) return self.features class DepthDecoder(nn.Module): def __init__( self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True ): super(DepthDecoder, self).__init__() self.num_output_channels = num_output_channels self.use_skips = use_skips self.upsample_mode = "nearest" self.scales = scales self.num_ch_enc = num_ch_enc self.num_ch_dec = np.array([16, 32, 64, 128, 256]) # decoder self.convs = OrderedDict() for i in range(4, -1, -1): # upconv_0 num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) # upconv_1 num_ch_in = self.num_ch_dec[i] if self.use_skips and i > 0: num_ch_in += self.num_ch_enc[i - 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) for s in self.scales: self.convs[("dispconv", s)] = Conv3x3( self.num_ch_dec[s], self.num_output_channels ) self.decoder_keys = {k: i for i, k in enumerate(self.convs.keys())} self.decoder = nn.ModuleList(list(self.convs.values())) self.sigmoid = nn.Sigmoid() def forward(self, input_features): self.outputs = {} # decoder x = input_features[-1] for i in range(4, -1, -1): # x = self.convs[("upconv", i, 0)](x) x = self.decoder[self.decoder_keys[("upconv", i, 0)]](x) x = [upsample(x)] if self.use_skips and i > 0: if x[0].shape[2] > input_features[i - 1].shape[2]: x[0] = x[0][:, :, : input_features[i - 1].shape[2], :] if x[0].shape[3] > input_features[i - 1].shape[3]: x[0] = x[0][:, :, :, : input_features[i - 1].shape[3]] x += [input_features[i - 1]] x = torch.cat(x, 1) # x = self.convs[("upconv", i, 1)](x) x = self.decoder[self.decoder_keys[("upconv", i, 1)]](x) self.outputs[("features", i)] = x if i in self.scales: # self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) self.outputs[("disp", i)] = self.sigmoid( self.decoder[self.decoder_keys[("dispconv", i)]](x) ) return self.outputs class Decoder(nn.Module): def __init__( self, num_ch_enc, num_ch_dec=None, d_out=1, scales=range(4), use_skips=True, extra_outs=0 ): super(Decoder, self).__init__() self.use_skips = use_skips self.upsample_mode = "nearest" self.num_ch_enc = num_ch_enc if num_ch_dec is None: self.num_ch_dec = np.array([128, 128, 256, 256, 512]) else: self.num_ch_dec = num_ch_dec self.d_out = d_out self.scales = scales self.extra_outs = extra_outs self.num_ch_dec = [max(self.d_out, chns) for chns in self.num_ch_dec] # decoder self.convs = OrderedDict() for i in range(4, -1, -1): # upconv_0 num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) # upconv_1 num_ch_in = self.num_ch_dec[i] if self.use_skips and i > 0: num_ch_in += self.num_ch_enc[i - 1] num_ch_out = self.num_ch_dec[i] self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) for s in self.scales: self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.d_out) if self.extra_outs > 0: for s in self.scales: self.convs[("extra_outs", s)] = Conv3x3(self.num_ch_dec[s], self.extra_outs) self.d_out += self.extra_outs self.decoder_keys = {k: i for i, k in enumerate(self.convs.keys())} self.decoder = nn.ModuleList(list(self.convs.values())) self.sigmoid = nn.Sigmoid() def forward(self, input_features): with profiler.record_function("encoder_forward"): self.outputs = {} # decoder x = input_features[-1] for i in range(4, -1, -1): x = self.decoder[self.decoder_keys[("upconv", i, 0)]](x) x = [F.interpolate(x, scale_factor=(2, 2), mode="nearest")] if self.use_skips and i > 0: feats = input_features[i - 1] if x[0].shape[2] > feats.shape[2]: x[0] = x[0][:, :, : feats.shape[2], :] if x[0].shape[3] > feats.shape[3]: x[0] = x[0][:, :, :, : feats.shape[3]] x += [feats] x = torch.cat(x, 1) x = self.decoder[self.decoder_keys[("upconv", i, 1)]](x) self.outputs[("features", i)] = x if i in self.scales: self.outputs[("disp", i)] = self.decoder[ self.decoder_keys[("dispconv", i)] ](x) if self.extra_outs > 0: self.outputs[("extra_outs", i)] = self.decoder[ self.decoder_keys[("extra_outs", i)] ](x) return self.outputs class Monodepth2(nn.Module): """ 2D (Spatial/Pixel-aligned/local) image encoder """ def __init__( self, resnet_layers=18, cp_location=None, freeze=False, num_ch_dec=None, d_out=128, scales=range(4), pointwise_convs=None, extra_outs=0 ): super().__init__() self.encoder = ResnetEncoder(resnet_layers, True, 1) self.num_ch_enc = self.encoder.num_ch_enc self.upsample_mode = "nearest" self.d_out = d_out self.extra_outs = extra_outs self.scales = scales if pointwise_convs is None: decoder_out = self.d_out else: decoder_out = pointwise_convs[0] # decoder self.decoder = Decoder( num_ch_enc=self.num_ch_enc, d_out=decoder_out, num_ch_dec=num_ch_dec, scales=self.scales, extra_outs=extra_outs, ) self.num_ch_dec = self.decoder.num_ch_dec # pointwise_convs if pointwise_convs is None: self.pointwise_convs = None else: pointwise_convs.append(self.d_out) self.pointwise_convs = nn.Sequential( *[nn.Sequential( nn.ReLU(), nn.Conv2d(pointwise_convs[i], pointwise_convs[i+1], 1, 1, 0), ) for i in range(len(pointwise_convs)-1)] ) self.latent_size = self.d_out if cp_location is not None: cp = torch.load(cp_location) self.load_state_dict(cp["model"]) if freeze: for p in self.parameters(True): p.requires_grad = False def forward(self, x): """ For extracting ResNet's features. :param x image (B, C, H, W) :return latent (B, latent_size, H, W) """ with profiler.record_function("encoder_forward"): x = torch.cat([x * 0.5 + 0.5], dim=1) image_features = self.encoder(x) ## output encoder from Monodepth2 outputs = self.decoder(image_features) x = [outputs[("disp", i)] for i in self.scales] if self.decoder.extra_outs > 0: x = [torch.cat((x[i], outputs[("extra_outs", i)]), dim=1) for i in self.scales] if self.pointwise_convs is not None: x = [ self.pointwise_convs(x_) for x_ in x ] return x ## default: x ## note: we need img_feat for feeding into NeuRay @classmethod def from_conf(cls, conf): return cls( cp_location=conf.get("cp_location", None), freeze=conf.get("freeze", False), num_ch_dec=conf.get("num_ch_dec", None), d_out=conf.get("d_out", 128), resnet_layers=conf.get("resnet_layers", 18), pointwise_convs=conf.get("pointwise_convs", None), extra_outs=conf.get("extra_outs", 0), )