File size: 2,742 Bytes
393c032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# -------------------------------
# 2. SWIN-STYLE TRANSFORMER UTILS
# -------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
def window_partition(x, window_size):
    """
    x: (B, H, W, C)
    Returns windows of shape: (num_windows*B, window_size*window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    # permute to gather patches
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    # merge dimension
    windows = x.view(-1, window_size * window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Reverse of window_partition.
    windows: (num_windows*B, window_size*window_size, C)
    Returns: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, H, W, -1)
    return x

class SwinWindowAttention(nn.Module):
    """
    A simplified Swin-like window attention block:
    1) Partition input into windows
    2) Perform multi-head self-attn
    3) Merge back
    """
    def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
        super(SwinWindowAttention, self).__init__()
        self.embed_dim = embed_dim
        self.window_size = window_size
        self.num_heads = num_heads
        
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x: (B, C, H, W) --> rearrange to (B, H, W, C)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).contiguous()
        
        # pad if needed so H, W are multiples of window_size
        pad_h = (self.window_size - H % self.window_size) % self.window_size
        pad_w = (self.window_size - W % self.window_size) % self.window_size
        if pad_h or pad_w:
            x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
        
        Hp, Wp = x.shape[1], x.shape[2]
        # Partition into windows
        windows = window_partition(x, self.window_size)  # shape: (num_windows*B, window_size*window_size, C)
        # Multi-head self-attn
        attn_windows, _ = self.mha(windows, windows, windows)
        attn_windows = self.dropout(attn_windows)
        
        # Reverse window partition
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)
        
        # Remove padding if added
        if pad_h or pad_w:
            x = x[:, :H, :W, :].contiguous()
        
        # back to (B, C, H, W)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x