EvoTransformer-v2.1 / init_model.py
HemanM's picture
Update init_model.py
fd0302f verified
raw
history blame
906 Bytes
import os
from transformers import BertTokenizer
from evo_model import EvoTransformerConfig, EvoTransformerForClassification
def initialize_and_save_model():
# Step 1: Initialize configuration with architecture info
config = EvoTransformerConfig(
hidden_size=384,
num_layers=6,
num_labels=2,
num_heads=6,
ffn_dim=1024,
use_memory=False
)
# Step 2: Initialize model
model = EvoTransformerForClassification(config)
# Step 3: Save model
os.makedirs("trained_model", exist_ok=True)
model.save_pretrained("trained_model")
# Step 4: Save tokenizer (BERT-based)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.save_pretrained("trained_model")
print("✅ EvoTransformer and tokenizer initialized and saved to 'trained_model/'")
if __name__ == "__main__":
initialize_and_save_model()