import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models # ================================================== # 1. SWIN WINDOW UTILS # ================================================== def window_partition(x, window_size): """Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C).""" B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() windows = x.view(-1, window_size * window_size, C) return windows def window_reverse(windows, window_size, H, W): """Reverses the window partition operation.""" B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.view(B, H, W, -1) return x # ================================================== # 2. SWIN WINDOW ATTENTION # ================================================== class SwinWindowAttention(nn.Module): """Swin-style window attention block.""" def __init__(self, embed_dim, window_size, num_heads, dropout=0.0): super(SwinWindowAttention, self).__init__() self.embed_dim = embed_dim self.window_size = window_size self.num_heads = num_heads self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) self.dropout = nn.Dropout(dropout) def forward(self, x): """Perform multi-head self-attn within windows.""" B, C, H, W = x.shape x = x.permute(0, 2, 3, 1).contiguous() pad_h = (self.window_size - H % self.window_size) % self.window_size pad_w = (self.window_size - W % self.window_size) % self.window_size if pad_h or pad_w: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = x.shape[1], x.shape[2] windows = window_partition(x, self.window_size) # (B*n_wins, win_size^2, C) attn_windows, _ = self.mha(windows, windows, windows) attn_windows = self.dropout(attn_windows) x = window_reverse(attn_windows, self.window_size, Hp, Wp) if pad_h or pad_w: x = x[:, :H, :W, :].contiguous() return x.permute(0, 3, 1, 2).contiguous() # ================================================== # 3. GLAM # ================================================== class GLAM(nn.Module): """Global-Local Attention Module (GLAM).""" def __init__(self, in_channels, reduction_ratio=8): super(GLAM, self).__init__() # Local Channel Attention self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1) self.local_channel_act = nn.Sigmoid() self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1) # Local Spatial Attention self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3) self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5) self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1) self.local_spatial_act = nn.Sigmoid() # Global Channel Attention self.global_avg_pool = nn.AdaptiveAvgPool2d(1) self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio) self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels) self.global_channel_act = nn.Sigmoid() # Global Spatial Attention self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1) self.global_spatial_softmax = nn.Softmax(dim=-1) # Weights self.local_attention_weight = nn.Parameter(torch.tensor(1.0)) self.global_attention_weight = nn.Parameter(torch.tensor(1.0)) def forward(self, x): # Local Channel Attention lca = self.local_channel_conv(x) lca = self.local_channel_act(lca) lca = self.local_channel_expand(lca) lca_out = lca * x # Local Spatial Attention lsa3 = self.local_spatial_conv3(x) lsa5 = self.local_spatial_conv5(x) lsa_cat = torch.cat([x, lsa3, lsa5], dim=1) lsa = self.local_spatial_merge(lsa_cat) lsa = self.local_spatial_act(lsa) lsa_out = lsa * lca_out lsa_out = lsa_out + lca_out # Global Channel Attention B, C, H, W = x.size() gca = self.global_avg_pool(x).view(B, C) gca = F.relu(self.global_channel_fc1(gca), inplace=True) gca = self.global_channel_fc2(gca) gca = self.global_channel_act(gca) gca = gca.view(B, C, 1, 1) gca_out = gca * x # Global Spatial Attention gsa = self.global_spatial_conv(x) # [B, 1, H, W] gsa = gsa.view(B, -1) # [B, H*W] gsa = self.global_spatial_softmax(gsa) gsa = gsa.view(B, 1, H, W) gsa_out = gsa * gca_out gsa_out = gsa_out + gca_out # Final Fusion out = lsa_out * self.local_attention_weight + gsa_out * self.global_attention_weight + x return out # ================================================== # 4. FUSION BLOCK # ================================================== class FusionBlock(nn.Module): """Combines Transformer and GLAM outputs using gating.""" def __init__(self): super(FusionBlock, self).__init__() def forward(self, g, T_out, G_out): """Perform final gating fusion.""" return g * T_out + (1 - g) * G_out # ================================================== # 5. EFFICIENTNETB0_TRANSFORMERGLAM # ================================================== class EfficientNetb0_TransformerGLAM(nn.Module): """EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating.""" def __init__(self, num_classes=3, embed_dim=512, num_heads=8, mlp_dim=512, dropout=0.5, window_size=7, reduction_ratio=8): super(EfficientNetb0_TransformerGLAM, self).__init__() # EfficientNet Backbone efficientnet = models.efficientnet_b0(weights=None) self.feature_extractor = nn.Sequential(*list(efficientnet.features.children())) self.conv1x1 = nn.Conv2d(1280, embed_dim, kernel_size=1) # Transformer path self.pre_attn_norm = nn.LayerNorm(embed_dim) self.swin_attn = SwinWindowAttention(embed_dim, window_size, num_heads, dropout) self.post_attn_norm = nn.LayerNorm(embed_dim) # GLAM path self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio) # Gating self.gate_fc = nn.Linear(embed_dim, 1) # Final Fusion self.fusion_block = FusionBlock() # Final classification self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(embed_dim, num_classes) def forward(self, x): # Backbone feats = self.feature_extractor(x) # [B, 1280, H', W'] feats = self.conv1x1(feats) # [B, embed_dim, H', W'] B, C, H, W = feats.shape # Transformer path x_perm = feats.permute(0, 2, 3, 1).contiguous() x_norm = self.pre_attn_norm(x_perm) # LN x_norm = x_norm.permute(0, 3, 1, 2).contiguous() x_norm = self.dropout(x_norm) T = self.swin_attn(x_norm) T_perm = T.permute(0, 2, 3, 1).contiguous() T_norm = self.post_attn_norm(T_perm) # LN T_out = T_norm.permute(0, 3, 1, 2).contiguous() # GLAM path G_out = self.glam(feats) # Gating gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C) g = torch.sigmoid(self.gate_fc(gap_feats)) g = g.view(B, 1, 1, 1) # Final Fusion F_out = self.fusion_block(g, T_out, G_out) # Save final feature map for Grad-CAM self.last_feature = F_out pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1) return self.fc(pooled)