File size: 4,589 Bytes
a7dedf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from torch import nn, Tensor
from typing import List
from einops import rearrange

from .blocks import conv3x3, conv1x1, Conv2dLayerNorm, _init_weights


class MultiScale(nn.Module):
    def __init__(
        self,
        channels: int,
        scales: List[int],
        heads: int = 8,
        groups: int = 1,
        mlp_ratio: float = 4.0,
    ) -> None:
        super().__init__()
        assert channels > 0, "channels should be a positive integer"
        assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all([scale > 0 for scale in scales]), "scales should be a list or tuple of positive integers"
        assert heads > 0 and channels % heads == 0, "heads should be a positive integer and channels should be divisible by heads"
        assert groups > 0 and channels % groups == 0, "groups should be a positive integer and channels should be divisible by groups"
        scales = sorted(scales)
        self.scales = scales
        self.num_scales = len(scales) + 1  # +1 for the original feature map
        self.heads = heads
        self.groups = groups

        # modules that generate multi-scale feature maps
        self.scale_0 = nn.Sequential(
            conv1x1(channels, channels, stride=1, bias=False),
            Conv2dLayerNorm(channels),
            nn.GELU(),
        )
        for scale in scales:
            setattr(self, f"conv_{scale}", nn.Sequential(
                conv3x3(
                    in_channels=channels,
                    out_channels=channels,
                    stride=1,
                    groups=groups,
                    dilation=scale,
                    bias=False,
                ),
                conv1x1(channels, channels, stride=1, bias=False) if groups > 1 else nn.Identity(),
                Conv2dLayerNorm(channels),
                nn.GELU(),
            ))
        
        # modules that fuse multi-scale feature maps
        self.norm_attn = Conv2dLayerNorm(channels)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_scales + 1, channels, 1, 1) / channels ** 0.5)
        self.to_q = conv1x1(channels, channels, stride=1, bias=False)
        self.to_k = conv1x1(channels, channels, stride=1, bias=False)
        self.to_v = conv1x1(channels, channels, stride=1, bias=False)

        self.scale = (channels // heads) ** -0.5

        self.attend = nn.Softmax(dim=-1)
        
        self.to_out = conv1x1(channels, channels, stride=1)

        # modules that refine multi-scale feature maps
        self.norm_mlp = Conv2dLayerNorm(channels)
        self.mlp = nn.Sequential(
            conv1x1(channels, channels * mlp_ratio, stride=1),
            nn.GELU(),
            conv1x1(channels * mlp_ratio, channels, stride=1),
        )
        
        self.apply(_init_weights)
    
    def _forward_attn(self, x: Tensor) -> Tensor:
        assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
        x = [self.scale_0(x)] + [getattr(self, f"conv_{scale}")(x) for scale in self.scales]

        x = torch.stack(x, dim=1)  # (B, S, C, H, W)
        x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1)  # (B, S+1, C, H, W)
        x = x + self.pos_embed  # (B, S+1, C, H, W)

        x = rearrange(x, "B S C H W -> (B S) C H W")  # (B*(S+1), C, H, W)
        x = self.norm_attn(x)  # (B*(S+1), C, H, W)
        x = rearrange(x, "(B S) C H W -> B S C H W", S=self.num_scales + 1)  # (B, S+1, C, H, W)
        
        q = self.to_q(x[:, 0])  # (B, C, H, W)
        k = self.to_k(rearrange(x, "B S C H W -> (B S) C H W"))
        v = self.to_v(rearrange(x, "B S C H W -> (B S) C H W"))

        q = rearrange(q, "B (h d) H W -> B h H W 1 d", h=self.heads)
        k = rearrange(k, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads)
        v = rearrange(v, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads)

        attn = q @ k.transpose(-2, -1) * self.scale  # (B, h, H, W, 1, S+1)
        attn = self.attend(attn)  # (B, h, H, W, 1, S+1)
        out = attn @ v  # (B, h, H, W, 1, d)

        out = rearrange(out, "B h H W 1 d -> B (h d) H W")  # (B, C, H, W)

        out = self.to_out(out)  # (B, C, H, W)
        return out
    
    def _forward_mlp(self, x: Tensor) -> Tensor:
        assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}"
        x = self.norm_mlp(x)
        x = self.mlp(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        x = x + self._forward_attn(x)
        x = x + self._forward_mlp(x)
        return x