HemanM commited on
Commit
c4358b8
·
verified ·
1 Parent(s): fb120ec

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +36 -0
evo_model.py CHANGED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class EvoTransformerV22(nn.Module):
6
+ def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
7
+ super().__init__()
8
+ self.embedding = nn.Embedding(30522, d_model)
9
+ self.memory_enabled = memory_enabled
10
+ self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model)) if memory_enabled else None
11
+
12
+ encoder_layer = nn.TransformerEncoderLayer(
13
+ d_model=d_model,
14
+ nhead=num_heads,
15
+ dim_feedforward=ffn_dim,
16
+ batch_first=True
17
+ )
18
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
19
+ self.pool = nn.AdaptiveAvgPool1d(1)
20
+
21
+ self.classifier = nn.Sequential(
22
+ nn.Linear(d_model, 128),
23
+ nn.ReLU(),
24
+ nn.Linear(128, 2) # Binary classification
25
+ )
26
+
27
+ def forward(self, input_ids):
28
+ x = self.embedding(input_ids)
29
+
30
+ if self.memory_enabled and self.memory_token is not None:
31
+ mem = self.memory_token.expand(x.size(0), 1, x.size(2))
32
+ x = torch.cat([mem, x], dim=1)
33
+
34
+ x = self.transformer(x)
35
+ x = self.pool(x.transpose(1, 2)).squeeze(-1)
36
+ return self.classifier(x)