Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig | |
class EvoTransformerConfig(PretrainedConfig): | |
def __init__(self, hidden_size=384, num_layers=6, num_labels=2, **kwargs): | |
super().__init__(**kwargs) | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.num_labels = num_labels | |
class EvoTransformerForClassification(PreTrainedModel): | |
config_class = EvoTransformerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.embedding = nn.Embedding(30522, config.hidden_size) | |
self.layers = nn.ModuleList([ | |
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=6, dim_feedforward=1024) | |
for _ in range(config.num_layers) | |
]) | |
self.classifier = nn.Sequential( | |
nn.Linear(config.hidden_size, 256), | |
nn.ReLU(), | |
nn.Linear(256, config.num_labels) | |
) | |
self.init_weights() | |
def forward(self, input_ids, attention_mask=None, labels=None): | |
x = self.embedding(input_ids) | |
x = x.transpose(0, 1) | |
for layer in self.layers: | |
x = layer(x, src_key_padding_mask=(attention_mask == 0) if attention_mask is not None else None) | |
x = x.mean(dim=0) | |
logits = self.classifier(x) | |
if labels is not None: | |
loss = nn.functional.cross_entropy(logits, labels) | |
return {"loss": loss, "logits": logits} | |
return {"logits": logits} | |
def save_pretrained(self, save_directory): | |
import os | |
os.makedirs(save_directory, exist_ok=True) | |
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin") | |
with open(f"{save_directory}/config.json", "w") as f: | |
f.write(self.config.to_json_string()) | |
def from_pretrained(cls, load_directory): | |
config_path = f"{load_directory}/config.json" | |
model_path = f"{load_directory}/pytorch_model.bin" | |
config = EvoTransformerConfig.from_json_file(config_path) | |
model = cls(config) | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
return model | |