Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig | |
class EvoTransformerConfig(PretrainedConfig): | |
def __init__(self, hidden_size=384, num_layers=6, num_labels=2, **kwargs): | |
super().__init__(**kwargs) | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.num_labels = num_labels | |
class EvoTransformerForClassification(PreTrainedModel): | |
config_class = EvoTransformerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size | |
self.layers = nn.ModuleList([ | |
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=6, dim_feedforward=1024) | |
for _ in range(config.num_layers) | |
]) | |
self.classifier = nn.Sequential( | |
nn.Linear(config.hidden_size, 256), | |
nn.ReLU(), | |
nn.Linear(256, config.num_labels) | |
) | |
self.init_weights() | |
def forward(self, input_ids, attention_mask=None, labels=None): | |
x = self.embedding(input_ids) # [batch, seq_len, hidden_size] | |
x = x.transpose(0, 1) # Transformer expects [seq_len, batch, hidden_size] | |
for layer in self.layers: | |
x = layer(x, src_key_padding_mask=(attention_mask == 0) if attention_mask is not None else None) | |
x = x.mean(dim=0) # mean pooling over seq_len | |
logits = self.classifier(x) | |
if labels is not None: | |
loss = nn.functional.cross_entropy(logits, labels) | |
return loss, logits | |
return logits | |
def save_pretrained(self, save_directory): | |
import os, json | |
os.makedirs(save_directory, exist_ok=True) | |
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin") | |
with open(f"{save_directory}/config.json", "w") as f: | |
f.write(self.config.to_json_string()) | |
def from_pretrained(cls, load_directory): | |
config_path = f"{load_directory}/config.json" | |
model_path = f"{load_directory}/pytorch_model.bin" | |
config = EvoTransformerConfig.from_json_file(config_path) | |
model = cls(config) | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
return model | |