dhruv2842 commited on
Commit
c764113
·
verified ·
1 Parent(s): c35b02c

Update densenet_withglam.py

Browse files
Files changed (1) hide show
  1. densenet_withglam.py +82 -81
densenet_withglam.py CHANGED
@@ -1,82 +1,83 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- class GLAM(nn.Module):
7
- """
8
- Global-Local Attention Module (GLAM) that produces a refined feature map.
9
- """
10
- def __init__(self, in_channels, reduction_ratio=8):
11
- super(GLAM, self).__init__()
12
-
13
- # --- Local Channel Attention ---
14
- self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
15
- self.local_channel_act = nn.Sigmoid()
16
- self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
17
-
18
- # --- Local Spatial Attention ---
19
- # 3-dilated, 5-dilated conv merges
20
- self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
21
- self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
22
- self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
23
- self.local_spatial_act = nn.Sigmoid()
24
-
25
- # --- Global Channel Attention ---
26
- self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
27
- self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
28
- self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
29
- self.global_channel_act = nn.Sigmoid()
30
-
31
- # --- Global Spatial Attention ---
32
- self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
33
- self.global_spatial_softmax = nn.Softmax(dim=-1)
34
-
35
-
36
- # --- Weighted paramerers initialization ---
37
- self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
38
- self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
39
-
40
-
41
- def forward(self, x):
42
- # Local Channel Attention
43
- lca = self.local_channel_conv(x)
44
- lca = self.local_channel_act(lca)
45
- lca = self.local_channel_expand(lca)
46
- lca_out = lca * x
47
-
48
- # Local Spatial Attention
49
- lsa3 = self.local_spatial_conv3(x)
50
- lsa5 = self.local_spatial_conv5(x)
51
- lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
52
- lsa = self.local_spatial_merge(lsa_cat)
53
- lsa = self.local_spatial_act(lsa)
54
- lsa_out = lsa * lca_out
55
- lsa_out = lsa_out + lca_out
56
-
57
- # Global Channel Attention
58
- B, C, H, W = x.size()
59
- gca = self.global_avg_pool(x).view(B, C)
60
- gca = F.relu(self.global_channel_fc1(gca), inplace=True)
61
- gca = self.global_channel_fc2(gca)
62
- gca = self.global_channel_act(gca)
63
- gca = gca.view(B, C, 1, 1)
64
- gca_out = gca * x
65
-
66
- # Global Spatial Attention
67
- gsa = self.global_spatial_conv(x) # [B, 1, H, W]
68
- gsa = gsa.view(B, -1) # [B, H*W]
69
- gsa = self.global_spatial_softmax(gsa)
70
- gsa = gsa.view(B, 1, H, W)
71
- gsa_out = gsa * gca_out
72
- gsa_out = gsa_out + gca_out
73
-
74
- # Fuse
75
- out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
76
- return out
77
- def get_model_with_attention(model_name, num_classes):
78
- if model_name == 'densenet169':
79
- model = models.densenet169(pretrained=True)
80
- in_channels = model.classifier.in_features
81
- model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels))
 
82
  model.classifier = nn.Linear(in_channels, num_classes)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+
6
+
7
+ class GLAM(nn.Module):
8
+ """
9
+ Global-Local Attention Module (GLAM) that produces a refined feature map.
10
+ """
11
+ def __init__(self, in_channels, reduction_ratio=8):
12
+ super(GLAM, self).__init__()
13
+
14
+ # --- Local Channel Attention ---
15
+ self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
16
+ self.local_channel_act = nn.Sigmoid()
17
+ self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
18
+
19
+ # --- Local Spatial Attention ---
20
+ # 3-dilated, 5-dilated conv merges
21
+ self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
22
+ self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
23
+ self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
24
+ self.local_spatial_act = nn.Sigmoid()
25
+
26
+ # --- Global Channel Attention ---
27
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
28
+ self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
29
+ self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
30
+ self.global_channel_act = nn.Sigmoid()
31
+
32
+ # --- Global Spatial Attention ---
33
+ self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
34
+ self.global_spatial_softmax = nn.Softmax(dim=-1)
35
+
36
+
37
+ # --- Weighted paramerers initialization ---
38
+ self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
39
+ self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
40
+
41
+
42
+ def forward(self, x):
43
+ # Local Channel Attention
44
+ lca = self.local_channel_conv(x)
45
+ lca = self.local_channel_act(lca)
46
+ lca = self.local_channel_expand(lca)
47
+ lca_out = lca * x
48
+
49
+ # Local Spatial Attention
50
+ lsa3 = self.local_spatial_conv3(x)
51
+ lsa5 = self.local_spatial_conv5(x)
52
+ lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
53
+ lsa = self.local_spatial_merge(lsa_cat)
54
+ lsa = self.local_spatial_act(lsa)
55
+ lsa_out = lsa * lca_out
56
+ lsa_out = lsa_out + lca_out
57
+
58
+ # Global Channel Attention
59
+ B, C, H, W = x.size()
60
+ gca = self.global_avg_pool(x).view(B, C)
61
+ gca = F.relu(self.global_channel_fc1(gca), inplace=True)
62
+ gca = self.global_channel_fc2(gca)
63
+ gca = self.global_channel_act(gca)
64
+ gca = gca.view(B, C, 1, 1)
65
+ gca_out = gca * x
66
+
67
+ # Global Spatial Attention
68
+ gsa = self.global_spatial_conv(x) # [B, 1, H, W]
69
+ gsa = gsa.view(B, -1) # [B, H*W]
70
+ gsa = self.global_spatial_softmax(gsa)
71
+ gsa = gsa.view(B, 1, H, W)
72
+ gsa_out = gsa * gca_out
73
+ gsa_out = gsa_out + gca_out
74
+
75
+ # Fuse
76
+ out = lsa_out*self.local_attention_weight + gsa_out*self.global_attention_weight + x
77
+ return out
78
+ def get_model_with_attention(model_name, num_classes):
79
+ if model_name == 'densenet169':
80
+ model = models.densenet169(pretrained=True)
81
+ in_channels = model.classifier.in_features
82
+ model.features = nn.Sequential(model.features, nn.ReLU(inplace=True), GLAM(in_channels))
83
  model.classifier = nn.Linear(in_channels, num_classes)