HemanM commited on
Commit
87069f6
Β·
verified Β·
1 Parent(s): b39744e

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +26 -2
evo_model.py CHANGED
@@ -3,11 +3,23 @@ import torch.nn as nn
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
5
  class EvoTransformerConfig(PretrainedConfig):
6
- def __init__(self, hidden_size=384, num_layers=6, num_labels=2, **kwargs):
 
 
 
 
 
 
 
 
 
7
  super().__init__(**kwargs)
8
  self.hidden_size = hidden_size
9
  self.num_layers = num_layers
10
  self.num_labels = num_labels
 
 
 
11
 
12
  class EvoTransformerForClassification(PreTrainedModel):
13
  config_class = EvoTransformerConfig
@@ -15,9 +27,20 @@ class EvoTransformerForClassification(PreTrainedModel):
15
  def __init__(self, config):
16
  super().__init__(config)
17
  self.config = config
 
 
 
 
 
 
 
18
  self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size
19
  self.layers = nn.ModuleList([
20
- nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=6, dim_feedforward=1024)
 
 
 
 
21
  for _ in range(config.num_layers)
22
  ])
23
  self.classifier = nn.Sequential(
@@ -25,6 +48,7 @@ class EvoTransformerForClassification(PreTrainedModel):
25
  nn.ReLU(),
26
  nn.Linear(256, config.num_labels)
27
  )
 
28
  self.init_weights()
29
 
30
  def forward(self, input_ids, attention_mask=None, labels=None):
 
3
  from transformers import PreTrainedModel, PretrainedConfig
4
 
5
  class EvoTransformerConfig(PretrainedConfig):
6
+ def __init__(
7
+ self,
8
+ hidden_size=384,
9
+ num_layers=6,
10
+ num_labels=2,
11
+ num_heads=6,
12
+ ffn_dim=1024,
13
+ use_memory=False,
14
+ **kwargs
15
+ ):
16
  super().__init__(**kwargs)
17
  self.hidden_size = hidden_size
18
  self.num_layers = num_layers
19
  self.num_labels = num_labels
20
+ self.num_heads = num_heads
21
+ self.ffn_dim = ffn_dim
22
+ self.use_memory = use_memory
23
 
24
  class EvoTransformerForClassification(PreTrainedModel):
25
  config_class = EvoTransformerConfig
 
27
  def __init__(self, config):
28
  super().__init__(config)
29
  self.config = config
30
+
31
+ # Expose architecture attributes for dashboard
32
+ self.num_layers = config.num_layers
33
+ self.num_heads = config.num_heads
34
+ self.ffn_dim = config.ffn_dim
35
+ self.use_memory = config.use_memory
36
+
37
  self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size
38
  self.layers = nn.ModuleList([
39
+ nn.TransformerEncoderLayer(
40
+ d_model=config.hidden_size,
41
+ nhead=config.num_heads,
42
+ dim_feedforward=config.ffn_dim
43
+ )
44
  for _ in range(config.num_layers)
45
  ])
46
  self.classifier = nn.Sequential(
 
48
  nn.ReLU(),
49
  nn.Linear(256, config.num_labels)
50
  )
51
+
52
  self.init_weights()
53
 
54
  def forward(self, input_ids, attention_mask=None, labels=None):