dhruv2842 commited on
Commit
7903c70
·
verified ·
1 Parent(s): 86809b9

Update efficientnet_transformer_glam.py

Browse files
Files changed (1) hide show
  1. efficientnet_transformer_glam.py +30 -14
efficientnet_transformer_glam.py CHANGED
@@ -4,9 +4,9 @@ import torch.nn.functional as F
4
  import torchvision.models as models
5
 
6
 
7
- # -------------------------------
8
  # 1. SWIN WINDOW UTILS
9
- # -------------------------------
10
  def window_partition(x, window_size):
11
  """Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
12
  B, H, W, C = x.shape
@@ -25,9 +25,9 @@ def window_reverse(windows, window_size, H, W):
25
  return x
26
 
27
 
28
- # -------------------------------
29
  # 2. SWIN WINDOW ATTENTION
30
- # -------------------------------
31
  class SwinWindowAttention(nn.Module):
32
  """Swin-style window attention block."""
33
  def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
@@ -62,9 +62,9 @@ class SwinWindowAttention(nn.Module):
62
  return x.permute(0, 3, 1, 2).contiguous()
63
 
64
 
65
- # -------------------------------
66
  # 3. GLAM
67
- # -------------------------------
68
  class GLAM(nn.Module):
69
  """Global-Local Attention Module (GLAM)."""
70
  def __init__(self, in_channels, reduction_ratio=8):
@@ -133,9 +133,22 @@ class GLAM(nn.Module):
133
  return out
134
 
135
 
136
- # -------------------------------
137
- # 4. EFFICIENTNETB0_TRANSFORMERGLAM
138
- # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  class EfficientNetb0_TransformerGLAM(nn.Module):
140
  """EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
141
  def __init__(self,
@@ -161,11 +174,11 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
161
  # GLAM path
162
  self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
163
 
164
- # Self-adaptive gating
165
  self.gate_fc = nn.Linear(embed_dim, 1)
166
 
167
- # Final feature output
168
- self.final_feature_layer = nn.Identity()
169
 
170
  # Final classification
171
  self.dropout = nn.Dropout(dropout)
@@ -199,9 +212,12 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
199
  g = g.view(B, 1, 1, 1)
200
 
201
  # Final Fusion
202
- F_out = g * T_out + (1 - g) * G_out
203
- # ✅ Save the spatial feature map
 
204
  self.last_feature = F_out
 
205
  pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
 
206
  return self.fc(pooled)
207
 
 
4
  import torchvision.models as models
5
 
6
 
7
+ # ==================================================
8
  # 1. SWIN WINDOW UTILS
9
+ # ==================================================
10
  def window_partition(x, window_size):
11
  """Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
12
  B, H, W, C = x.shape
 
25
  return x
26
 
27
 
28
+ # ==================================================
29
  # 2. SWIN WINDOW ATTENTION
30
+ # ==================================================
31
  class SwinWindowAttention(nn.Module):
32
  """Swin-style window attention block."""
33
  def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
 
62
  return x.permute(0, 3, 1, 2).contiguous()
63
 
64
 
65
+ # ==================================================
66
  # 3. GLAM
67
+ # ==================================================
68
  class GLAM(nn.Module):
69
  """Global-Local Attention Module (GLAM)."""
70
  def __init__(self, in_channels, reduction_ratio=8):
 
133
  return out
134
 
135
 
136
+ # ==================================================
137
+ # 4. FUSION BLOCK
138
+ # ==================================================
139
+ class FusionBlock(nn.Module):
140
+ """Combines Transformer and GLAM outputs using gating."""
141
+ def __init__(self):
142
+ super(FusionBlock, self).__init__()
143
+
144
+ def forward(self, g, T_out, G_out):
145
+ """Perform final gating fusion."""
146
+ return g * T_out + (1 - g) * G_out
147
+
148
+
149
+ # ==================================================
150
+ # 5. EFFICIENTNETB0_TRANSFORMERGLAM
151
+ # ==================================================
152
  class EfficientNetb0_TransformerGLAM(nn.Module):
153
  """EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
154
  def __init__(self,
 
174
  # GLAM path
175
  self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
176
 
177
+ # Gating
178
  self.gate_fc = nn.Linear(embed_dim, 1)
179
 
180
+ # Final Fusion
181
+ self.fusion_block = FusionBlock()
182
 
183
  # Final classification
184
  self.dropout = nn.Dropout(dropout)
 
212
  g = g.view(B, 1, 1, 1)
213
 
214
  # Final Fusion
215
+ F_out = self.fusion_block(g, T_out, G_out)
216
+
217
+ # Save final feature map for Grad-CAM
218
  self.last_feature = F_out
219
+
220
  pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
221
+
222
  return self.fc(pooled)
223