HemanM commited on
Commit
d23a6f2
·
verified ·
1 Parent(s): 27c948e

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +19 -26
evo_model.py CHANGED
@@ -1,39 +1,32 @@
1
  import torch
2
  import torch.nn as nn
3
- from torch.nn import TransformerEncoder, TransformerEncoderLayer
4
 
5
  class EvoEncoder(nn.Module):
6
- def __init__(self, vocab_size=30522, d_model=512, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
7
- super(EvoEncoder, self).__init__()
8
- self.embedding = nn.Embedding(vocab_size, d_model)
9
- self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
10
- self.positional_encoding = nn.Parameter(torch.zeros(1, 512, d_model))
11
-
12
- encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=nhead,
13
- dim_feedforward=dim_feedforward, dropout=dropout)
14
- self.transformer = TransformerEncoder(encoder_layer, num_layers=num_layers)
15
- self.norm = nn.LayerNorm(d_model)
 
16
 
17
  def forward(self, input_ids):
18
  x = self.embedding(input_ids)
19
- bsz = x.size(0)
20
-
21
- # Add memory token
22
- mem_token = self.memory_token.expand(bsz, -1, -1) # [B, 1, D]
23
- x = torch.cat([mem_token, x], dim=1)
24
-
25
- x = x + self.positional_encoding[:, :x.size(1), :]
26
  x = self.transformer(x)
27
- x = self.norm(x)
28
- return x[:, 0] # return memory token output
29
 
30
  class EvoTransformer(nn.Module):
31
- def __init__(self, vocab_size=30522, d_model=512, nhead=8, num_layers=6,
32
- dim_feedforward=1024, dropout=0.1):
33
- super(EvoTransformer, self).__init__()
34
- self.encoder = EvoEncoder(vocab_size, d_model, nhead, num_layers, dim_feedforward, dropout)
35
- self.classifier = nn.Linear(d_model, 1)
36
 
37
  def forward(self, input_ids):
38
  x = self.encoder(input_ids)
39
- return x # shape: [batch, d_model]
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
+ def __init__(self, d_model=384, nhead=6, dim_feedforward=1024, num_layers=6):
7
+ super().__init__()
8
+ self.embedding = nn.Embedding(30522, d_model) # BERT-base vocab size
9
+ encoder_layer = nn.TransformerEncoderLayer(
10
+ d_model=d_model,
11
+ nhead=nhead,
12
+ dim_feedforward=dim_feedforward,
13
+ batch_first=True,
14
+ )
15
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
+ self.memory_proj = nn.Linear(d_model, d_model)
17
 
18
  def forward(self, input_ids):
19
  x = self.embedding(input_ids)
 
 
 
 
 
 
 
20
  x = self.transformer(x)
21
+ x = x.mean(dim=1)
22
+ return self.memory_proj(x)
23
 
24
  class EvoTransformer(nn.Module):
25
+ def __init__(self, d_model=384):
26
+ super().__init__()
27
+ self.encoder = EvoEncoder(d_model=d_model)
28
+ self.classifier = nn.Linear(d_model, 2)
 
29
 
30
  def forward(self, input_ids):
31
  x = self.encoder(input_ids)
32
+ return x