Spaces:
Runtime error
Runtime error
Update evo_model.py
Browse files- evo_model.py +26 -2
evo_model.py
CHANGED
@@ -3,11 +3,23 @@ import torch.nn as nn
|
|
3 |
from transformers import PreTrainedModel, PretrainedConfig
|
4 |
|
5 |
class EvoTransformerConfig(PretrainedConfig):
|
6 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
super().__init__(**kwargs)
|
8 |
self.hidden_size = hidden_size
|
9 |
self.num_layers = num_layers
|
10 |
self.num_labels = num_labels
|
|
|
|
|
|
|
11 |
|
12 |
class EvoTransformerForClassification(PreTrainedModel):
|
13 |
config_class = EvoTransformerConfig
|
@@ -15,9 +27,20 @@ class EvoTransformerForClassification(PreTrainedModel):
|
|
15 |
def __init__(self, config):
|
16 |
super().__init__(config)
|
17 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size
|
19 |
self.layers = nn.ModuleList([
|
20 |
-
nn.TransformerEncoderLayer(
|
|
|
|
|
|
|
|
|
21 |
for _ in range(config.num_layers)
|
22 |
])
|
23 |
self.classifier = nn.Sequential(
|
@@ -25,6 +48,7 @@ class EvoTransformerForClassification(PreTrainedModel):
|
|
25 |
nn.ReLU(),
|
26 |
nn.Linear(256, config.num_labels)
|
27 |
)
|
|
|
28 |
self.init_weights()
|
29 |
|
30 |
def forward(self, input_ids, attention_mask=None, labels=None):
|
|
|
3 |
from transformers import PreTrainedModel, PretrainedConfig
|
4 |
|
5 |
class EvoTransformerConfig(PretrainedConfig):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
hidden_size=384,
|
9 |
+
num_layers=6,
|
10 |
+
num_labels=2,
|
11 |
+
num_heads=6,
|
12 |
+
ffn_dim=1024,
|
13 |
+
use_memory=False,
|
14 |
+
**kwargs
|
15 |
+
):
|
16 |
super().__init__(**kwargs)
|
17 |
self.hidden_size = hidden_size
|
18 |
self.num_layers = num_layers
|
19 |
self.num_labels = num_labels
|
20 |
+
self.num_heads = num_heads
|
21 |
+
self.ffn_dim = ffn_dim
|
22 |
+
self.use_memory = use_memory
|
23 |
|
24 |
class EvoTransformerForClassification(PreTrainedModel):
|
25 |
config_class = EvoTransformerConfig
|
|
|
27 |
def __init__(self, config):
|
28 |
super().__init__(config)
|
29 |
self.config = config
|
30 |
+
|
31 |
+
# Expose architecture attributes for dashboard
|
32 |
+
self.num_layers = config.num_layers
|
33 |
+
self.num_heads = config.num_heads
|
34 |
+
self.ffn_dim = config.ffn_dim
|
35 |
+
self.use_memory = config.use_memory
|
36 |
+
|
37 |
self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size
|
38 |
self.layers = nn.ModuleList([
|
39 |
+
nn.TransformerEncoderLayer(
|
40 |
+
d_model=config.hidden_size,
|
41 |
+
nhead=config.num_heads,
|
42 |
+
dim_feedforward=config.ffn_dim
|
43 |
+
)
|
44 |
for _ in range(config.num_layers)
|
45 |
])
|
46 |
self.classifier = nn.Sequential(
|
|
|
48 |
nn.ReLU(),
|
49 |
nn.Linear(256, config.num_labels)
|
50 |
)
|
51 |
+
|
52 |
self.init_weights()
|
53 |
|
54 |
def forward(self, input_ids, attention_mask=None, labels=None):
|