Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
class GLAM(nn.Module): | |
""" | |
Global-Local Attention Module (GLAM) that produces a refined feature map. | |
""" | |
def __init__(self, in_channels, reduction_ratio=8): | |
super(GLAM, self).__init__() | |
# --- Local Channel Attention --- | |
self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1) | |
self.local_channel_act = nn.Sigmoid() | |
self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1) | |
# --- Local Spatial Attention --- | |
# 3-dilated, 5-dilated conv merges | |
self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3) | |
self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5) | |
self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1) | |
self.local_spatial_act = nn.Sigmoid() | |
# --- Global Channel Attention --- | |
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio) | |
self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels) | |
self.global_channel_act = nn.Sigmoid() | |
# --- Global Spatial Attention --- | |
self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1) | |
self.global_spatial_softmax = nn.Softmax(dim=-1) | |
# --- Weighted paramerers initialization --- | |
self.local_attention_weight = nn.Parameter(torch.tensor(1.0)) | |
self.global_attention_weight = nn.Parameter(torch.tensor(1.0)) | |
def forward(self, x): | |
# Local Channel Attention | |
lca = self.local_channel_conv(x) | |
lca = self.local_channel_act(lca) | |
lca = self.local_channel_expand(lca) | |
lca_out = lca * x | |
# Local Spatial Attention | |
lsa3 = self.local_spatial_conv3(x) | |
lsa5 = self.local_spatial_conv5(x) | |
lsa_cat = torch.cat([x, lsa3, lsa5], dim=1) | |
lsa = self.local_spatial_merge(lsa_cat) | |
lsa = self.local_spatial_act(lsa) | |
lsa_out = lsa * lca_out | |
lsa_out = lsa_out + lca_out | |
# Global Channel Attention | |
B, C, H, W = x.size() | |
gca = self.global_avg_pool(x).view(B, C) | |
gca = F.relu(self.global_channel_fc1(gca), inplace=True) | |
gca = self.global_channel_fc2(gca) | |
gca = self.global_channel_act(gca) | |
gca = gca.view(B, C, 1, 1) | |
gca_out = gca * x | |
# Global Spatial Attention | |
gsa = self.global_spatial_conv(x) # [B, 1, H, W] | |
gsa = gsa.view(B, -1) # [B, H*W] | |
gsa = self.global_spatial_softmax(gsa) | |
gsa = gsa.view(B, 1, H, W) | |
gsa_out = gsa * gca_out | |
gsa_out = gsa_out + gca_out | |
# Fuse | |
out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x | |
return out | |
def get_model_with_attention(model_name, num_classes): | |
if model_name == 'densenet169': | |
model = models.densenet169(pretrained=False) | |
in_channels = model.classifier.in_features | |
model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels)) | |
model.classifier = nn.Linear(in_channels, num_classes) | |
return model |