glaucoma / swin_module.py
dhruv2842's picture
Update swin_module.py
393c032 verified
raw
history blame
2.74 kB
# -------------------------------
# 2. SWIN-STYLE TRANSFORMER UTILS
# -------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
def window_partition(x, window_size):
"""
x: (B, H, W, C)
Returns windows of shape: (num_windows*B, window_size*window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
# permute to gather patches
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
# merge dimension
windows = x.view(-1, window_size * window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Reverse of window_partition.
windows: (num_windows*B, window_size*window_size, C)
Returns: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H, W, -1)
return x
class SwinWindowAttention(nn.Module):
"""
A simplified Swin-like window attention block:
1) Partition input into windows
2) Perform multi-head self-attn
3) Merge back
"""
def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
super(SwinWindowAttention, self).__init__()
self.embed_dim = embed_dim
self.window_size = window_size
self.num_heads = num_heads
self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# x: (B, C, H, W) --> rearrange to (B, H, W, C)
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).contiguous()
# pad if needed so H, W are multiples of window_size
pad_h = (self.window_size - H % self.window_size) % self.window_size
pad_w = (self.window_size - W % self.window_size) % self.window_size
if pad_h or pad_w:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = x.shape[1], x.shape[2]
# Partition into windows
windows = window_partition(x, self.window_size) # shape: (num_windows*B, window_size*window_size, C)
# Multi-head self-attn
attn_windows, _ = self.mha(windows, windows, windows)
attn_windows = self.dropout(attn_windows)
# Reverse window partition
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
# Remove padding if added
if pad_h or pad_w:
x = x[:, :H, :W, :].contiguous()
# back to (B, C, H, W)
x = x.permute(0, 3, 1, 2).contiguous()
return x