Spaces:
Running
Running
Update evo_model.py
Browse files- evo_model.py +52 -0
evo_model.py
CHANGED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ✅ evo_model.py – HF-compatible wrapper for EvoTransformer
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
6 |
+
from model import EvoTransformer # assumes your core model is in model.py
|
7 |
+
|
8 |
+
class EvoTransformerConfig(PretrainedConfig):
|
9 |
+
model_type = "evo-transformer"
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
vocab_size=30522,
|
14 |
+
d_model=256,
|
15 |
+
nhead=4,
|
16 |
+
dim_feedforward=512,
|
17 |
+
num_hidden_layers=4,
|
18 |
+
**kwargs
|
19 |
+
):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.vocab_size = vocab_size
|
22 |
+
self.d_model = d_model
|
23 |
+
self.nhead = nhead
|
24 |
+
self.dim_feedforward = dim_feedforward
|
25 |
+
self.num_hidden_layers = num_hidden_layers
|
26 |
+
|
27 |
+
class EvoTransformerForClassification(PreTrainedModel):
|
28 |
+
config_class = EvoTransformerConfig
|
29 |
+
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__(config)
|
32 |
+
self.model = EvoTransformer(
|
33 |
+
vocab_size=config.vocab_size,
|
34 |
+
d_model=config.d_model,
|
35 |
+
nhead=config.nhead,
|
36 |
+
dim_feedforward=config.dim_feedforward,
|
37 |
+
num_layers=config.num_hidden_layers
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, input_ids):
|
41 |
+
return self.model(input_ids)
|
42 |
+
|
43 |
+
def save_pretrained(self, save_directory):
|
44 |
+
torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
|
45 |
+
self.config.save_pretrained(save_directory)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def from_pretrained(cls, load_directory):
|
49 |
+
config = EvoTransformerConfig.from_pretrained(load_directory)
|
50 |
+
model = cls(config)
|
51 |
+
model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
|
52 |
+
return model
|