HemanM commited on
Commit
a568566
·
verified ·
1 Parent(s): 57dfdd4

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +18 -19
evo_model.py CHANGED
@@ -1,10 +1,18 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  class EvoEncoder(nn.Module):
5
- def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=True):
6
  super().__init__()
7
  self.embedding = nn.Embedding(30522, d_model)
 
 
 
 
 
 
 
8
  encoder_layer = nn.TransformerEncoderLayer(
9
  d_model=d_model,
10
  nhead=num_heads,
@@ -12,34 +20,25 @@ class EvoEncoder(nn.Module):
12
  batch_first=True
13
  )
14
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
15
- self.memory_enabled = memory_enabled
16
- if memory_enabled:
17
- self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
18
- self.memory_proj = nn.Linear(d_model, d_model)
19
 
20
  def forward(self, input_ids):
21
  x = self.embedding(input_ids)
22
- if self.memory_enabled:
23
- mem = self.memory_token.expand(x.size(0), -1, -1)
 
24
  x = torch.cat([mem, x], dim=1)
 
25
  x = self.transformer(x)
26
  return x
27
 
28
  class EvoTransformerV22(nn.Module):
29
  def __init__(self):
30
  super().__init__()
31
- self.encoder = EvoEncoder(
32
- d_model=384,
33
- num_heads=6,
34
- ffn_dim=1024,
35
- num_layers=6,
36
- memory_enabled=True
37
- )
38
- self.pooling = nn.AdaptiveAvgPool1d(1)
39
- self.classifier = nn.Linear(384, 2)
40
 
41
  def forward(self, input_ids):
42
  x = self.encoder(input_ids)
43
- x = x.permute(0, 2, 1) # [B, D, T]
44
- x = self.pooling(x).squeeze(-1)
45
- return self.classifier(x)
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
+ def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
+ self.memory_enabled = memory_enabled
10
+ if memory_enabled:
11
+ self.memory_proj = nn.Linear(d_model, d_model)
12
+ self.memory_token = nn.Parameter(torch.zeros(1, 1, d_model))
13
+ else:
14
+ self.memory_token = None
15
+
16
  encoder_layer = nn.TransformerEncoderLayer(
17
  d_model=d_model,
18
  nhead=num_heads,
 
20
  batch_first=True
21
  )
22
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
 
 
 
 
23
 
24
  def forward(self, input_ids):
25
  x = self.embedding(input_ids)
26
+
27
+ if self.memory_enabled and self.memory_token is not None:
28
+ mem = self.memory_token.expand(x.size(0), 1, x.size(2))
29
  x = torch.cat([mem, x], dim=1)
30
+
31
  x = self.transformer(x)
32
  return x
33
 
34
  class EvoTransformerV22(nn.Module):
35
  def __init__(self):
36
  super().__init__()
37
+ self.encoder = EvoEncoder(d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True)
38
+ self.pool = nn.AdaptiveAvgPool1d(1)
39
+ self.classifier = nn.Linear(512, 1) # ✅ Matches checkpoint
 
 
 
 
 
 
40
 
41
  def forward(self, input_ids):
42
  x = self.encoder(input_ids)
43
+ x = self.pool(x.transpose(1, 2)).squeeze(-1)
44
+ return self.classifier(x) # Output: [batch_size, 1]