Update efficientnet_transformer_glam.py
Browse files- efficientnet_transformer_glam.py +30 -14
efficientnet_transformer_glam.py
CHANGED
@@ -4,9 +4,9 @@ import torch.nn.functional as F
|
|
4 |
import torchvision.models as models
|
5 |
|
6 |
|
7 |
-
#
|
8 |
# 1. SWIN WINDOW UTILS
|
9 |
-
#
|
10 |
def window_partition(x, window_size):
|
11 |
"""Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
|
12 |
B, H, W, C = x.shape
|
@@ -25,9 +25,9 @@ def window_reverse(windows, window_size, H, W):
|
|
25 |
return x
|
26 |
|
27 |
|
28 |
-
#
|
29 |
# 2. SWIN WINDOW ATTENTION
|
30 |
-
#
|
31 |
class SwinWindowAttention(nn.Module):
|
32 |
"""Swin-style window attention block."""
|
33 |
def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
|
@@ -62,9 +62,9 @@ class SwinWindowAttention(nn.Module):
|
|
62 |
return x.permute(0, 3, 1, 2).contiguous()
|
63 |
|
64 |
|
65 |
-
#
|
66 |
# 3. GLAM
|
67 |
-
#
|
68 |
class GLAM(nn.Module):
|
69 |
"""Global-Local Attention Module (GLAM)."""
|
70 |
def __init__(self, in_channels, reduction_ratio=8):
|
@@ -133,9 +133,22 @@ class GLAM(nn.Module):
|
|
133 |
return out
|
134 |
|
135 |
|
136 |
-
#
|
137 |
-
# 4.
|
138 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
class EfficientNetb0_TransformerGLAM(nn.Module):
|
140 |
"""EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
|
141 |
def __init__(self,
|
@@ -161,11 +174,11 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
|
|
161 |
# GLAM path
|
162 |
self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
|
163 |
|
164 |
-
#
|
165 |
self.gate_fc = nn.Linear(embed_dim, 1)
|
166 |
|
167 |
-
# Final
|
168 |
-
self.
|
169 |
|
170 |
# Final classification
|
171 |
self.dropout = nn.Dropout(dropout)
|
@@ -199,9 +212,12 @@ class EfficientNetb0_TransformerGLAM(nn.Module):
|
|
199 |
g = g.view(B, 1, 1, 1)
|
200 |
|
201 |
# Final Fusion
|
202 |
-
F_out = g
|
203 |
-
|
|
|
204 |
self.last_feature = F_out
|
|
|
205 |
pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
|
|
|
206 |
return self.fc(pooled)
|
207 |
|
|
|
4 |
import torchvision.models as models
|
5 |
|
6 |
|
7 |
+
# ==================================================
|
8 |
# 1. SWIN WINDOW UTILS
|
9 |
+
# ==================================================
|
10 |
def window_partition(x, window_size):
|
11 |
"""Partitions input tensor into windows of shape (B * num_windows, window_size*window_size, C)."""
|
12 |
B, H, W, C = x.shape
|
|
|
25 |
return x
|
26 |
|
27 |
|
28 |
+
# ==================================================
|
29 |
# 2. SWIN WINDOW ATTENTION
|
30 |
+
# ==================================================
|
31 |
class SwinWindowAttention(nn.Module):
|
32 |
"""Swin-style window attention block."""
|
33 |
def __init__(self, embed_dim, window_size, num_heads, dropout=0.0):
|
|
|
62 |
return x.permute(0, 3, 1, 2).contiguous()
|
63 |
|
64 |
|
65 |
+
# ==================================================
|
66 |
# 3. GLAM
|
67 |
+
# ==================================================
|
68 |
class GLAM(nn.Module):
|
69 |
"""Global-Local Attention Module (GLAM)."""
|
70 |
def __init__(self, in_channels, reduction_ratio=8):
|
|
|
133 |
return out
|
134 |
|
135 |
|
136 |
+
# ==================================================
|
137 |
+
# 4. FUSION BLOCK
|
138 |
+
# ==================================================
|
139 |
+
class FusionBlock(nn.Module):
|
140 |
+
"""Combines Transformer and GLAM outputs using gating."""
|
141 |
+
def __init__(self):
|
142 |
+
super(FusionBlock, self).__init__()
|
143 |
+
|
144 |
+
def forward(self, g, T_out, G_out):
|
145 |
+
"""Perform final gating fusion."""
|
146 |
+
return g * T_out + (1 - g) * G_out
|
147 |
+
|
148 |
+
|
149 |
+
# ==================================================
|
150 |
+
# 5. EFFICIENTNETB0_TRANSFORMERGLAM
|
151 |
+
# ==================================================
|
152 |
class EfficientNetb0_TransformerGLAM(nn.Module):
|
153 |
"""EfficientNet-B0 + Swin-style Transformer + GLAM + Self-Adaptive Gating."""
|
154 |
def __init__(self,
|
|
|
174 |
# GLAM path
|
175 |
self.glam = GLAM(in_channels=embed_dim, reduction_ratio=reduction_ratio)
|
176 |
|
177 |
+
# Gating
|
178 |
self.gate_fc = nn.Linear(embed_dim, 1)
|
179 |
|
180 |
+
# Final Fusion
|
181 |
+
self.fusion_block = FusionBlock()
|
182 |
|
183 |
# Final classification
|
184 |
self.dropout = nn.Dropout(dropout)
|
|
|
212 |
g = g.view(B, 1, 1, 1)
|
213 |
|
214 |
# Final Fusion
|
215 |
+
F_out = self.fusion_block(g, T_out, G_out)
|
216 |
+
|
217 |
+
# Save final feature map for Grad-CAM
|
218 |
self.last_feature = F_out
|
219 |
+
|
220 |
pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
|
221 |
+
|
222 |
return self.fc(pooled)
|
223 |
|