Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PretrainedConfig | |
class EvoTransformerConfig(PretrainedConfig): | |
def __init__( | |
self, | |
hidden_size=384, | |
num_layers=6, | |
num_labels=2, | |
num_heads=6, | |
ffn_dim=1024, | |
use_memory=False, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.num_labels = num_labels | |
self.num_heads = num_heads | |
self.ffn_dim = ffn_dim | |
self.use_memory = use_memory | |
class EvoTransformerForClassification(PreTrainedModel): | |
config_class = EvoTransformerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
# Expose architecture attributes for dashboard | |
self.num_layers = config.num_layers | |
self.num_heads = config.num_heads | |
self.ffn_dim = config.ffn_dim | |
self.use_memory = config.use_memory | |
self.embedding = nn.Embedding(30522, config.hidden_size) # BERT vocab size | |
self.layers = nn.ModuleList([ | |
nn.TransformerEncoderLayer( | |
d_model=config.hidden_size, | |
nhead=config.num_heads, | |
dim_feedforward=config.ffn_dim | |
) | |
for _ in range(config.num_layers) | |
]) | |
self.classifier = nn.Sequential( | |
nn.Linear(config.hidden_size, 256), | |
nn.ReLU(), | |
nn.Linear(256, config.num_labels) | |
) | |
self.init_weights() | |
def forward(self, input_ids, attention_mask=None, labels=None): | |
x = self.embedding(input_ids) # [batch, seq_len, hidden_size] | |
x = x.transpose(0, 1) # Transformer expects [seq_len, batch, hidden_size] | |
for layer in self.layers: | |
x = layer(x, src_key_padding_mask=(attention_mask == 0) if attention_mask is not None else None) | |
x = x.mean(dim=0) # mean pooling over seq_len | |
logits = self.classifier(x) | |
if labels is not None: | |
loss = nn.functional.cross_entropy(logits, labels) | |
return loss, logits | |
return logits | |
def save_pretrained(self, save_directory): | |
import os, json | |
os.makedirs(save_directory, exist_ok=True) | |
torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin") | |
with open(f"{save_directory}/config.json", "w") as f: | |
f.write(self.config.to_json_string()) | |
def from_pretrained(cls, load_directory): | |
config_path = f"{load_directory}/config.json" | |
model_path = f"{load_directory}/pytorch_model.bin" | |
config = EvoTransformerConfig.from_json_file(config_path) | |
model = cls(config) | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
return model | |