File size: 3,114 Bytes
109c9ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class GLAM(nn.Module):
"""
Global-Local Attention Module (GLAM) that produces a refined feature map.
"""
def __init__(self, in_channels, reduction_ratio=8):
super(GLAM, self).__init__()
# --- Local Channel Attention ---
self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
self.local_channel_act = nn.Sigmoid()
self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
# --- Local Spatial Attention ---
# 3-dilated, 5-dilated conv merges
self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
self.local_spatial_act = nn.Sigmoid()
# --- Global Channel Attention ---
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
self.global_channel_act = nn.Sigmoid()
# --- Global Spatial Attention ---
self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
self.global_spatial_softmax = nn.Softmax(dim=-1)
# --- Weighted paramerers initialization ---
self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
# Local Channel Attention
lca = self.local_channel_conv(x)
lca = self.local_channel_act(lca)
lca = self.local_channel_expand(lca)
lca_out = lca * x
# Local Spatial Attention
lsa3 = self.local_spatial_conv3(x)
lsa5 = self.local_spatial_conv5(x)
lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
lsa = self.local_spatial_merge(lsa_cat)
lsa = self.local_spatial_act(lsa)
lsa_out = lsa * lca_out
lsa_out = lsa_out + lca_out
# Global Channel Attention
B, C, H, W = x.size()
gca = self.global_avg_pool(x).view(B, C)
gca = F.relu(self.global_channel_fc1(gca), inplace=True)
gca = self.global_channel_fc2(gca)
gca = self.global_channel_act(gca)
gca = gca.view(B, C, 1, 1)
gca_out = gca * x
# Global Spatial Attention
gsa = self.global_spatial_conv(x) # [B, 1, H, W]
gsa = gsa.view(B, -1) # [B, H*W]
gsa = self.global_spatial_softmax(gsa)
gsa = gsa.view(B, 1, H, W)
gsa_out = gsa * gca_out
gsa_out = gsa_out + gca_out
# Fuse
out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
return out
|