HemanM commited on
Commit
ceea26e
·
verified ·
1 Parent(s): fdef447

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +19 -9
evo_model.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
 
5
- class EvoTransformerV22(nn.Module):
6
  def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
@@ -17,12 +16,6 @@ class EvoTransformerV22(nn.Module):
17
  if memory_enabled:
18
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
19
  self.memory_proj = nn.Linear(d_model, d_model)
20
- self.pooling = nn.AdaptiveAvgPool1d(1)
21
- self.classifier = nn.Sequential(
22
- nn.Linear(d_model, 128),
23
- nn.ReLU(),
24
- nn.Linear(128, 2)
25
- )
26
 
27
  def forward(self, input_ids):
28
  x = self.embedding(input_ids)
@@ -30,6 +23,23 @@ class EvoTransformerV22(nn.Module):
30
  mem = self.memory_token.expand(x.size(0), -1, -1)
31
  x = torch.cat([mem, x], dim=1)
32
  x = self.transformer(x)
33
- x = x.permute(0, 2, 1) # for pooling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  x = self.pooling(x).squeeze(-1)
35
  return self.classifier(x)
 
1
  import torch
2
  import torch.nn as nn
 
3
 
4
+ class EvoEncoder(nn.Module):
5
  def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
6
  super().__init__()
7
  self.embedding = nn.Embedding(30522, d_model)
 
16
  if memory_enabled:
17
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
18
  self.memory_proj = nn.Linear(d_model, d_model)
 
 
 
 
 
 
19
 
20
  def forward(self, input_ids):
21
  x = self.embedding(input_ids)
 
23
  mem = self.memory_token.expand(x.size(0), -1, -1)
24
  x = torch.cat([mem, x], dim=1)
25
  x = self.transformer(x)
26
+ return x
27
+
28
+ class EvoTransformerV22(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.encoder = EvoEncoder(
32
+ d_model=384,
33
+ num_heads=6,
34
+ ffn_dim=1024,
35
+ num_layers=6,
36
+ memory_enabled=True
37
+ )
38
+ self.pooling = nn.AdaptiveAvgPool1d(1)
39
+ self.classifier = nn.Linear(384, 2)
40
+
41
+ def forward(self, input_ids):
42
+ x = self.encoder(input_ids)
43
+ x = x.permute(0, 2, 1) # [B, D, T]
44
  x = self.pooling(x).squeeze(-1)
45
  return self.classifier(x)