File size: 3,114 Bytes
109c9ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F

class GLAM(nn.Module):
    """
    Global-Local Attention Module (GLAM) that produces a refined feature map.
    """
    def __init__(self, in_channels, reduction_ratio=8):
        super(GLAM, self).__init__()
        
        # --- Local Channel Attention ---
        self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
        self.local_channel_act = nn.Sigmoid()
        self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
        
        # --- Local Spatial Attention ---
        # 3-dilated, 5-dilated conv merges
        self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
        self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
        self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
        self.local_spatial_act = nn.Sigmoid()
        
        # --- Global Channel Attention ---
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
        self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
        self.global_channel_act = nn.Sigmoid()
        
        # --- Global Spatial Attention ---
        self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.global_spatial_softmax = nn.Softmax(dim=-1)
        
        
        # --- Weighted paramerers initialization ---
        self.local_attention_weight = nn.Parameter(torch.tensor(1.0))  
        self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
        

    def forward(self, x):
        # Local Channel Attention
        lca = self.local_channel_conv(x)       
        lca = self.local_channel_act(lca)      
        lca = self.local_channel_expand(lca)   
        lca_out = lca * x                      

        # Local Spatial Attention
        lsa3 = self.local_spatial_conv3(x)
        lsa5 = self.local_spatial_conv5(x)
        lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
        lsa = self.local_spatial_merge(lsa_cat)
        lsa = self.local_spatial_act(lsa)
        lsa_out = lsa * lca_out
        lsa_out = lsa_out + lca_out

        # Global Channel Attention
        B, C, H, W = x.size()
        gca = self.global_avg_pool(x).view(B, C)  
        gca = F.relu(self.global_channel_fc1(gca), inplace=True)
        gca = self.global_channel_fc2(gca)
        gca = self.global_channel_act(gca)
        gca = gca.view(B, C, 1, 1)
        gca_out = gca * x

        # Global Spatial Attention
        gsa = self.global_spatial_conv(x)  # [B, 1, H, W]
        gsa = gsa.view(B, -1)              # [B, H*W]
        gsa = self.global_spatial_softmax(gsa)
        gsa = gsa.view(B, 1, H, W)
        gsa_out = gsa * gca_out
        gsa_out = gsa_out + gca_out

        # Fuse
        out =  lsa_out*self.local_attention_weight +  gsa_out*self.global_attention_weight + x
        return out