EvoTransformer-v2.1 / init_model.py
HemanM's picture
Update init_model.py
99f288e verified
raw
history blame
895 Bytes
import os
from transformers import BertTokenizer
from evo_model import EvoTransformerConfig, EvoTransformerForClassification
def initialize_and_save_model():
# Step 1: Initialize configuration
config = EvoTransformerConfig()
# Step 2: Initialize model
model = EvoTransformerForClassification(config)
# Step 3: Save model
os.makedirs("trained_model", exist_ok=True)
model.save_pretrained("trained_model")
# Step 4: Save tokenizer (BERT-based)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.save_pretrained("trained_model")
print("✅ EvoTransformer and tokenizer initialized and saved to 'trained_model/'")
def load_model():
model = EvoTransformerForClassification.from_pretrained("trained_model")
return model
# Optional: reinitialize if run directly
if __name__ == "__main__":
initialize_and_save_model()