File size: 3,003 Bytes
a7dedf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from torch import nn, Tensor
from typing import Union

from .utils import _init_weights
from .blocks import BasicBlock, LightBasicBlock, conv1x1, conv3x3


class ConvRefine(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
        activation: nn.Module = nn.ReLU(inplace=True),
        groups: int = 1,
    ) -> None:
        super().__init__()
        self.refine = BasicBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            norm_layer=norm_layer,
            activation=activation,
            groups=groups,
        )
        self.apply(_init_weights)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.refine(x)


class LightConvRefine(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
        activation: nn.Module = nn.ReLU(inplace=True),
    ) -> None:
        super().__init__()
        self.refine = LightBasicBlock(
            in_channels=in_channels,
            out_channels=out_channels,
            norm_layer=norm_layer,
            activation=activation,
        )
        self.apply(_init_weights)
    
    def forward(self, x: Tensor) -> Tensor:
        return self.refine(x)


class LighterConvRefine(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d,
        activation: nn.Module = nn.ReLU(inplace=True),
    ) -> None:
        super().__init__()
        # depthwise separable convolution
        self.conv1 = conv3x3(
            in_channels=in_channels,
            out_channels=in_channels,
            stride=1,
            groups=in_channels,
            bias=not norm_layer,
        )
        self.norm1 = norm_layer(in_channels) if norm_layer else nn.Identity()
        self.act1 = activation

        self.conv2 = conv1x1(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=1,
            bias=not norm_layer,
        )
        self.norm2 = norm_layer(out_channels) if norm_layer else nn.Identity()
        self.act2 = activation

        if in_channels != out_channels:
            self.downsample = nn.Sequential(
                conv1x1(in_channels, out_channels, stride=1, bias=not norm_layer),
                norm_layer(out_channels) if norm_layer else nn.Identity(),
            )
        else:
            self.downsample = nn.Identity()

        self.apply(_init_weights)
    
    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.norm2(out)

        out += self.downsample(identity)
        out = self.act2(out)
        return out