EvoTransformer-v2.1 / init_model.py
HemanM's picture
Update init_model.py
34f85d1 verified
raw
history blame
1.49 kB
import torch
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, TensorDataset
import os
def retrain_model():
print("πŸ”„ Starting Evo retrain...")
# Sample retraining data
examples = [
"Goal: House on fire. Option 1: Exit house. Option 2: Stay in house.",
"Goal: Wet floor. Option 1: Walk slowly. Option 2: Run fast.",
"Goal: Loud music. Option 1: Turn it down. Option 2: Ignore it."
]
labels = [0, 0, 0] # Option 1 is correct
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = EvoTransformerForClassification(EvoTransformerConfig())
inputs = tokenizer(examples, padding=True, truncation=True, return_tensors="pt")
labels_tensor = torch.tensor(labels)
dataset = TensorDataset(inputs["input_ids"], inputs["attention_mask"], labels_tensor)
dataloader = DataLoader(dataset, batch_size=2)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(2):
for input_ids, attention_mask, labels_batch in dataloader:
optimizer.zero_grad()
loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels_batch)
loss.backward()
optimizer.step()
os.makedirs("trained_model", exist_ok=True)
model.save_pretrained("trained_model")
print("βœ… Evo retrained and saved.")