Update efficientnet_transformer_glam.py
Browse files
efficientnet_transformer_glam.py
CHANGED
@@ -15,6 +15,7 @@ def window_partition(x, window_size):
|
|
15 |
windows = x.view(-1, window_size * window_size, C)
|
16 |
return windows
|
17 |
|
|
|
18 |
def window_reverse(windows, window_size, H, W):
|
19 |
"""Reverses the window partition operation."""
|
20 |
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
@@ -163,6 +164,9 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
|
|
163 |
# Self-adaptive gating
|
164 |
self.gate_fc = nn.Linear(embed_dim, 1)
|
165 |
|
|
|
|
|
|
|
166 |
# Final classification
|
167 |
self.dropout = nn.Dropout(dropout)
|
168 |
self.fc = nn.Linear(embed_dim, num_classes)
|
@@ -183,7 +187,7 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
|
|
183 |
T = self.swin_attn(x_norm)
|
184 |
|
185 |
T_perm = T.permute(0, 2, 3, 1).contiguous()
|
186 |
-
T_norm = self.post_attn_norm(T_perm)
|
187 |
T_out = T_norm.permute(0, 3, 1, 2).contiguous()
|
188 |
|
189 |
# GLAM path
|
@@ -196,6 +200,8 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
|
|
196 |
|
197 |
# Final Fusion
|
198 |
F_out = g * T_out + (1 - g) * G_out
|
|
|
199 |
pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
|
200 |
|
201 |
return self.fc(pooled)
|
|
|
|
15 |
windows = x.view(-1, window_size * window_size, C)
|
16 |
return windows
|
17 |
|
18 |
+
|
19 |
def window_reverse(windows, window_size, H, W):
|
20 |
"""Reverses the window partition operation."""
|
21 |
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
|
164 |
# Self-adaptive gating
|
165 |
self.gate_fc = nn.Linear(embed_dim, 1)
|
166 |
|
167 |
+
# Final feature output
|
168 |
+
self.final_feature_layer = nn.Identity()
|
169 |
+
|
170 |
# Final classification
|
171 |
self.dropout = nn.Dropout(dropout)
|
172 |
self.fc = nn.Linear(embed_dim, num_classes)
|
|
|
187 |
T = self.swin_attn(x_norm)
|
188 |
|
189 |
T_perm = T.permute(0, 2, 3, 1).contiguous()
|
190 |
+
T_norm = self.post_attn_norm(T_perm) # LN
|
191 |
T_out = T_norm.permute(0, 3, 1, 2).contiguous()
|
192 |
|
193 |
# GLAM path
|
|
|
200 |
|
201 |
# Final Fusion
|
202 |
F_out = g * T_out + (1 - g) * G_out
|
203 |
+
F_out = self.final_feature_layer(F_out) # ✅ Final feature map for Grad-CAM
|
204 |
pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
|
205 |
|
206 |
return self.fc(pooled)
|
207 |
+
|