dhruv2842 commited on
Commit
7af96b6
·
verified ·
1 Parent(s): 2a590c3

Update efficientnet_transformer_glam.py

Browse files
Files changed (1) hide show
  1. efficientnet_transformer_glam.py +7 -1
efficientnet_transformer_glam.py CHANGED
@@ -15,6 +15,7 @@ def window_partition(x, window_size):
15
  windows = x.view(-1, window_size * window_size, C)
16
  return windows
17
 
 
18
  def window_reverse(windows, window_size, H, W):
19
  """Reverses the window partition operation."""
20
  B = int(windows.shape[0] / (H * W / window_size / window_size))
@@ -163,6 +164,9 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
163
  # Self-adaptive gating
164
  self.gate_fc = nn.Linear(embed_dim, 1)
165
 
 
 
 
166
  # Final classification
167
  self.dropout = nn.Dropout(dropout)
168
  self.fc = nn.Linear(embed_dim, num_classes)
@@ -183,7 +187,7 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
183
  T = self.swin_attn(x_norm)
184
 
185
  T_perm = T.permute(0, 2, 3, 1).contiguous()
186
- T_norm = self.post_attn_norm(T_perm) # LN
187
  T_out = T_norm.permute(0, 3, 1, 2).contiguous()
188
 
189
  # GLAM path
@@ -196,6 +200,8 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
196
 
197
  # Final Fusion
198
  F_out = g * T_out + (1 - g) * G_out
 
199
  pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
200
 
201
  return self.fc(pooled)
 
 
15
  windows = x.view(-1, window_size * window_size, C)
16
  return windows
17
 
18
+
19
  def window_reverse(windows, window_size, H, W):
20
  """Reverses the window partition operation."""
21
  B = int(windows.shape[0] / (H * W / window_size / window_size))
 
164
  # Self-adaptive gating
165
  self.gate_fc = nn.Linear(embed_dim, 1)
166
 
167
+ # Final feature output
168
+ self.final_feature_layer = nn.Identity()
169
+
170
  # Final classification
171
  self.dropout = nn.Dropout(dropout)
172
  self.fc = nn.Linear(embed_dim, num_classes)
 
187
  T = self.swin_attn(x_norm)
188
 
189
  T_perm = T.permute(0, 2, 3, 1).contiguous()
190
+ T_norm = self.post_attn_norm(T_perm) # LN
191
  T_out = T_norm.permute(0, 3, 1, 2).contiguous()
192
 
193
  # GLAM path
 
200
 
201
  # Final Fusion
202
  F_out = g * T_out + (1 - g) * G_out
203
+ F_out = self.final_feature_layer(F_out) # ✅ Final feature map for Grad-CAM
204
  pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
205
 
206
  return self.fc(pooled)
207
+