HemanM commited on
Commit
d73ff46
·
verified ·
1 Parent(s): 5eef845

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +23 -10
evo_model.py CHANGED
@@ -2,12 +2,16 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- class EvoTransformerV22(nn.Module):
6
- def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  self.memory_enabled = memory_enabled
10
- self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model)) if memory_enabled else None
 
 
 
 
11
 
12
  encoder_layer = nn.TransformerEncoderLayer(
13
  d_model=d_model,
@@ -16,13 +20,6 @@ class EvoTransformerV22(nn.Module):
16
  batch_first=True
17
  )
18
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
19
- self.pool = nn.AdaptiveAvgPool1d(1)
20
-
21
- self.classifier = nn.Sequential(
22
- nn.Linear(d_model, 128),
23
- nn.ReLU(),
24
- nn.Linear(128, 2) # Binary classification
25
- )
26
 
27
  def forward(self, input_ids):
28
  x = self.embedding(input_ids)
@@ -32,5 +29,21 @@ class EvoTransformerV22(nn.Module):
32
  x = torch.cat([mem, x], dim=1)
33
 
34
  x = self.transformer(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  x = self.pool(x.transpose(1, 2)).squeeze(-1)
36
  return self.classifier(x)
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ class EvoEncoder(nn.Module):
6
+ def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  self.memory_enabled = memory_enabled
10
+ if memory_enabled:
11
+ self.memory_proj = nn.Linear(d_model, d_model)
12
+ self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
13
+ else:
14
+ self.memory_token = None
15
 
16
  encoder_layer = nn.TransformerEncoderLayer(
17
  d_model=d_model,
 
20
  batch_first=True
21
  )
22
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
 
 
 
 
 
 
23
 
24
  def forward(self, input_ids):
25
  x = self.embedding(input_ids)
 
29
  x = torch.cat([mem, x], dim=1)
30
 
31
  x = self.transformer(x)
32
+ return x
33
+
34
+
35
+ class EvoTransformerV22(nn.Module):
36
+ def __init__(self):
37
+ super().__init__()
38
+ self.encoder = EvoEncoder(memory_enabled=True)
39
+ self.pool = nn.AdaptiveAvgPool1d(1)
40
+ self.classifier = nn.Sequential(
41
+ nn.Linear(384, 128),
42
+ nn.ReLU(),
43
+ nn.Linear(128, 2)
44
+ )
45
+
46
+ def forward(self, input_ids):
47
+ x = self.encoder(input_ids)
48
  x = self.pool(x.transpose(1, 2)).squeeze(-1)
49
  return self.classifier(x)