dhruv2842 commited on
Commit
393c032
·
verified ·
1 Parent(s): 109c9ca

Update swin_module.py

Browse files
Files changed (1) hide show
  1. swin_module.py +75 -72
swin_module.py CHANGED
@@ -1,72 +1,75 @@
1
- # -------------------------------
2
- # 2. SWIN-STYLE TRANSFORMER UTILS
3
- # -------------------------------
4
- def window_partition(x, window_size):
5
- """
6
- x: (B, H, W, C)
7
- Returns windows of shape: (num_windows*B, window_size*window_size, C)
8
- """
9
- B, H, W, C = x.shape
10
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
11
- # permute to gather patches
12
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
13
- # merge dimension
14
- windows = x.view(-1, window_size * window_size, C)
15
- return windows
16
-
17
- def window_reverse(windows, window_size, H, W):
18
- """
19
- Reverse of window_partition.
20
- windows: (num_windows*B, window_size*window_size, C)
21
- Returns: (B, H, W, C)
22
- """
23
- B = int(windows.shape[0] / (H * W / window_size / window_size))
24
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
25
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
26
- x = x.view(B, H, W, -1)
27
- return x
28
-
29
- class SwinWindowAttention(nn.Module):
30
- """
31
- A simplified Swin-like window attention block:
32
- 1) Partition input into windows
33
- 2) Perform multi-head self-attn
34
- 3) Merge back
35
- """
36
- def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
37
- super(SwinWindowAttention, self).__init__()
38
- self.embed_dim = embed_dim
39
- self.window_size = window_size
40
- self.num_heads = num_heads
41
-
42
- self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
43
- self.dropout = nn.Dropout(dropout)
44
-
45
- def forward(self, x):
46
- # x: (B, C, H, W) --> rearrange to (B, H, W, C)
47
- B, C, H, W = x.shape
48
- x = x.permute(0, 2, 3, 1).contiguous()
49
-
50
- # pad if needed so H, W are multiples of window_size
51
- pad_h = (self.window_size - H % self.window_size) % self.window_size
52
- pad_w = (self.window_size - W % self.window_size) % self.window_size
53
- if pad_h or pad_w:
54
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
55
-
56
- Hp, Wp = x.shape[1], x.shape[2]
57
- # Partition into windows
58
- windows = window_partition(x, self.window_size) # shape: (num_windows*B, window_size*window_size, C)
59
- # Multi-head self-attn
60
- attn_windows, _ = self.mha(windows, windows, windows)
61
- attn_windows = self.dropout(attn_windows)
62
-
63
- # Reverse window partition
64
- x = window_reverse(attn_windows, self.window_size, Hp, Wp)
65
-
66
- # Remove padding if added
67
- if pad_h or pad_w:
68
- x = x[:, :H, :W, :].contiguous()
69
-
70
- # back to (B, C, H, W)
71
- x = x.permute(0, 3, 1, 2).contiguous()
72
- return x
 
 
 
 
1
+ # -------------------------------
2
+ # 2. SWIN-STYLE TRANSFORMER UTILS
3
+ # -------------------------------
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ def window_partition(x, window_size):
8
+ """
9
+ x: (B, H, W, C)
10
+ Returns windows of shape: (num_windows*B, window_size*window_size, C)
11
+ """
12
+ B, H, W, C = x.shape
13
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
14
+ # permute to gather patches
15
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
16
+ # merge dimension
17
+ windows = x.view(-1, window_size * window_size, C)
18
+ return windows
19
+
20
+ def window_reverse(windows, window_size, H, W):
21
+ """
22
+ Reverse of window_partition.
23
+ windows: (num_windows*B, window_size*window_size, C)
24
+ Returns: (B, H, W, C)
25
+ """
26
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
27
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
28
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
29
+ x = x.view(B, H, W, -1)
30
+ return x
31
+
32
+ class SwinWindowAttention(nn.Module):
33
+ """
34
+ A simplified Swin-like window attention block:
35
+ 1) Partition input into windows
36
+ 2) Perform multi-head self-attn
37
+ 3) Merge back
38
+ """
39
+ def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
40
+ super(SwinWindowAttention, self).__init__()
41
+ self.embed_dim = embed_dim
42
+ self.window_size = window_size
43
+ self.num_heads = num_heads
44
+
45
+ self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
46
+ self.dropout = nn.Dropout(dropout)
47
+
48
+ def forward(self, x):
49
+ # x: (B, C, H, W) --> rearrange to (B, H, W, C)
50
+ B, C, H, W = x.shape
51
+ x = x.permute(0, 2, 3, 1).contiguous()
52
+
53
+ # pad if needed so H, W are multiples of window_size
54
+ pad_h = (self.window_size - H % self.window_size) % self.window_size
55
+ pad_w = (self.window_size - W % self.window_size) % self.window_size
56
+ if pad_h or pad_w:
57
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
58
+
59
+ Hp, Wp = x.shape[1], x.shape[2]
60
+ # Partition into windows
61
+ windows = window_partition(x, self.window_size) # shape: (num_windows*B, window_size*window_size, C)
62
+ # Multi-head self-attn
63
+ attn_windows, _ = self.mha(windows, windows, windows)
64
+ attn_windows = self.dropout(attn_windows)
65
+
66
+ # Reverse window partition
67
+ x = window_reverse(attn_windows, self.window_size, Hp, Wp)
68
+
69
+ # Remove padding if added
70
+ if pad_h or pad_w:
71
+ x = x[:, :H, :W, :].contiguous()
72
+
73
+ # back to (B, C, H, W)
74
+ x = x.permute(0, 3, 1, 2).contiguous()
75
+ return x