File size: 1,666 Bytes
daeebb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# ✅ evo_model.py – HF-compatible wrapper for EvoTransformer

import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from model import EvoTransformer  # assumes your core model is in model.py

class EvoTransformerConfig(PretrainedConfig):
    model_type = "evo-transformer"

    def __init__(
        self,
        vocab_size=30522,
        d_model=256,
        nhead=4,
        dim_feedforward=512,
        num_hidden_layers=4,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.num_hidden_layers = num_hidden_layers

class EvoTransformerForClassification(PreTrainedModel):
    config_class = EvoTransformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = EvoTransformer(
            vocab_size=config.vocab_size,
            d_model=config.d_model,
            nhead=config.nhead,
            dim_feedforward=config.dim_feedforward,
            num_layers=config.num_hidden_layers
        )

    def forward(self, input_ids):
        return self.model(input_ids)

    def save_pretrained(self, save_directory):
        torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, load_directory):
        config = EvoTransformerConfig.from_pretrained(load_directory)
        model = cls(config)
        model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
        return model