jev-aleks's picture
scenedino init
9e15541
import torch
import torchvision
from torch import nn
from scenedino.models.backbones.monodepth2 import Decoder
class NoDecoder(nn.Module):
def __init__(self, image_size, interpolation, normalize_features):
super().__init__()
match interpolation:
case 'nearest':
inter_mode = torchvision.transforms.InterpolationMode.NEAREST
case 'bilinear':
inter_mode = torchvision.transforms.InterpolationMode.BILINEAR
case 'bicubic':
inter_mode = torchvision.transforms.InterpolationMode.BICUBIC
case _:
raise NotImplementedError(f"Interpolation mode \"{interpolation}\" not implemented!")
self.image_size = image_size
self.resize_tf = torchvision.transforms.Resize(size=image_size, interpolation=inter_mode)
self.normalize_features = normalize_features
def forward(self, x):
features = x[-1]
resized_features = self.resize_tf(features)
if self.normalize_features:
resized_features = resized_features / torch.linalg.norm(resized_features, dim=1, keepdim=True)
return [resized_features]
class SimpleFeaturePyramidDecoder(nn.Module):
def __init__(self,
latent_size,
num_ch_enc,
num_ch_dec,
d_out,
scales,
use_skips,
device):
super().__init__()
self.scales = scales
self.resize_layers = [
nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[0], kernel_size=8, stride=8, padding=0, device=device),
nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[1], kernel_size=4, stride=4, padding=0, device=device),
nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[2], kernel_size=2, stride=2, padding=0, device=device),
nn.Conv2d(in_channels=latent_size, out_channels=num_ch_enc[3], kernel_size=3, stride=1, padding=1, device=device),
nn.Conv2d(in_channels=latent_size, out_channels=num_ch_enc[4], kernel_size=3, stride=2, padding=1, device=device),
]
num_ch_dec = [max(d_out, chns) for chns in num_ch_dec]
self.decoder = Decoder(
num_ch_enc=num_ch_enc,
num_ch_dec=num_ch_dec,
d_out=d_out,
scales=scales,
use_skips=use_skips,
extra_outs=0,
)
def forward(self, x):
dino_features = x[-1]
features = []
for resize_layer in self.resize_layers:
features.append(resize_layer(dino_features))
outputs = self.decoder(features)
return [outputs[("disp", i)] for i in self.scales]