EvoTransformer-v2.1 / evo_model.py
HemanM's picture
Update evo_model.py
daeebb8 verified
raw
history blame
1.67 kB
# βœ… evo_model.py – HF-compatible wrapper for EvoTransformer
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from model import EvoTransformer # assumes your core model is in model.py
class EvoTransformerConfig(PretrainedConfig):
model_type = "evo-transformer"
def __init__(
self,
vocab_size=30522,
d_model=256,
nhead=4,
dim_feedforward=512,
num_hidden_layers=4,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.nhead = nhead
self.dim_feedforward = dim_feedforward
self.num_hidden_layers = num_hidden_layers
class EvoTransformerForClassification(PreTrainedModel):
config_class = EvoTransformerConfig
def __init__(self, config):
super().__init__(config)
self.model = EvoTransformer(
vocab_size=config.vocab_size,
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.dim_feedforward,
num_layers=config.num_hidden_layers
)
def forward(self, input_ids):
return self.model(input_ids)
def save_pretrained(self, save_directory):
torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.bin")
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, load_directory):
config = EvoTransformerConfig.from_pretrained(load_directory)
model = cls(config)
model.model.load_state_dict(torch.load(f"{load_directory}/pytorch_model.bin"))
return model