EvoPlatformV3 / retrain_from_feedback.py
HemanM's picture
Update retrain_from_feedback.py
ab2f547 verified
raw
history blame
2.31 kB
# retrain_from_feedback.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from transformers import AutoTokenizer
from evo_architecture import mutate_genome, log_genome, default_config
from evo_model import EvoTransformerV22
import os
MODEL_PATH = "evo_hellaswag.pt"
CSV_PATH = "feedback_log.csv"
def train_evo():
if not os.path.exists(CSV_PATH):
print("⚠️ No feedback_log.csv file found.")
return
df = pd.read_csv(CSV_PATH)
if df.empty:
print("⚠️ feedback_log.csv is empty.")
return
# Step 1: Evolve new architecture
base_config = default_config()
evolved_config = mutate_genome(base_config)
print("🧬 New mutated config:", evolved_config)
# Step 2: Initialize model with evolved config
model = EvoTransformerV22(
num_layers=evolved_config["num_layers"],
num_heads=evolved_config["num_heads"],
ffn_dim=evolved_config["ffn_dim"],
memory_enabled=evolved_config["memory_enabled"]
)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
# Step 3: Train on feedback
total_loss = 0.0
for _, row in df.iterrows():
question = row["question"]
opt1 = row["option1"]
opt2 = row["option2"]
answer = row["answer"]
label = torch.tensor([1.0 if answer.strip() == opt2.strip() else 0.0])
input_text = f"{question} [SEP] {opt2 if label.item() == 1 else opt1}"
encoded = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
logits = model(encoded["input_ids"])
loss = F.binary_cross_entropy_with_logits(logits.squeeze(), label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
# Step 4: Save new model weights
torch.save(model.state_dict(), MODEL_PATH)
print("✅ Evo model retrained and saved.")
# Step 5: Log genome and score (loss as proxy)
avg_loss = total_loss / len(df)
score = 1.0 - avg_loss # Use (1 - loss) as crude fitness
log_genome(evolved_config, performance=round(score, 4))
print("🧬 Genome logged with score:", round(score, 4))