dhruv2842 commited on
Commit
bc30c66
·
verified ·
1 Parent(s): 2b314ce

Update efficientnet_transformer_glam.py

Browse files
Files changed (1) hide show
  1. efficientnet_transformer_glam.py +201 -201
efficientnet_transformer_glam.py CHANGED
@@ -1,201 +1,201 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchvision.models as models
5
-
6
-
7
- # -------------------------------
8
- # 1. SWIN WINDOW UTILS
9
- # -------------------------------
10
- def window_partition(x, window_size):
11
- """Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
12
- B, H, W, C = x.shape
13
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
14
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
15
- windows = x.view(-1, window_size * window_size, C)
16
- return windows
17
-
18
- def window_reverse(windows, window_size, H, W):
19
- """Reverses the window partition operation."""
20
- B = int(windows.shape[0] / (H * W / window_size / window_size))
21
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
22
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
23
- x = x.view(B, H, W, -1)
24
- return x
25
-
26
-
27
- # -------------------------------
28
- # 2. SWIN WINDOW ATTENTION
29
- # -------------------------------
30
- class SwinWindowAttention(nn.Module):
31
- """Swin-style window attention block."""
32
- def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
33
- super(SwinWindowAttention, self).__init__()
34
- self.embed_dim = embed_dim
35
- self.window_size = window_size
36
- self.num_heads = num_heads
37
- self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
38
- self.dropout = nn.Dropout(dropout)
39
-
40
- def forward(self, x):
41
- """Perform multi-head self-attn within windows."""
42
- B, C, H, W = x.shape
43
- x = x.permute(0, 2, 3, 1).contiguous()
44
-
45
- pad_h = (self.window_size - H % self.window_size) % self.window_size
46
- pad_w = (self.window_size - W % self.window_size) % self.window_size
47
- if pad_h or pad_w:
48
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
49
- Hp, Wp = x.shape[1], x.shape[2]
50
-
51
- windows = window_partition(x, self.window_size) # (B*n_wins, win_size^2, C)
52
-
53
- attn_windows, _ = self.mha(windows, windows, windows)
54
- attn_windows = self.dropout(attn_windows)
55
-
56
- x = window_reverse(attn_windows, self.window_size, Hp, Wp)
57
-
58
- if pad_h or pad_w:
59
- x = x[:, :H, :W, :].contiguous()
60
-
61
- return x.permute(0, 3, 1, 2).contiguous()
62
-
63
-
64
- # -------------------------------
65
- # 3. GLAM
66
- # -------------------------------
67
- class GLAM(nn.Module):
68
- """Global-Local Attention Module (GLAM)."""
69
- def __init__(self, in_channels, reduction_ratio=8):
70
- super(GLAM, self).__init__()
71
-
72
- # Local Channel Attention
73
- self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
74
- self.local_channel_act = nn.Sigmoid()
75
- self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
76
-
77
- # Local Spatial Attention
78
- self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
79
- self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
80
- self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
81
- self.local_spatial_act = nn.Sigmoid()
82
-
83
- # Global Channel Attention
84
- self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
85
- self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
86
- self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
87
- self.global_channel_act = nn.Sigmoid()
88
-
89
- # Global Spatial Attention
90
- self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
91
- self.global_spatial_softmax = nn.Softmax(dim=-1)
92
-
93
- # Weights
94
- self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
95
- self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
96
-
97
- def forward(self, x):
98
- # Local Channel Attention
99
- lca = self.local_channel_conv(x)
100
- lca = self.local_channel_act(lca)
101
- lca = self.local_channel_expand(lca)
102
- lca_out = lca * x
103
-
104
- # Local Spatial Attention
105
- lsa3 = self.local_spatial_conv3(x)
106
- lsa5 = self.local_spatial_conv5(x)
107
- lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
108
- lsa = self.local_spatial_merge(lsa_cat)
109
- lsa = self.local_spatial_act(lsa)
110
- lsa_out = lsa * lca_out
111
- lsa_out = lsa_out + lca_out
112
-
113
- # Global Channel Attention
114
- B, C, H, W = x.size()
115
- gca = self.global_avg_pool(x).view(B, C)
116
- gca = F.relu(self.global_channel_fc1(gca), inplace=True)
117
- gca = self.global_channel_fc2(gca)
118
- gca = self.global_channel_act(gca)
119
- gca = gca.view(B, C, 1, 1)
120
- gca_out = gca * x
121
-
122
- # Global Spatial Attention
123
- gsa = self.global_spatial_conv(x) # [B, 1, H, W]
124
- gsa = gsa.view(B, -1) # [B, H*W]
125
- gsa = self.global_spatial_softmax(gsa)
126
- gsa = gsa.view(B, 1, H, W)
127
- gsa_out = gsa * gca_out
128
- gsa_out = gsa_out + gca_out
129
-
130
- # Final Fusion
131
- out = lsa_out * self.local_attention_weight + gsa_out * self.global_attention_weight + x
132
- return out
133
-
134
-
135
- # -------------------------------
136
- # 4. EFFICIENTNETB0_TRANSFORMERGLAM
137
- # -------------------------------
138
- class EfficientNetb0_TransformerGLAM(nn.Module):
139
- """EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
140
- def __init__(self,
141
- num_classes=3,
142
- embed_dim=512,
143
- num_heads=8,
144
- mlp_dim=512,
145
- dropout=0.5,
146
- window_size=7,
147
- reduction_ratio=8):
148
- super(EfficientNetb0_TransformerGLAM, self).__init__()
149
-
150
- # EfficientNet Backbone
151
- efficientnet = models.efficientnet_b0(pretrained=True)
152
- self.feature_extractor = nn.Sequential(*list(efficientnet.features.children()))
153
- self.conv1x1 = nn.Conv2d(1280, embed_dim, kernel_size=1)
154
-
155
- # Transformer path
156
- self.pre_attn_norm = nn.LayerNorm(embed_dim)
157
- self.swin_attn = SwinWindowAttention(embed_dim, window_size, num_heads, dropout)
158
- self.post_attn_norm = nn.LayerNorm(embed_dim)
159
-
160
- # GLAM path
161
- self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
162
-
163
- # Self-adaptive gating
164
- self.gate_fc = nn.Linear(embed_dim, 1)
165
-
166
- # Final classification
167
- self.dropout = nn.Dropout(dropout)
168
- self.fc = nn.Linear(embed_dim, num_classes)
169
-
170
- def forward(self, x):
171
- # Backbone
172
- feats = self.feature_extractor(x) # [B, 1280, H', W']
173
- feats = self.conv1x1(feats) # [B, embed_dim, H', W']
174
-
175
- B, C, H, W = feats.shape
176
-
177
- # Transformer path
178
- x_perm = feats.permute(0, 2, 3, 1).contiguous()
179
- x_norm = self.pre_attn_norm(x_perm) # LN
180
- x_norm = x_norm.permute(0, 3, 1, 2).contiguous()
181
- x_norm = self.dropout(x_norm)
182
-
183
- T = self.swin_attn(x_norm)
184
-
185
- T_perm = T.permute(0, 2, 3, 1).contiguous()
186
- T_norm = self.post_attn_norm(T_perm) # LN
187
- T_out = T_norm.permute(0, 3, 1, 2).contiguous()
188
-
189
- # GLAM path
190
- G_out = self.glam(feats)
191
-
192
- # Gating
193
- gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
194
- g = torch.sigmoid(self.gate_fc(gap_feats))
195
- g = g.view(B, 1, 1, 1)
196
-
197
- # Final Fusion
198
- F_out = g * T_out + (1 - g) * G_out
199
- pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
200
-
201
- return self.fc(pooled)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+
6
+
7
+ # -------------------------------
8
+ # 1. SWIN WINDOW UTILS
9
+ # -------------------------------
10
+ def window_partition(x, window_size):
11
+ """Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
12
+ B, H, W, C = x.shape
13
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
14
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
15
+ windows = x.view(-1, window_size * window_size, C)
16
+ return windows
17
+
18
+ def window_reverse(windows, window_size, H, W):
19
+ """Reverses the window partition operation."""
20
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
21
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
22
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
23
+ x = x.view(B, H, W, -1)
24
+ return x
25
+
26
+
27
+ # -------------------------------
28
+ # 2. SWIN WINDOW ATTENTION
29
+ # -------------------------------
30
+ class SwinWindowAttention(nn.Module):
31
+ """Swin-style window attention block."""
32
+ def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
33
+ super(SwinWindowAttention, self).__init__()
34
+ self.embed_dim = embed_dim
35
+ self.window_size = window_size
36
+ self.num_heads = num_heads
37
+ self.mha = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
38
+ self.dropout = nn.Dropout(dropout)
39
+
40
+ def forward(self, x):
41
+ """Perform multi-head self-attn within windows."""
42
+ B, C, H, W = x.shape
43
+ x = x.permute(0, 2, 3, 1).contiguous()
44
+
45
+ pad_h = (self.window_size - H % self.window_size) % self.window_size
46
+ pad_w = (self.window_size - W % self.window_size) % self.window_size
47
+ if pad_h or pad_w:
48
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
49
+ Hp, Wp = x.shape[1], x.shape[2]
50
+
51
+ windows = window_partition(x, self.window_size) # (B*n_wins, win_size^2, C)
52
+
53
+ attn_windows, _ = self.mha(windows, windows, windows)
54
+ attn_windows = self.dropout(attn_windows)
55
+
56
+ x = window_reverse(attn_windows, self.window_size, Hp, Wp)
57
+
58
+ if pad_h or pad_w:
59
+ x = x[:, :H, :W, :].contiguous()
60
+
61
+ return x.permute(0, 3, 1, 2).contiguous()
62
+
63
+
64
+ # -------------------------------
65
+ # 3. GLAM
66
+ # -------------------------------
67
+ class GLAM(nn.Module):
68
+ """Global-Local Attention Module (GLAM)."""
69
+ def __init__(self, in_channels, reduction_ratio=8):
70
+ super(GLAM, self).__init__()
71
+
72
+ # Local Channel Attention
73
+ self.local_channel_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
74
+ self.local_channel_act = nn.Sigmoid()
75
+ self.local_channel_expand = nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1)
76
+
77
+ # Local Spatial Attention
78
+ self.local_spatial_conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=3, dilation=3)
79
+ self.local_spatial_conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=5, dilation=5)
80
+ self.local_spatial_merge = nn.Conv2d(in_channels * 3, in_channels, kernel_size=1)
81
+ self.local_spatial_act = nn.Sigmoid()
82
+
83
+ # Global Channel Attention
84
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
85
+ self.global_channel_fc1 = nn.Linear(in_channels, in_channels // reduction_ratio)
86
+ self.global_channel_fc2 = nn.Linear(in_channels // reduction_ratio, in_channels)
87
+ self.global_channel_act = nn.Sigmoid()
88
+
89
+ # Global Spatial Attention
90
+ self.global_spatial_conv = nn.Conv2d(in_channels, 1, kernel_size=1)
91
+ self.global_spatial_softmax = nn.Softmax(dim=-1)
92
+
93
+ # Weights
94
+ self.local_attention_weight = nn.Parameter(torch.tensor(1.0))
95
+ self.global_attention_weight = nn.Parameter(torch.tensor(1.0))
96
+
97
+ def forward(self, x):
98
+ # Local Channel Attention
99
+ lca = self.local_channel_conv(x)
100
+ lca = self.local_channel_act(lca)
101
+ lca = self.local_channel_expand(lca)
102
+ lca_out = lca * x
103
+
104
+ # Local Spatial Attention
105
+ lsa3 = self.local_spatial_conv3(x)
106
+ lsa5 = self.local_spatial_conv5(x)
107
+ lsa_cat = torch.cat([x, lsa3, lsa5], dim=1)
108
+ lsa = self.local_spatial_merge(lsa_cat)
109
+ lsa = self.local_spatial_act(lsa)
110
+ lsa_out = lsa * lca_out
111
+ lsa_out = lsa_out + lca_out
112
+
113
+ # Global Channel Attention
114
+ B, C, H, W = x.size()
115
+ gca = self.global_avg_pool(x).view(B, C)
116
+ gca = F.relu(self.global_channel_fc1(gca), inplace=True)
117
+ gca = self.global_channel_fc2(gca)
118
+ gca = self.global_channel_act(gca)
119
+ gca = gca.view(B, C, 1, 1)
120
+ gca_out = gca * x
121
+
122
+ # Global Spatial Attention
123
+ gsa = self.global_spatial_conv(x) # [B, 1, H, W]
124
+ gsa = gsa.view(B, -1) # [B, H*W]
125
+ gsa = self.global_spatial_softmax(gsa)
126
+ gsa = gsa.view(B, 1, H, W)
127
+ gsa_out = gsa * gca_out
128
+ gsa_out = gsa_out + gca_out
129
+
130
+ # Final Fusion
131
+ out = lsa_out * self.local_attention_weight + gsa_out * self.global_attention_weight + x
132
+ return out
133
+
134
+
135
+ # -------------------------------
136
+ # 4. EFFICIENTNETB0_TRANSFORMERGLAM
137
+ # -------------------------------
138
+ class EfficientNetb0_TransformerGLAM(nn.Module):
139
+ """EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
140
+ def __init__(self,
141
+ num_classes=3,
142
+ embed_dim=512,
143
+ num_heads=8,
144
+ mlp_dim=512,
145
+ dropout=0.5,
146
+ window_size=7,
147
+ reduction_ratio=8):
148
+ super(EfficientNetb0_TransformerGLAM, self).__init__()
149
+
150
+ # EfficientNet Backbone
151
+ efficientnet = models.efficientnet_b0(weights=None))
152
+ self.feature_extractor = nn.Sequential(*list(efficientnet.features.children()))
153
+ self.conv1x1 = nn.Conv2d(1280, embed_dim, kernel_size=1)
154
+
155
+ # Transformer path
156
+ self.pre_attn_norm = nn.LayerNorm(embed_dim)
157
+ self.swin_attn = SwinWindowAttention(embed_dim, window_size, num_heads, dropout)
158
+ self.post_attn_norm = nn.LayerNorm(embed_dim)
159
+
160
+ # GLAM path
161
+ self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
162
+
163
+ # Self-adaptive gating
164
+ self.gate_fc = nn.Linear(embed_dim, 1)
165
+
166
+ # Final classification
167
+ self.dropout = nn.Dropout(dropout)
168
+ self.fc = nn.Linear(embed_dim, num_classes)
169
+
170
+ def forward(self, x):
171
+ # Backbone
172
+ feats = self.feature_extractor(x) # [B, 1280, H', W']
173
+ feats = self.conv1x1(feats) # [B, embed_dim, H', W']
174
+
175
+ B, C, H, W = feats.shape
176
+
177
+ # Transformer path
178
+ x_perm = feats.permute(0, 2, 3, 1).contiguous()
179
+ x_norm = self.pre_attn_norm(x_perm) # LN
180
+ x_norm = x_norm.permute(0, 3, 1, 2).contiguous()
181
+ x_norm = self.dropout(x_norm)
182
+
183
+ T = self.swin_attn(x_norm)
184
+
185
+ T_perm = T.permute(0, 2, 3, 1).contiguous()
186
+ T_norm = self.post_attn_norm(T_perm) # LN
187
+ T_out = T_norm.permute(0, 3, 1, 2).contiguous()
188
+
189
+ # GLAM path
190
+ G_out = self.glam(feats)
191
+
192
+ # Gating
193
+ gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
194
+ g = torch.sigmoid(self.gate_fc(gap_feats))
195
+ g = g.view(B, 1, 1, 1)
196
+
197
+ # Final Fusion
198
+ F_out = g * T_out + (1 - g) * G_out
199
+ pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
200
+
201
+ return self.fc(pooled)