HemanM commited on
Commit
0eff587
·
verified ·
1 Parent(s): 54bb1a5

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +8 -6
evo_model.py CHANGED
@@ -6,23 +6,25 @@ class EvoEncoder(nn.Module):
6
  def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
- self.positional_encoding = nn.Parameter(torch.zeros(1, 512, d_model))
10
- encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True)
 
 
 
 
11
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
12
- self.norm = nn.LayerNorm(d_model)
13
  self.memory_enabled = memory_enabled
14
  if memory_enabled:
15
  self.memory_proj = nn.Linear(d_model, d_model)
16
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
17
 
18
  def forward(self, input_ids):
19
- x = self.embedding(input_ids) + self.positional_encoding[:, :input_ids.size(1), :]
20
  if self.memory_enabled:
21
  mem = self.memory_token.expand(x.size(0), -1, -1)
22
  x = torch.cat([mem, x], dim=1)
23
  x = self.transformer(x)
24
- x = self.norm(x)
25
- return x[:, 0] # Return [CLS]-like token (memory or first token)
26
 
27
  class EvoTransformer(nn.Module):
28
  def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, num_classes=1, memory_enabled=True):
 
6
  def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
+ encoder_layer = nn.TransformerEncoderLayer(
10
+ d_model=d_model,
11
+ nhead=num_heads,
12
+ dim_feedforward=ffn_dim,
13
+ batch_first=True
14
+ )
15
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
16
  self.memory_enabled = memory_enabled
17
  if memory_enabled:
18
  self.memory_proj = nn.Linear(d_model, d_model)
19
  self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
20
 
21
  def forward(self, input_ids):
22
+ x = self.embedding(input_ids)
23
  if self.memory_enabled:
24
  mem = self.memory_token.expand(x.size(0), -1, -1)
25
  x = torch.cat([mem, x], dim=1)
26
  x = self.transformer(x)
27
+ return x[:, 0] # Return memory token or first token
 
28
 
29
  class EvoTransformer(nn.Module):
30
  def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, num_classes=1, memory_enabled=True):