Spaces:
Running
Running
File size: 1,666 Bytes
daeebb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# ✅ evo_model.py – HF-compatible wrapper for EvoTransformer
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from model import EvoTransformer # assumes your core model is in model.py
class EvoTransformerConfig(PretrainedConfig):
model_type = "evo-transformer"
def __init__(
self,
vocab_size=30522,
d_model=256,
nhead=4,
dim_feedforward=512,
num_hidden_layers=4,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.nhead = nhead
self.dim_feedforward = dim_feedforward
self.num_hidden_layers = num_hidden_layers
class EvoTransformerForClassification(PreTrainedModel):
config_class = EvoTransformerConfig
def __init__(self, config):
super().__init__(config)
self.model = EvoTransformer(
vocab_size=config.vocab_size,
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.dim_feedforward,
num_layers=config.num_hidden_layers
)
def forward(self, input_ids):
return self.model(input_ids)
def save_pretrained(self, save_directory):
torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, load_directory):
config = EvoTransformerConfig.from_pretrained(load_directory)
model = cls(config)
model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
return model |