dhruv2842 commited on
Commit
c54a3f6
·
verified ·
1 Parent(s): 0ac5b89

Upload densenet_withglam.py

Browse files
Files changed (1) hide show
  1. densenet_withglam.py +82 -0
densenet_withglam.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class GLAM(nn.Module):
7
+ """
8
+ Global-Local Attention Module (GLAM) that produces a refined feature map.
9
+ """
10
+ def __init__(self, in_channels, reduction_ratio=8):
11
+ super(GLAM, self).__init__()
12
+
13
+ # --- Local Channel Attention ---
14
+ self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
15
+ self.local_channel_act = nn.Sigmoid()
16
+ self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
17
+
18
+ # --- Local Spatial Attention ---
19
+ # 3-dilated, 5-dilated conv merges
20
+ self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
21
+ self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
22
+ self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
23
+ self.local_spatial_act = nn.Sigmoid()
24
+
25
+ # --- Global Channel Attention ---
26
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
27
+ self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
28
+ self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
29
+ self.global_channel_act = nn.Sigmoid()
30
+
31
+ # --- Global Spatial Attention ---
32
+ self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
33
+ self.global_spatial_softmax = nn.Softmax(dim=-1)
34
+
35
+
36
+ # --- Weighted paramerers initialization ---
37
+ self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
38
+ self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
39
+
40
+
41
+ def forward(self, x):
42
+ # Local Channel Attention
43
+ lca = self.local_channel_conv(x)
44
+ lca = self.local_channel_act(lca)
45
+ lca = self.local_channel_expand(lca)
46
+ lca_out = lca * x
47
+
48
+ # Local Spatial Attention
49
+ lsa3 = self.local_spatial_conv3(x)
50
+ lsa5 = self.local_spatial_conv5(x)
51
+ lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
52
+ lsa = self.local_spatial_merge(lsa_cat)
53
+ lsa = self.local_spatial_act(lsa)
54
+ lsa_out = lsa * lca_out
55
+ lsa_out = lsa_out + lca_out
56
+
57
+ # Global Channel Attention
58
+ B, C, H, W = x.size()
59
+ gca = self.global_avg_pool(x).view(B, C)
60
+ gca = F.relu(self.global_channel_fc1(gca), inplace=True)
61
+ gca = self.global_channel_fc2(gca)
62
+ gca = self.global_channel_act(gca)
63
+ gca = gca.view(B, C, 1, 1)
64
+ gca_out = gca * x
65
+
66
+ # Global Spatial Attention
67
+ gsa = self.global_spatial_conv(x) # [B, 1, H, W]
68
+ gsa = gsa.view(B, -1) # [B, H*W]
69
+ gsa = self.global_spatial_softmax(gsa)
70
+ gsa = gsa.view(B, 1, H, W)
71
+ gsa_out = gsa * gca_out
72
+ gsa_out = gsa_out + gca_out
73
+
74
+ # Fuse
75
+ out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
76
+ return out
77
+ def get_model_with_attention(model_name, num_classes):
78
+ if model_name == 'densenet169':
79
+ model = models.densenet169(pretrained=True)
80
+ in_channels = model.classifier.in_features
81
+ model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels))
82
+ model.classifier = nn.Linear(in_channels, num_classes)