# ✅ evo_model.py – HF-compatible wrapper for EvoTransformer import torch from torch import nn from transformers import PreTrainedModel, PretrainedConfig from model import EvoTransformer # assumes your core model is in model.py class EvoTransformerConfig(PretrainedConfig): model_type = "evo-transformer" def __init__( self, vocab_size=30522, d_model=256, nhead=4, dim_feedforward=512, num_hidden_layers=4, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.nhead = nhead self.dim_feedforward = dim_feedforward self.num_hidden_layers = num_hidden_layers class EvoTransformerForClassification(PreTrainedModel): config_class = EvoTransformerConfig def __init__(self, config): super().__init__(config) self.model = EvoTransformer( vocab_size=config.vocab_size, d_model=config.d_model, nhead=config.nhead, dim_feedforward=config.dim_feedforward, num_layers=config.num_hidden_layers ) def forward(self, input_ids): return self.model(input_ids) def save_pretrained(self, save_directory): torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin") self.config.save_pretrained(save_directory) @classmethod def from_pretrained(cls, load_directory): config = EvoTransformerConfig.from_pretrained(load_directory) model = cls(config) model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin")) return model