dhruv2842 commited on
Commit
d6ee28e
Β·
verified Β·
1 Parent(s): 393c032

Update glam_efficientnet_model.py

Browse files
Files changed (1) hide show
  1. glam_efficientnet_model.py +95 -106
glam_efficientnet_model.py CHANGED
@@ -1,106 +1,95 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers import PreTrainedModel, PretrainedConfig, EfficientNetModel
5
- from typing import Optional, Union
6
-
7
- # --------------------------------------------------
8
- # Import your GLAM, SwinWindowAttention blocks here
9
- # --------------------------------------------------
10
- # from .glam_module import GLAM
11
- # from .swin_module import SwinWindowAttention
12
-
13
- from glam_module import GLAM
14
- from swin_module import SwinWindowAttention
15
-
16
- class GLAMEfficientNetConfig(PretrainedConfig):
17
- """Hugging Face-style configuration for GLAM EfficientNet."""
18
- model_type = "glam_efficientnet"
19
-
20
- def __init__(self,
21
- num_classes: int = 3,
22
- embed_dim: int = 512,
23
- num_heads: int = 8,
24
- window_size: int = 7,
25
- reduction_ratio: int = 8,
26
- dropout: float = 0.5,
27
- **kwargs):
28
- super().__init__(**kwargs)
29
- self.num_classes = num_classes
30
- self.embed_dim = embed_dim
31
- self.num_heads = num_heads
32
- self.window_size = window_size
33
- self.reduction_ratio = reduction_ratio
34
- self.dropout = dropout
35
-
36
-
37
- class GLAMEfficientNetForClassification(PreTrainedModel):
38
- """Hugging Face-style Model for EfficientNet + GLAM + Swin Architecture."""
39
- config_class = GLAMEfficientNetConfig
40
-
41
- def __init__(self, config: GLAMEfficientNetConfig, glam_module_cls, swin_module_cls):
42
- super().__init__(config)
43
-
44
- # βœ… 1) Hugging Face EfficientNet Backbone
45
- self.features = EfficientNetModel.from_pretrained("google/efficientnet-b0")
46
-
47
- # βœ… 1x1 conv for channel adjustment
48
- self.conv1x1 = nn.Conv2d(1280, config.embed_dim, kernel_size=1)
49
-
50
- # βœ… 2) Swin Attention Block
51
- self.swin_attn = swin_module_cls(
52
- embed_dim=config.embed_dim,
53
- window_size=config.window_size,
54
- num_heads=config.num_heads,
55
- dropout=config.dropout
56
- )
57
- self.pre_attn_norm = nn.LayerNorm(config.embed_dim)
58
- self.post_attn_norm = nn.LayerNorm(config.embed_dim)
59
-
60
- # βœ… 3) GLAM Block
61
- self.glam = glam_module_cls(in_channels=config.embed_dim, reduction_ratio=config.reduction_ratio)
62
-
63
- # βœ… 4) Self-Adaptive Gating
64
- self.gate_fc = nn.Linear(config.embed_dim, 1)
65
-
66
- # βœ… Final classification
67
- self.dropout = nn.Dropout(config.dropout)
68
- self.classifier = nn.Linear(config.embed_dim, config.num_classes)
69
-
70
- def forward(self, pixel_values, labels=None, **kwargs):
71
- """Perform forward pass."""
72
- # βœ… 1) EfficientNet Backbone
73
- backbone_output = self.features(pixel_values) # Returns BaseModelOutput
74
- feats = backbone_output.last_hidden_state # [B, C, H', W']
75
- feats = self.conv1x1(feats) # Adjust channel dims
76
- B, C, H, W = feats.shape
77
-
78
- # βœ… 2) Transformer Branch
79
- x_perm = feats.permute(0, 2, 3, 1).contiguous() # [B, H', W', C]
80
- x_norm = self.pre_attn_norm(x_perm).permute(0, 3, 1, 2).contiguous()
81
- x_norm = self.dropout(x_norm)
82
-
83
- T_out = self.swin_attn(x_norm) # [B, C, H', W']
84
-
85
- T_out = self.post_attn_norm(T_out.permute(0, 2, 3, 1).contiguous())
86
- T_out = T_out.permute(0, 3, 1, 2).contiguous()
87
-
88
- # βœ… 3) GLAM Branch
89
- G_out = self.glam(feats)
90
-
91
- # βœ… 4) Self-Adaptive Gating
92
- gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
93
- g = torch.sigmoid(self.gate_fc(gap_feats)).view(B, 1, 1, 1)
94
-
95
- F_out = g * T_out + (1 - g) * G_out
96
-
97
- # βœ… Final Pooling & Classifier
98
- pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
99
- logits = self.classifier(self.dropout(pooled))
100
-
101
- loss = None
102
- if labels is not None:
103
- loss = F.cross_entropy(logits, labels)
104
-
105
- return {"loss": loss, "logits": logits}
106
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+ from typing import Optional, Union
6
+
7
+ from glam_module import GLAM
8
+ from swin_module import SwinWindowAttention
9
+
10
+
11
+ class GLAMEfficientNetConfig:
12
+ """Hugging Face-style configuration for GLAM EfficientNet."""
13
+ def __init__(self,
14
+ num_classes: int = 3,
15
+ embed_dim: int = 512,
16
+ num_heads: int = 8,
17
+ window_size: int = 7,
18
+ reduction_ratio: int = 8,
19
+ dropout: float = 0.5,
20
+ **kwargs):
21
+ super().__init__(**kwargs)
22
+ self.num_classes = num_classes
23
+ self.embed_dim = embed_dim
24
+ self.num_heads = num_heads
25
+ self.window_size = window_size
26
+ self.reduction_ratio = reduction_ratio
27
+ self.dropout = dropout
28
+
29
+
30
+ class GLAMEfficientNetForClassification(nn.Module):
31
+ """EfficientNet (torchvision) + GLAM + Swin Architecture for Classification."""
32
+ def __init__(self, config: GLAMEfficientNetConfig, glam_module_cls, swin_module_cls):
33
+ super().__init__()
34
+
35
+ # βœ… 1) Torchvision EfficientNet Backbone
36
+ efficientnet = models.efficientnet_b0(pretrained=False) # No Hugging Face!
37
+ self.features = efficientnet.features
38
+
39
+ # βœ… 1x1 conv for channel adjustment
40
+ self.conv1x1 = nn.Conv2d(1280, config.embed_dim, kernel_size=1)
41
+
42
+ # βœ… 2) Swin Attention Block
43
+ self.swin_attn = swin_module_cls(
44
+ embed_dim=config.embed_dim,
45
+ window_size=config.window_size,
46
+ num_heads=config.num_heads,
47
+ dropout=config.dropout
48
+ )
49
+ self.pre_attn_norm = nn.LayerNorm(config.embed_dim)
50
+ self.post_attn_norm = nn.LayerNorm(config.embed_dim)
51
+
52
+ # βœ… 3) GLAM Block
53
+ self.glam = glam_module_cls(in_channels=config.embed_dim, reduction_ratio=config.reduction_ratio)
54
+
55
+ # βœ… 4) Self-Adaptive Gating
56
+ self.gate_fc = nn.Linear(config.embed_dim, 1)
57
+
58
+ # βœ… Final classification
59
+ self.dropout = nn.Dropout(config.dropout)
60
+ self.classifier = nn.Linear(config.embed_dim, config.num_classes)
61
+
62
+ def forward(self, pixel_values, labels=None, **kwargs):
63
+ """Perform forward pass."""
64
+ # βœ… 1) EfficientNet Backbone
65
+ feats = self.features(pixel_values) # [B, 1280, H', W']
66
+ feats = self.conv1x1(feats) # [B, embed_dim, H', W']
67
+ B, C, H, W = feats.shape
68
+
69
+ # βœ… 2) Transformer Branch
70
+ x_perm = feats.permute(0, 2, 3, 1).contiguous() # [B, H', W', C]
71
+ x_norm = self.pre_attn_norm(x_perm).permute(0, 3, 1, 2).contiguous()
72
+ x_norm = self.dropout(x_norm)
73
+
74
+ T_out = self.swin_attn(x_norm) # [B, C, H', W']
75
+ T_out = self.post_attn_norm(T_out.permute(0, 2, 3, 1).contiguous())
76
+ T_out = T_out.permute(0, 3, 1, 2).contiguous()
77
+
78
+ # βœ… 3) GLAM Branch
79
+ G_out = self.glam(feats)
80
+
81
+ # βœ… 4) Self-Adaptive Gating
82
+ gap_feats = F.adaptive_avg_pool2d(feats, (1, 1)).view(B, C)
83
+ g = torch.sigmoid(self.gate_fc(gap_feats)).view(B, 1, 1, 1)
84
+
85
+ F_out = g * T_out + (1 - g) * G_out
86
+
87
+ # βœ… Final Pooling & Classifier
88
+ pooled = F.adaptive_avg_pool2d(F_out, (1, 1)).view(B, -1)
89
+ logits = self.classifier(self.dropout(pooled))
90
+
91
+ loss = None
92
+ if labels is not None:
93
+ loss = F.cross_entropy(logits, labels)
94
+
95
+ return {"loss": loss, "logits": logits}