HemanM commited on
Commit
fdef447
·
verified ·
1 Parent(s): 6a1a7af

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +11 -13
evo_model.py CHANGED
@@ -2,8 +2,8 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- class EvoEncoder(nn.Module):
6
- def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  encoder_layer = nn.TransformerEncoderLayer(
@@ -15,8 +15,14 @@ class EvoEncoder(nn.Module):
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
  self.memory_enabled = memory_enabled
17
  if memory_enabled:
18
- self.memory_proj = nn.Linear(d_model, d_model)
19
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
 
 
 
 
 
 
 
20
 
21
  def forward(self, input_ids):
22
  x = self.embedding(input_ids)
@@ -24,14 +30,6 @@ class EvoEncoder(nn.Module):
24
  mem = self.memory_token.expand(x.size(0), -1, -1)
25
  x = torch.cat([mem, x], dim=1)
26
  x = self.transformer(x)
27
- return x[:, 0] # Return memory token or first token
28
-
29
- class EvoTransformer(nn.Module):
30
- def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, num_classes=1, memory_enabled=True):
31
- super().__init__()
32
- self.encoder = EvoEncoder(d_model, num_heads, ffn_dim, num_layers, memory_enabled)
33
- self.classifier = nn.Linear(d_model, num_classes)
34
-
35
- def forward(self, input_ids):
36
- x = self.encoder(input_ids)
37
  return self.classifier(x)
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
+ class EvoTransformerV22(nn.Module):
6
+ def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  encoder_layer = nn.TransformerEncoderLayer(
 
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
  self.memory_enabled = memory_enabled
17
  if memory_enabled:
 
18
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
19
+ self.memory_proj = nn.Linear(d_model, d_model)
20
+ self.pooling = nn.AdaptiveAvgPool1d(1)
21
+ self.classifier = nn.Sequential(
22
+ nn.Linear(d_model, 128),
23
+ nn.ReLU(),
24
+ nn.Linear(128, 2)
25
+ )
26
 
27
  def forward(self, input_ids):
28
  x = self.embedding(input_ids)
 
30
  mem = self.memory_token.expand(x.size(0), -1, -1)
31
  x = torch.cat([mem, x], dim=1)
32
  x = self.transformer(x)
33
+ x = x.permute(0, 2, 1) # for pooling
34
+ x = self.pooling(x).squeeze(-1)
 
 
 
 
 
 
 
 
35
  return self.classifier(x)