Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,777 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import torchvision
from torch import nn
from scenedino.models.backbones.monodepth2 import Decoder
class NoDecoder(nn.Module):
def __init__(self, image_size, interpolation, normalize_features):
super().__init__()
match interpolation:
case 'nearest':
inter_mode = torchvision.transforms.InterpolationMode.NEAREST
case 'bilinear':
inter_mode = torchvision.transforms.InterpolationMode.BILINEAR
case 'bicubic':
inter_mode = torchvision.transforms.InterpolationMode.BICUBIC
case _:
raise NotImplementedError(f"Interpolation mode \"{interpolation}\" not implemented!")
self.image_size = image_size
self.resize_tf = torchvision.transforms.Resize(size=image_size, interpolation=inter_mode)
self.normalize_features = normalize_features
def forward(self, x):
features = x[-1]
resized_features = self.resize_tf(features)
if self.normalize_features:
resized_features = resized_features / torch.linalg.norm(resized_features, dim=1, keepdim=True)
return [resized_features]
class SimpleFeaturePyramidDecoder(nn.Module):
def __init__(self,
latent_size,
num_ch_enc,
num_ch_dec,
d_out,
scales,
use_skips,
device):
super().__init__()
self.scales = scales
self.resize_layers = [
nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[0], kernel_size=8, stride=8, padding=0, device=device),
nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[1], kernel_size=4, stride=4, padding=0, device=device),
nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[2], kernel_size=2, stride=2, padding=0, device=device),
nn.Conv2d(in_channels=latent_size, out_channels=num_ch_enc[3], kernel_size=3, stride=1, padding=1, device=device),
nn.Conv2d(in_channels=latent_size, out_channels=num_ch_enc[4], kernel_size=3, stride=2, padding=1, device=device),
]
num_ch_dec = [max(d_out, chns) for chns in num_ch_dec]
self.decoder = Decoder(
num_ch_enc=num_ch_enc,
num_ch_dec=num_ch_dec,
d_out=d_out,
scales=scales,
use_skips=use_skips,
extra_outs=0,
)
def forward(self, x):
dino_features = x[-1]
features = []
for resize_layer in self.resize_layers:
features.append(resize_layer(dino_features))
outputs = self.decoder(features)
return [outputs[("disp", i)] for i in self.scales] |