|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
def window_partition(x, window_size): |
|
""" |
|
x: (B, H, W, C) |
|
Returns windows of shape: (num_windows*B, window_size*window_size, C) |
|
""" |
|
B, H, W, C = x.shape |
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
|
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
|
|
windows = x.view(-1, window_size * window_size, C) |
|
return windows |
|
|
|
def window_reverse(windows, window_size, H, W): |
|
""" |
|
Reverse of window_partition. |
|
windows: (num_windows*B, window_size*window_size, C) |
|
Returns: (B, H, W, C) |
|
""" |
|
B = int(windows.shape[0] / (H * W / window_size / window_size)) |
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
x = x.view(B, H, W, -1) |
|
return x |
|
|
|
class SwinWindowAttention(nn.Module): |
|
""" |
|
A simplified Swin-like window attention block: |
|
1) Partition input into windows |
|
2) Perform multi-head self-attn |
|
3) Merge back |
|
""" |
|
def __init__(self, embed_dim, window_size, num_heads, dropout=0.0): |
|
super(SwinWindowAttention, self).__init__() |
|
self.embed_dim = embed_dim |
|
self.window_size = window_size |
|
self.num_heads = num_heads |
|
|
|
self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
|
|
B, C, H, W = x.shape |
|
x = x.permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
pad_h = (self.window_size - H % self.window_size) % self.window_size |
|
pad_w = (self.window_size - W % self.window_size) % self.window_size |
|
if pad_h or pad_w: |
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
|
|
|
Hp, Wp = x.shape[1], x.shape[2] |
|
|
|
windows = window_partition(x, self.window_size) |
|
|
|
attn_windows, _ = self.mha(windows, windows, windows) |
|
attn_windows = self.dropout(attn_windows) |
|
|
|
|
|
x = window_reverse(attn_windows, self.window_size, Hp, Wp) |
|
|
|
|
|
if pad_h or pad_w: |
|
x = x[:, :H, :W, :].contiguous() |
|
|
|
|
|
x = x.permute(0, 3, 1, 2).contiguous() |
|
return x |
|
|