HemanM commited on
Commit
daeebb8
·
verified ·
1 Parent(s): 97d373b

Update evo_model.py

Browse files
Files changed (1) hide show
  1. evo_model.py +52 -0
evo_model.py CHANGED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ✅ evo_model.py – HF-compatible wrapper for EvoTransformer
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import PreTrainedModel, PretrainedConfig
6
+ from model import EvoTransformer # assumes your core model is in model.py
7
+
8
+ class EvoTransformerConfig(PretrainedConfig):
9
+ model_type = "evo-transformer"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=30522,
14
+ d_model=256,
15
+ nhead=4,
16
+ dim_feedforward=512,
17
+ num_hidden_layers=4,
18
+ **kwargs
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.vocab_size = vocab_size
22
+ self.d_model = d_model
23
+ self.nhead = nhead
24
+ self.dim_feedforward = dim_feedforward
25
+ self.num_hidden_layers = num_hidden_layers
26
+
27
+ class EvoTransformerForClassification(PreTrainedModel):
28
+ config_class = EvoTransformerConfig
29
+
30
+ def __init__(self, config):
31
+ super().__init__(config)
32
+ self.model = EvoTransformer(
33
+ vocab_size=config.vocab_size,
34
+ d_model=config.d_model,
35
+ nhead=config.nhead,
36
+ dim_feedforward=config.dim_feedforward,
37
+ num_layers=config.num_hidden_layers
38
+ )
39
+
40
+ def forward(self, input_ids):
41
+ return self.model(input_ids)
42
+
43
+ def save_pretrained(self, save_directory):
44
+ torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
45
+ self.config.save_pretrained(save_directory)
46
+
47
+ @classmethod
48
+ def from_pretrained(cls, load_directory):
49
+ config = EvoTransformerConfig.from_pretrained(load_directory)
50
+ model = cls(config)
51
+ model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
52
+ return model