File size: 2,777 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torchvision
from torch import nn

from scenedino.models.backbones.monodepth2 import Decoder


class NoDecoder(nn.Module):
    def __init__(self, image_size, interpolation, normalize_features):
        super().__init__()

        match interpolation:
            case 'nearest':
                inter_mode = torchvision.transforms.InterpolationMode.NEAREST
            case 'bilinear':
                inter_mode = torchvision.transforms.InterpolationMode.BILINEAR
            case 'bicubic':
                inter_mode = torchvision.transforms.InterpolationMode.BICUBIC
            case _:
                raise NotImplementedError(f"Interpolation mode \"{interpolation}\" not implemented!")

        self.image_size = image_size
        self.resize_tf = torchvision.transforms.Resize(size=image_size, interpolation=inter_mode)
        self.normalize_features = normalize_features

    def forward(self, x):
        features = x[-1]
        resized_features = self.resize_tf(features)

        if self.normalize_features:
            resized_features = resized_features / torch.linalg.norm(resized_features, dim=1, keepdim=True)

        return [resized_features]


class SimpleFeaturePyramidDecoder(nn.Module):
    def __init__(self,
                 latent_size,
                 num_ch_enc,
                 num_ch_dec,
                 d_out,
                 scales,
                 use_skips,
                 device):
        super().__init__()

        self.scales = scales
        self.resize_layers = [
            nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[0], kernel_size=8, stride=8, padding=0, device=device),
            nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[1], kernel_size=4, stride=4, padding=0, device=device),
            nn.ConvTranspose2d(in_channels=latent_size, out_channels=num_ch_enc[2], kernel_size=2, stride=2, padding=0, device=device),
            nn.Conv2d(in_channels=latent_size, out_channels=num_ch_enc[3], kernel_size=3, stride=1, padding=1, device=device),
            nn.Conv2d(in_channels=latent_size, out_channels=num_ch_enc[4], kernel_size=3, stride=2, padding=1, device=device),
        ]

        num_ch_dec = [max(d_out, chns) for chns in num_ch_dec]
        self.decoder = Decoder(
            num_ch_enc=num_ch_enc,
            num_ch_dec=num_ch_dec,
            d_out=d_out,
            scales=scales,
            use_skips=use_skips,
            extra_outs=0,
        )

    def forward(self, x):
        dino_features = x[-1]
        features = []
        for resize_layer in self.resize_layers:
            features.append(resize_layer(dino_features))

        outputs = self.decoder(features)
        return [outputs[("disp", i)] for i in self.scales]