glaucoma / efficientnet_transformer_glam.py
dhruv2842's picture
Update efficientnet_transformer_glam.py
f270b46 verified
raw
history blame
7.57 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# -------------------------------
# 1. SWIN WINDOW UTILS
# -------------------------------
def window_partition(x, window_size):
"""Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = x.view(-1, window_size * window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""Reverses the window partition operation."""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H, W, -1)
return x
# -------------------------------
# 2. SWIN WINDOW ATTENTION
# -------------------------------
class SwinWindowAttention(nn.Module):
"""Swin-style window attention block."""
def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
super(SwinWindowAttention, self).__init__()
self.embed_dim = embed_dim
self.window_size = window_size
self.num_heads = num_heads
self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""Perform multi-head self-attn within windows."""
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).contiguous()
pad_h = (self.window_size - H % self.window_size) % self.window_size
pad_w = (self.window_size - W % self.window_size) % self.window_size
if pad_h or pad_w:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = x.shape[1], x.shape[2]
windows = window_partition(x, self.window_size) # (B*n_wins, win_size^2, C)
attn_windows, _ = self.mha(windows, windows, windows)
attn_windows = self.dropout(attn_windows)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
if pad_h or pad_w:
x = x[:, :H, :W, :].contiguous()
return x.permute(0, 3, 1, 2).contiguous()
# -------------------------------
# 3. GLAM
# -------------------------------
class GLAM(nn.Module):
"""Global-Local Attention Module (GLAM)."""
def __init__(self, in_channels, reduction_ratio=8):
super(GLAM, self).__init__()
# Local Channel Attention
self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
self.local_channel_act = nn.Sigmoid()
self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
# Local Spatial Attention
self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
self.local_spatial_act = nn.Sigmoid()
# Global Channel Attention
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
self.global_channel_act = nn.Sigmoid()
# Global Spatial Attention
self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
self.global_spatial_softmax = nn.Softmax(dim=-1)
# Weights
self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
# Local Channel Attention
lca = self.local_channel_conv(x)
lca = self.local_channel_act(lca)
lca = self.local_channel_expand(lca)
lca_out = lca * x
# Local Spatial Attention
lsa3 = self.local_spatial_conv3(x)
lsa5 = self.local_spatial_conv5(x)
lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
lsa = self.local_spatial_merge(lsa_cat)
lsa = self.local_spatial_act(lsa)
lsa_out = lsa * lca_out
lsa_out = lsa_out + lca_out
# Global Channel Attention
B, C, H, W = x.size()
gca = self.global_avg_pool(x).view(B, C)
gca = F.relu(self.global_channel_fc1(gca), inplace=True)
gca = self.global_channel_fc2(gca)
gca = self.global_channel_act(gca)
gca = gca.view(B, C, 1, 1)
gca_out = gca * x
# Global Spatial Attention
gsa = self.global_spatial_conv(x) # [B, 1, H, W]
gsa = gsa.view(B, -1) # [B, H*W]
gsa = self.global_spatial_softmax(gsa)
gsa = gsa.view(B, 1, H, W)
gsa_out = gsa * gca_out
gsa_out = gsa_out + gca_out
# Final Fusion
out = lsa_out * self.local_attention_weight + gsa_out * self.global_attention_weight + x
return out
# -------------------------------
# 4. EFFICIENTNETB0_TRANSFORMERGLAM
# -------------------------------
class EfficientNetb0_TransformerGLAM(nn.Module):
"""EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
def __init__(self,
num_classes=3,
embed_dim=512,
num_heads=8,
mlp_dim=512,
dropout=0.5,
window_size=7,
reduction_ratio=8):
super(EfficientNetb0_TransformerGLAM, self).__init__()
# EfficientNet Backbone
efficientnet = models.efficientnet_b0(weights=None)
self.feature_extractor = nn.Sequential(*list(efficientnet.features.children()))
self.conv1x1 = nn.Conv2d(1280, embed_dim, kernel_size=1)
# Transformer path
self.pre_attn_norm = nn.LayerNorm(embed_dim)
self.swin_attn = SwinWindowAttention(embed_dim, window_size, num_heads, dropout)
self.post_attn_norm = nn.LayerNorm(embed_dim)
# GLAM path
self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
# Self-adaptive gating
self.gate_fc = nn.Linear(embed_dim, 1)
# Final classification
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
# Backbone
feats = self.feature_extractor(x) # [B, 1280, H', W']
feats = self.conv1x1(feats) # [B, embed_dim, H', W']
B, C, H, W = feats.shape
# Transformer path
x_perm = feats.permute(0, 2, 3, 1).contiguous()
x_norm = self.pre_attn_norm(x_perm) # LN
x_norm = x_norm.permute(0, 3, 1, 2).contiguous()
x_norm = self.dropout(x_norm)
T = self.swin_attn(x_norm)
T_perm = T.permute(0, 2, 3, 1).contiguous()
T_norm = self.post_attn_norm(T_perm) # LN
T_out = T_norm.permute(0, 3, 1, 2).contiguous()
# GLAM path
G_out = self.glam(feats)
# Gating
gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
g = torch.sigmoid(self.gate_fc(gap_feats))
g = g.view(B, 1, 1, 1)
# Final Fusion
F_out = g * T_out + (1 - g) * G_out
pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
return self.fc(pooled)