HemanM commited on
Commit
54bb1a5
·
verified ·
1 Parent(s): d23a6f2

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +19 -16
evo_model.py CHANGED
@@ -3,30 +3,33 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
- def __init__(self, d_model=384, nhead=6, dim_feedforward=1024, num_layers=6):
7
  super().__init__()
8
- self.embedding = nn.Embedding(30522, d_model) # BERT-base vocab size
9
- encoder_layer = nn.TransformerEncoderLayer(
10
- d_model=d_model,
11
- nhead=nhead,
12
- dim_feedforward=dim_feedforward,
13
- batch_first=True,
14
- )
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
16
- self.memory_proj = nn.Linear(d_model, d_model)
 
 
 
 
17
 
18
  def forward(self, input_ids):
19
- x = self.embedding(input_ids)
 
 
 
20
  x = self.transformer(x)
21
- x = x.mean(dim=1)
22
- return self.memory_proj(x)
23
 
24
  class EvoTransformer(nn.Module):
25
- def __init__(self, d_model=384):
26
  super().__init__()
27
- self.encoder = EvoEncoder(d_model=d_model)
28
- self.classifier = nn.Linear(d_model, 2)
29
 
30
  def forward(self, input_ids):
31
  x = self.encoder(input_ids)
32
- return x
 
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
+ def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
+ self.embedding = nn.Embedding(30522, d_model)
9
+ self.positional_encoding = nn.Parameter(torch.zeros(1, 512, d_model))
10
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True)
 
 
 
 
11
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
12
+ self.norm = nn.LayerNorm(d_model)
13
+ self.memory_enabled = memory_enabled
14
+ if memory_enabled:
15
+ self.memory_proj = nn.Linear(d_model, d_model)
16
+ self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
17
 
18
  def forward(self, input_ids):
19
+ x = self.embedding(input_ids) + self.positional_encoding[:, :input_ids.size(1), :]
20
+ if self.memory_enabled:
21
+ mem = self.memory_token.expand(x.size(0), -1, -1)
22
+ x = torch.cat([mem, x], dim=1)
23
  x = self.transformer(x)
24
+ x = self.norm(x)
25
+ return x[:, 0] # Return [CLS]-like token (memory or first token)
26
 
27
  class EvoTransformer(nn.Module):
28
+ def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, num_classes=1, memory_enabled=True):
29
  super().__init__()
30
+ self.encoder = EvoEncoder(d_model, num_heads, ffn_dim, num_layers, memory_enabled)
31
+ self.classifier = nn.Linear(d_model, num_classes)
32
 
33
  def forward(self, input_ids):
34
  x = self.encoder(input_ids)
35
+ return self.classifier(x)