HemanM commited on
Commit
ffbab0d
·
verified ·
1 Parent(s): d73ff46

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +3 -4
evo_model.py CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
- def __init__(self, d_model=384, num_heads=6, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  self.memory_enabled = memory_enabled
@@ -31,14 +31,13 @@ class EvoEncoder(nn.Module):
31
  x = self.transformer(x)
32
  return x
33
 
34
-
35
  class EvoTransformerV22(nn.Module):
36
  def __init__(self):
37
  super().__init__()
38
- self.encoder = EvoEncoder(memory_enabled=True)
39
  self.pool = nn.AdaptiveAvgPool1d(1)
40
  self.classifier = nn.Sequential(
41
- nn.Linear(384, 128),
42
  nn.ReLU(),
43
  nn.Linear(128, 2)
44
  )
 
3
  import torch.nn.functional as F
4
 
5
  class EvoEncoder(nn.Module):
6
+ def __init__(self, d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True):
7
  super().__init__()
8
  self.embedding = nn.Embedding(30522, d_model)
9
  self.memory_enabled = memory_enabled
 
31
  x = self.transformer(x)
32
  return x
33
 
 
34
  class EvoTransformerV22(nn.Module):
35
  def __init__(self):
36
  super().__init__()
37
+ self.encoder = EvoEncoder(d_model=512, num_heads=8, ffn_dim=1024, num_layers=6, memory_enabled=True)
38
  self.pool = nn.AdaptiveAvgPool1d(1)
39
  self.classifier = nn.Sequential(
40
+ nn.Linear(512, 128),
41
  nn.ReLU(),
42
  nn.Linear(128, 2)
43
  )