Spaces:
Running
Running
# β evo_model.py β HF-compatible wrapper for EvoTransformer | |
import torch | |
from torch import nn | |
from transformers import PreTrainedModel, PretrainedConfig | |
from model import EvoTransformer # assumes your core model is in model.py | |
class EvoTransformerConfig(PretrainedConfig): | |
model_type = "evo-transformer" | |
def __init__( | |
self, | |
vocab_size=30522, | |
d_model=256, | |
nhead=4, | |
dim_feedforward=512, | |
num_hidden_layers=4, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.vocab_size = vocab_size | |
self.d_model = d_model | |
self.nhead = nhead | |
self.dim_feedforward = dim_feedforward | |
self.num_hidden_layers = num_hidden_layers | |
class EvoTransformerForClassification(PreTrainedModel): | |
config_class = EvoTransformerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = EvoTransformer( | |
vocab_size=config.vocab_size, | |
d_model=config.d_model, | |
nhead=config.nhead, | |
dim_feedforward=config.dim_feedforward, | |
num_layers=config.num_hidden_layers | |
) | |
def forward(self, input_ids): | |
return self.model(input_ids) | |
def save_pretrained(self, save_directory): | |
torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin") | |
self.config.save_pretrained(save_directory) | |
def from_pretrained(cls, load_directory): | |
config = EvoTransformerConfig.from_pretrained(load_directory) | |
model = cls(config) | |
model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin")) | |
return model |