glaucoma / glam_module.py
dhruv2842's picture
Update glam_module.py
109c9ca verified
raw
history blame
3.11 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class GLAM(nn.Module):
"""
Global-Local Attention Module (GLAM) that produces a refined feature map.
"""
def __init__(self, in_channels, reduction_ratio=8):
super(GLAM, self).__init__()
# --- Local Channel Attention ---
self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
self.local_channel_act = nn.Sigmoid()
self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
# --- Local Spatial Attention ---
# 3-dilated, 5-dilated conv merges
self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
self.local_spatial_act = nn.Sigmoid()
# --- Global Channel Attention ---
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
self.global_channel_act = nn.Sigmoid()
# --- Global Spatial Attention ---
self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
self.global_spatial_softmax = nn.Softmax(dim=-1)
# --- Weighted paramerers initialization ---
self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
# Local Channel Attention
lca = self.local_channel_conv(x)
lca = self.local_channel_act(lca)
lca = self.local_channel_expand(lca)
lca_out = lca * x
# Local Spatial Attention
lsa3 = self.local_spatial_conv3(x)
lsa5 = self.local_spatial_conv5(x)
lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
lsa = self.local_spatial_merge(lsa_cat)
lsa = self.local_spatial_act(lsa)
lsa_out = lsa * lca_out
lsa_out = lsa_out + lca_out
# Global Channel Attention
B, C, H, W = x.size()
gca = self.global_avg_pool(x).view(B, C)
gca = F.relu(self.global_channel_fc1(gca), inplace=True)
gca = self.global_channel_fc2(gca)
gca = self.global_channel_act(gca)
gca = gca.view(B, C, 1, 1)
gca_out = gca * x
# Global Spatial Attention
gsa = self.global_spatial_conv(x) # [B, 1, H, W]
gsa = gsa.view(B, -1) # [B, H*W]
gsa = self.global_spatial_softmax(gsa)
gsa = gsa.view(B, 1, H, W)
gsa_out = gsa * gca_out
gsa_out = gsa_out + gca_out
# Fuse
out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
return out