File size: 3,875 Bytes
a7dedf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from torch import nn, Tensor
from torch.hub import load_state_dict_from_url
from typing import Optional

from .vgg import VGG
from .utils import make_vgg_layers, vgg_urls
from ..utils import _init_weights, ConvDownsample, _get_activation, _get_norm_layer

EPS = 1e-6


encoder_cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512]
decoder_cfg = [512, 512, 512, 256, 128]


class CSRNet(nn.Module):
    def __init__(
        self,
        model_name: str,
        block_size: Optional[int] = None,
        norm: str = "none",
        act: str = "none"
    ) -> None:
        super().__init__()
        assert model_name in ["vgg16", "vgg16_bn"], f"Model name should be one of ['vgg16', 'vgg16_bn'], but got {model_name}."
        assert block_size is None or block_size in [8, 16, 32], f"block_size should be one of [8, 16, 32], but got {block_size}."
        self.model_name = model_name

        vgg = VGG(make_vgg_layers(encoder_cfg, in_channels=3, batch_norm="bn" in model_name, dilation=1))
        vgg.load_state_dict(load_state_dict_from_url(vgg_urls[model_name]), strict=False)
        self.encoder = vgg.features
        self.encoder_reduction = 8
        self.encoder_channels = 512
        self.block_size = block_size if block_size is not None else 8

        if norm == "bn":
            norm_layer = nn.BatchNorm2d
        elif norm == "ln":
            norm_layer = nn.LayerNorm
        else:
            norm_layer = _get_norm_layer(vgg)
        
        if act == "relu":
            activation = nn.ReLU(inplace=True)
        elif act == "gelu":
            activation = nn.GELU()
        else:
            activation = _get_activation(vgg)

        if self.block_size == self.encoder_reduction:
            self.refiner = nn.Identity()
        elif self.block_size > self.encoder_reduction:
            if self.block_size == 32:
                self.refiner = nn.Sequential(
                    ConvDownsample(
                        in_channels=self.encoder_channels,
                        out_channels=self.encoder_channels,
                        norm_layer=norm_layer, 
                        activation=activation,
                    ),
                    ConvDownsample(
                        in_channels=self.encoder_channels,
                        out_channels=self.encoder_channels,
                        norm_layer=norm_layer, 
                        activation=activation,
                    )
                )
            elif self.block_size == 16:
                self.refiner = ConvDownsample(
                    in_channels=self.encoder_channels,
                    out_channels=self.encoder_channels,
                    norm_layer=norm_layer, 
                    activation=activation,
                )
        self.refiner_channels = self.encoder_channels
        self.refiner_reduction = self.block_size

        decoder = make_vgg_layers(decoder_cfg, in_channels=512, batch_norm="bn" in model_name, dilation=2)
        decoder.apply(_init_weights)
        self.decoder = decoder
        self.decoder_channels = decoder_cfg[-1]
        self.decoder_reduction = self.refiner_reduction

    def encode(self, x: Tensor) -> Tensor:
        return self.encoder(x)
    
    def refine(self, x: Tensor) -> Tensor:
        return self.refiner(x)
    
    def decode(self, x: Tensor) -> Tensor:
        return self.decoder(x)

    def forward(self, x: Tensor) -> Tensor:
        x = self.encode(x)
        x = self.refine(x)
        x = self.decode(x)
        return x


def _csrnet(block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> CSRNet:
    return CSRNet("vgg16", block_size=block_size, norm=norm, act=act)

def _csrnet_bn(block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> CSRNet:
    return CSRNet("vgg16_bn", block_size=block_size, norm=norm, act=act)