# ------------------------------- # 2. SWIN-STYLE TRANSFORMER UTILS # ------------------------------- import torch import torch.nn as nn import torch.nn.functional as F def window_partition(x, window_size): """ x: (B, H, W, C) Returns windows of shape: (num_windows*B, window_size*window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) # permute to gather patches x = x.permute(0, 1, 3, 2, 4, 5).contiguous() # merge dimension windows = x.view(-1, window_size * window_size, C) return windows def window_reverse(windows, window_size, H, W): """ Reverse of window_partition. windows: (num_windows*B, window_size*window_size, C) Returns: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.view(B, H, W, -1) return x class SwinWindowAttention(nn.Module): """ A simplified Swin-like window attention block: 1) Partition input into windows 2) Perform multi-head self-attn 3) Merge back """ def __init__(self, embed_dim, window_size, num_heads, dropout=0.0): super(SwinWindowAttention, self).__init__() self.embed_dim = embed_dim self.window_size = window_size self.num_heads = num_heads self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) self.dropout = nn.Dropout(dropout) def forward(self, x): # x: (B, C, H, W) --> rearrange to (B, H, W, C) B, C, H, W = x.shape x = x.permute(0, 2, 3, 1).contiguous() # pad if needed so H, W are multiples of window_size pad_h = (self.window_size - H % self.window_size) % self.window_size pad_w = (self.window_size - W % self.window_size) % self.window_size if pad_h or pad_w: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = x.shape[1], x.shape[2] # Partition into windows windows = window_partition(x, self.window_size) # shape: (num_windows*B, window_size*window_size, C) # Multi-head self-attn attn_windows, _ = self.mha(windows, windows, windows) attn_windows = self.dropout(attn_windows) # Reverse window partition x = window_reverse(attn_windows, self.window_size, Hp, Wp) # Remove padding if added if pad_h or pad_w: x = x[:, :H, :W, :].contiguous() # back to (B, C, H, W) x = x.permute(0, 3, 1, 2).contiguous() return x