import torch import torch.nn as nn import torch.nn.functional as F class GLAM(nn.Module): """ Global-Local Attention Module (GLAM) that produces a refined feature map. """ def __init__(self, in_channels, reduction_ratio=8): super(GLAM, self).__init__() # --- Local Channel Attention --- self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1) self.local_channel_act = nn.Sigmoid() self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1) # --- Local Spatial Attention --- # 3-dilated, 5-dilated conv merges self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3) self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5) self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1) self.local_spatial_act = nn.Sigmoid() # --- Global Channel Attention --- self.global_avg_pool = nn.AdaptiveAvgPool2d(1) self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio) self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels) self.global_channel_act = nn.Sigmoid() # --- Global Spatial Attention --- self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1) self.global_spatial_softmax = nn.Softmax(dim=-1) # --- Weighted paramerers initialization --- self.local_attention_weight = nn.Parameter(torch.tensor(1.0)) self.global_attention_weight = nn.Parameter(torch.tensor(1.0)) def forward(self, x): # Local Channel Attention lca = self.local_channel_conv(x) lca = self.local_channel_act(lca) lca = self.local_channel_expand(lca) lca_out = lca * x # Local Spatial Attention lsa3 = self.local_spatial_conv3(x) lsa5 = self.local_spatial_conv5(x) lsa_cat = torch.cat([x, lsa3, lsa5], dim=1) lsa = self.local_spatial_merge(lsa_cat) lsa = self.local_spatial_act(lsa) lsa_out = lsa * lca_out lsa_out = lsa_out + lca_out # Global Channel Attention B, C, H, W = x.size() gca = self.global_avg_pool(x).view(B, C) gca = F.relu(self.global_channel_fc1(gca), inplace=True) gca = self.global_channel_fc2(gca) gca = self.global_channel_act(gca) gca = gca.view(B, C, 1, 1) gca_out = gca * x # Global Spatial Attention gsa = self.global_spatial_conv(x) # [B, 1, H, W] gsa = gsa.view(B, -1) # [B, H*W] gsa = self.global_spatial_softmax(gsa) gsa = gsa.view(B, 1, H, W) gsa_out = gsa * gca_out gsa_out = gsa_out + gca_out # Fuse out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x return out