HemanM commited on
Commit
163f93d
·
verified ·
1 Parent(s): 2b652a8

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +3 -9
evo_model.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
- def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  encoder_layer = nn.TransformerEncoderLayer(
@@ -14,20 +14,14 @@ class EvoEncoder(nn.Module):
14
  )
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
  self.memory_enabled = memory_enabled
17
- if memory_enabled:
18
- self.memory_proj = nn.Linear(d_model, d_model)
19
- self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
20
 
21
  def forward(self, input_ids):
22
  x = self.embedding(input_ids)
23
- if self.memory_enabled:
24
- mem = self.memory_token.expand(x.size(0), -1, -1)
25
- x = torch.cat([mem, x], dim=1)
26
  x = self.transformer(x)
27
- return x[:, 0] # Return memory token or first token
28
 
29
  class EvoTransformer(nn.Module):
30
- def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, num_classes=1, memory_enabled=True):
31
  super().__init__()
32
  self.encoder = EvoEncoder(d_model, num_heads, ffn_dim, num_layers, memory_enabled)
33
  self.classifier = nn.Linear(d_model, num_classes)
 
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
+ def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  encoder_layer = nn.TransformerEncoderLayer(
 
14
  )
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
  self.memory_enabled = memory_enabled
 
 
 
17
 
18
  def forward(self, input_ids):
19
  x = self.embedding(input_ids)
 
 
 
20
  x = self.transformer(x)
21
+ return x[:, 0] # first token
22
 
23
  class EvoTransformer(nn.Module):
24
+ def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, num_classes=2, memory_enabled=False):
25
  super().__init__()
26
  self.encoder = EvoEncoder(d_model, num_heads, ffn_dim, num_layers, memory_enabled)
27
  self.classifier = nn.Linear(d_model, num_classes)