HemanM commited on
Commit
f5a2b6c
·
verified ·
1 Parent(s): 163f93d

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +10 -5
evo_model.py CHANGED
@@ -1,9 +1,8 @@
1
- import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
- def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=False):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  encoder_layer = nn.TransformerEncoderLayer(
@@ -14,18 +13,24 @@ class EvoEncoder(nn.Module):
14
  )
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
  self.memory_enabled = memory_enabled
 
 
 
17
 
18
  def forward(self, input_ids):
19
  x = self.embedding(input_ids)
 
 
 
20
  x = self.transformer(x)
21
- return x[:, 0] # first token
22
 
23
  class EvoTransformer(nn.Module):
24
- def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, num_classes=2, memory_enabled=False):
25
  super().__init__()
26
  self.encoder = EvoEncoder(d_model, num_heads, ffn_dim, num_layers, memory_enabled)
27
  self.classifier = nn.Linear(d_model, num_classes)
28
 
29
  def forward(self, input_ids):
30
  x = self.encoder(input_ids)
31
- return self.classifier(x)
 
 
1
  import torch.nn as nn
2
  import torch.nn.functional as F
3
 
4
  class EvoEncoder(nn.Module):
5
+ def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
6
  super().__init__()
7
  self.embedding = nn.Embedding(30522, d_model)
8
  encoder_layer = nn.TransformerEncoderLayer(
 
13
  )
14
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
15
  self.memory_enabled = memory_enabled
16
+ if memory_enabled:
17
+ self.memory_proj = nn.Linear(d_model, d_model)
18
+ self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
19
 
20
  def forward(self, input_ids):
21
  x = self.embedding(input_ids)
22
+ if self.memory_enabled:
23
+ mem = self.memory_token.expand(x.size(0), -1, -1)
24
+ x = torch.cat([mem, x], dim=1)
25
  x = self.transformer(x)
26
+ return x[:, 0] # Return memory token or first token
27
 
28
  class EvoTransformer(nn.Module):
29
+ def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, num_classes=1, memory_enabled=True):
30
  super().__init__()
31
  self.encoder = EvoEncoder(d_model, num_heads, ffn_dim, num_layers, memory_enabled)
32
  self.classifier = nn.Linear(d_model, num_classes)
33
 
34
  def forward(self, input_ids):
35
  x = self.encoder(input_ids)
36
+ return self.classifier(x)