|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.models as models |
|
|
|
|
|
|
|
|
|
|
|
def window_partition(x, window_size): |
|
"""Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C).""" |
|
B, H, W, C = x.shape |
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
windows = x.view(-1, window_size * window_size, C) |
|
return windows |
|
|
|
def window_reverse(windows, window_size, H, W): |
|
"""Reverses the window partition operation.""" |
|
B = int(windows.shape[0] / (H * W / window_size / window_size)) |
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
x = x.view(B, H, W, -1) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class SwinWindowAttention(nn.Module): |
|
"""Swin-style window attention block.""" |
|
def __init__(self, embed_dim, window_size, num_heads, dropout=0.0): |
|
super(SwinWindowAttention, self).__init__() |
|
self.embed_dim = embed_dim |
|
self.window_size = window_size |
|
self.num_heads = num_heads |
|
self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x): |
|
"""Perform multi-head self-attn within windows.""" |
|
B, C, H, W = x.shape |
|
x = x.permute(0, 2, 3, 1).contiguous() |
|
|
|
pad_h = (self.window_size - H % self.window_size) % self.window_size |
|
pad_w = (self.window_size - W % self.window_size) % self.window_size |
|
if pad_h or pad_w: |
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
|
Hp, Wp = x.shape[1], x.shape[2] |
|
|
|
windows = window_partition(x, self.window_size) |
|
|
|
attn_windows, _ = self.mha(windows, windows, windows) |
|
attn_windows = self.dropout(attn_windows) |
|
|
|
x = window_reverse(attn_windows, self.window_size, Hp, Wp) |
|
|
|
if pad_h or pad_w: |
|
x = x[:, :H, :W, :].contiguous() |
|
|
|
return x.permute(0, 3, 1, 2).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
class GLAM(nn.Module): |
|
"""Global-Local Attention Module (GLAM).""" |
|
def __init__(self, in_channels, reduction_ratio=8): |
|
super(GLAM, self).__init__() |
|
|
|
|
|
self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1) |
|
self.local_channel_act = nn.Sigmoid() |
|
self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1) |
|
|
|
|
|
self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3) |
|
self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5) |
|
self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1) |
|
self.local_spatial_act = nn.Sigmoid() |
|
|
|
|
|
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio) |
|
self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels) |
|
self.global_channel_act = nn.Sigmoid() |
|
|
|
|
|
self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1) |
|
self.global_spatial_softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
self.local_attention_weight = nn.Parameter(torch.tensor(1.0)) |
|
self.global_attention_weight = nn.Parameter(torch.tensor(1.0)) |
|
|
|
def forward(self, x): |
|
|
|
lca = self.local_channel_conv(x) |
|
lca = self.local_channel_act(lca) |
|
lca = self.local_channel_expand(lca) |
|
lca_out = lca * x |
|
|
|
|
|
lsa3 = self.local_spatial_conv3(x) |
|
lsa5 = self.local_spatial_conv5(x) |
|
lsa_cat = torch.cat([x, lsa3, lsa5], dim=1) |
|
lsa = self.local_spatial_merge(lsa_cat) |
|
lsa = self.local_spatial_act(lsa) |
|
lsa_out = lsa * lca_out |
|
lsa_out = lsa_out + lca_out |
|
|
|
|
|
B, C, H, W = x.size() |
|
gca = self.global_avg_pool(x).view(B, C) |
|
gca = F.relu(self.global_channel_fc1(gca), inplace=True) |
|
gca = self.global_channel_fc2(gca) |
|
gca = self.global_channel_act(gca) |
|
gca = gca.view(B, C, 1, 1) |
|
gca_out = gca * x |
|
|
|
|
|
gsa = self.global_spatial_conv(x) |
|
gsa = gsa.view(B, -1) |
|
gsa = self.global_spatial_softmax(gsa) |
|
gsa = gsa.view(B, 1, H, W) |
|
gsa_out = gsa * gca_out |
|
gsa_out = gsa_out + gca_out |
|
|
|
|
|
out = lsa_out * self.local_attention_weight + gsa_out * self.global_attention_weight + x |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class EfficientNetb0_TransformerGLAM(nn.Module): |
|
"""EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating.""" |
|
def __init__(self, |
|
num_classes=3, |
|
embed_dim=512, |
|
num_heads=8, |
|
mlp_dim=512, |
|
dropout=0.5, |
|
window_size=7, |
|
reduction_ratio=8): |
|
super(EfficientNetb0_TransformerGLAM, self).__init__() |
|
|
|
|
|
efficientnet = models.efficientnet_b0(weights=None) |
|
self.feature_extractor = nn.Sequential(*list(efficientnet.features.children())) |
|
self.conv1x1 = nn.Conv2d(1280, embed_dim, kernel_size=1) |
|
|
|
|
|
self.pre_attn_norm = nn.LayerNorm(embed_dim) |
|
self.swin_attn = SwinWindowAttention(embed_dim, window_size, num_heads, dropout) |
|
self.post_attn_norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio) |
|
|
|
|
|
self.gate_fc = nn.Linear(embed_dim, 1) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.fc = nn.Linear(embed_dim, num_classes) |
|
|
|
def forward(self, x): |
|
|
|
feats = self.feature_extractor(x) |
|
feats = self.conv1x1(feats) |
|
|
|
B, C, H, W = feats.shape |
|
|
|
|
|
x_perm = feats.permute(0, 2, 3, 1).contiguous() |
|
x_norm = self.pre_attn_norm(x_perm) |
|
x_norm = x_norm.permute(0, 3, 1, 2).contiguous() |
|
x_norm = self.dropout(x_norm) |
|
|
|
T = self.swin_attn(x_norm) |
|
|
|
T_perm = T.permute(0, 2, 3, 1).contiguous() |
|
T_norm = self.post_attn_norm(T_perm) |
|
T_out = T_norm.permute(0, 3, 1, 2).contiguous() |
|
|
|
|
|
G_out = self.glam(feats) |
|
|
|
|
|
gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C) |
|
g = torch.sigmoid(self.gate_fc(gap_feats)) |
|
g = g.view(B, 1, 1, 1) |
|
|
|
|
|
F_out = g * T_out + (1 - g) * G_out |
|
pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1) |
|
|
|
return self.fc(pooled) |
|
|