|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class GLAM(nn.Module): |
|
""" |
|
Global-Local Attention Module (GLAM) that produces a refined feature map. |
|
""" |
|
def __init__(self, in_channels, reduction_ratio=8): |
|
super(GLAM, self).__init__() |
|
|
|
|
|
self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1) |
|
self.local_channel_act = nn.Sigmoid() |
|
self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1) |
|
|
|
|
|
|
|
self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3) |
|
self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5) |
|
self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1) |
|
self.local_spatial_act = nn.Sigmoid() |
|
|
|
|
|
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio) |
|
self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels) |
|
self.global_channel_act = nn.Sigmoid() |
|
|
|
|
|
self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1) |
|
self.global_spatial_softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
|
|
self.local_attention_weight = nn.Parameter(torch.tensor(1.0)) |
|
self.global_attention_weight = nn.Parameter(torch.tensor(1.0)) |
|
|
|
|
|
def forward(self, x): |
|
|
|
lca = self.local_channel_conv(x) |
|
lca = self.local_channel_act(lca) |
|
lca = self.local_channel_expand(lca) |
|
lca_out = lca * x |
|
|
|
|
|
lsa3 = self.local_spatial_conv3(x) |
|
lsa5 = self.local_spatial_conv5(x) |
|
lsa_cat = torch.cat([x, lsa3, lsa5], dim=1) |
|
lsa = self.local_spatial_merge(lsa_cat) |
|
lsa = self.local_spatial_act(lsa) |
|
lsa_out = lsa * lca_out |
|
lsa_out = lsa_out + lca_out |
|
|
|
|
|
B, C, H, W = x.size() |
|
gca = self.global_avg_pool(x).view(B, C) |
|
gca = F.relu(self.global_channel_fc1(gca), inplace=True) |
|
gca = self.global_channel_fc2(gca) |
|
gca = self.global_channel_act(gca) |
|
gca = gca.view(B, C, 1, 1) |
|
gca_out = gca * x |
|
|
|
|
|
gsa = self.global_spatial_conv(x) |
|
gsa = gsa.view(B, -1) |
|
gsa = self.global_spatial_softmax(gsa) |
|
gsa = gsa.view(B, 1, H, W) |
|
gsa_out = gsa * gca_out |
|
gsa_out = gsa_out + gca_out |
|
|
|
|
|
out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x |
|
return out |
|
|