dhruv2842 commited on
Commit
109c9ca
·
verified ·
1 Parent(s): d10a3b5

Update glam_module.py

Browse files
Files changed (1) hide show
  1. glam_module.py +75 -71
glam_module.py CHANGED
@@ -1,71 +1,75 @@
1
- class GLAM(nn.Module):
2
- """
3
- Global-Local Attention Module (GLAM) that produces a refined feature map.
4
- """
5
- def __init__(self, in_channels, reduction_ratio=8):
6
- super(GLAM, self).__init__()
7
-
8
- # --- Local Channel Attention ---
9
- self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
10
- self.local_channel_act = nn.Sigmoid()
11
- self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
12
-
13
- # --- Local Spatial Attention ---
14
- # 3-dilated, 5-dilated conv merges
15
- self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
16
- self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
17
- self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
18
- self.local_spatial_act = nn.Sigmoid()
19
-
20
- # --- Global Channel Attention ---
21
- self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
22
- self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
23
- self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
24
- self.global_channel_act = nn.Sigmoid()
25
-
26
- # --- Global Spatial Attention ---
27
- self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
28
- self.global_spatial_softmax = nn.Softmax(dim=-1)
29
-
30
-
31
- # --- Weighted paramerers initialization ---
32
- self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
33
- self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
34
-
35
-
36
- def forward(self, x):
37
- # Local Channel Attention
38
- lca = self.local_channel_conv(x)
39
- lca = self.local_channel_act(lca)
40
- lca = self.local_channel_expand(lca)
41
- lca_out = lca * x
42
-
43
- # Local Spatial Attention
44
- lsa3 = self.local_spatial_conv3(x)
45
- lsa5 = self.local_spatial_conv5(x)
46
- lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
47
- lsa = self.local_spatial_merge(lsa_cat)
48
- lsa = self.local_spatial_act(lsa)
49
- lsa_out = lsa * lca_out
50
- lsa_out = lsa_out + lca_out
51
-
52
- # Global Channel Attention
53
- B, C, H, W = x.size()
54
- gca = self.global_avg_pool(x).view(B, C)
55
- gca = F.relu(self.global_channel_fc1(gca), inplace=True)
56
- gca = self.global_channel_fc2(gca)
57
- gca = self.global_channel_act(gca)
58
- gca = gca.view(B, C, 1, 1)
59
- gca_out = gca * x
60
-
61
- # Global Spatial Attention
62
- gsa = self.global_spatial_conv(x) # [B, 1, H, W]
63
- gsa = gsa.view(B, -1) # [B, H*W]
64
- gsa = self.global_spatial_softmax(gsa)
65
- gsa = gsa.view(B, 1, H, W)
66
- gsa_out = gsa * gca_out
67
- gsa_out = gsa_out + gca_out
68
-
69
- # Fuse
70
- out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
71
- return out
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class GLAM(nn.Module):
6
+ """
7
+ Global-Local Attention Module (GLAM) that produces a refined feature map.
8
+ """
9
+ def __init__(self, in_channels, reduction_ratio=8):
10
+ super(GLAM, self).__init__()
11
+
12
+ # --- Local Channel Attention ---
13
+ self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
14
+ self.local_channel_act = nn.Sigmoid()
15
+ self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
16
+
17
+ # --- Local Spatial Attention ---
18
+ # 3-dilated, 5-dilated conv merges
19
+ self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
20
+ self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
21
+ self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
22
+ self.local_spatial_act = nn.Sigmoid()
23
+
24
+ # --- Global Channel Attention ---
25
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
26
+ self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
27
+ self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
28
+ self.global_channel_act = nn.Sigmoid()
29
+
30
+ # --- Global Spatial Attention ---
31
+ self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
32
+ self.global_spatial_softmax = nn.Softmax(dim=-1)
33
+
34
+
35
+ # --- Weighted paramerers initialization ---
36
+ self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
37
+ self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
38
+
39
+
40
+ def forward(self, x):
41
+ # Local Channel Attention
42
+ lca = self.local_channel_conv(x)
43
+ lca = self.local_channel_act(lca)
44
+ lca = self.local_channel_expand(lca)
45
+ lca_out = lca * x
46
+
47
+ # Local Spatial Attention
48
+ lsa3 = self.local_spatial_conv3(x)
49
+ lsa5 = self.local_spatial_conv5(x)
50
+ lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
51
+ lsa = self.local_spatial_merge(lsa_cat)
52
+ lsa = self.local_spatial_act(lsa)
53
+ lsa_out = lsa * lca_out
54
+ lsa_out = lsa_out + lca_out
55
+
56
+ # Global Channel Attention
57
+ B, C, H, W = x.size()
58
+ gca = self.global_avg_pool(x).view(B, C)
59
+ gca = F.relu(self.global_channel_fc1(gca), inplace=True)
60
+ gca = self.global_channel_fc2(gca)
61
+ gca = self.global_channel_act(gca)
62
+ gca = gca.view(B, C, 1, 1)
63
+ gca_out = gca * x
64
+
65
+ # Global Spatial Attention
66
+ gsa = self.global_spatial_conv(x) # [B, 1, H, W]
67
+ gsa = gsa.view(B, -1) # [B, H*W]
68
+ gsa = self.global_spatial_softmax(gsa)
69
+ gsa = gsa.view(B, 1, H, W)
70
+ gsa_out = gsa * gca_out
71
+ gsa_out = gsa_out + gca_out
72
+
73
+ # Fuse
74
+ out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
75
+ return out