HemanM commited on
Commit
a040341
Β·
verified Β·
1 Parent(s): 5f733b4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -0
model.py CHANGED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # βœ… model.py
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class EvoTransformerBlock(nn.Module):
6
+ def __init__(self, d_model, nhead, dim_feedforward):
7
+ super().__init__()
8
+ self.layer = nn.TransformerEncoderLayer(
9
+ d_model=d_model,
10
+ nhead=nhead,
11
+ dim_feedforward=dim_feedforward,
12
+ batch_first=True
13
+ )
14
+
15
+ def forward(self, x):
16
+ return self.layer(x)
17
+
18
+ class EvoTransformer(nn.Module):
19
+ def __init__(self, vocab_size, d_model=256, nhead=4, dim_feedforward=512, num_layers=4):
20
+ super().__init__()
21
+ self.embedding = nn.Embedding(vocab_size, d_model)
22
+ self.encoder = nn.Sequential(*[
23
+ EvoTransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)
24
+ ])
25
+ self.pooler = nn.AdaptiveAvgPool1d(1)
26
+ self.classifier = nn.Sequential(
27
+ nn.Linear(d_model, d_model // 2),
28
+ nn.ReLU(),
29
+ nn.Linear(d_model // 2, 2)
30
+ )
31
+
32
+ def forward(self, x):
33
+ x = self.embedding(x)
34
+ x = self.encoder(x)
35
+ x = self.pooler(x.transpose(1, 2)).squeeze(-1)
36
+ return self.classifier(x)