HemanM commited on
Commit
5ed25f6
·
verified ·
1 Parent(s): cad50da

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +7 -8
evo_model.py CHANGED
@@ -3,17 +3,18 @@ import torch.nn as nn
3
  from torch.nn import TransformerEncoder, TransformerEncoderLayer
4
 
5
  class EvoTransformer(nn.Module):
6
- def __init__(self, vocab_size=30522, d_model=384, nhead=6, num_layers=6, dim_feedforward=1024, dropout=0.1, num_labels=2):
7
  super(EvoTransformer, self).__init__()
8
  self.embedding = nn.Embedding(vocab_size, d_model)
9
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
10
-
11
  encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
12
  self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)
13
- self.norm = nn.LayerNorm(d_model)
14
-
15
  self.memory_proj = nn.Linear(d_model, d_model)
16
- self.classifier = nn.Linear(d_model, num_labels)
 
 
17
 
18
  def forward(self, input_ids):
19
  x = self.embedding(input_ids)
@@ -25,6 +26,4 @@ class EvoTransformer(nn.Module):
25
  x = self.norm(x)
26
 
27
  memory_output = self.memory_proj(x[:, 0])
28
- logits = self.classifier(memory_output)
29
-
30
- return logits
 
3
  from torch.nn import TransformerEncoder, TransformerEncoderLayer
4
 
5
  class EvoTransformer(nn.Module):
6
+ def __init__(self, vocab_size=30522, d_model=512, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
7
  super(EvoTransformer, self).__init__()
8
  self.embedding = nn.Embedding(vocab_size, d_model)
9
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
10
+
11
  encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
12
  self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)
13
+
 
14
  self.memory_proj = nn.Linear(d_model, d_model)
15
+ self.norm = nn.LayerNorm(d_model)
16
+
17
+ self.classifier = nn.Linear(d_model, 1) # Matches saved model: output is a single logit
18
 
19
  def forward(self, input_ids):
20
  x = self.embedding(input_ids)
 
26
  x = self.norm(x)
27
 
28
  memory_output = self.memory_proj(x[:, 0])
29
+ return memory_output