EvoTransformer-v2.1 / watchdog.py
HemanM's picture
Update watchdog.py
cae5830 verified
raw
history blame
2.28 kB
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
from firebase_admin import firestore
import pandas as pd
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def load_feedback_data():
db = firestore.client()
docs = db.collection("evo_feedback_logs").stream()
data = []
for doc in docs:
d = doc.to_dict()
if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
data.append((
d["goal"],
d["solution_1"],
d["solution_2"],
0 if d["correct_answer"] == "Solution 1" else 1
))
return pd.DataFrame(data, columns=["goal", "sol1", "sol2", "label"])
def encode(goal, sol1, sol2):
prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids
def manual_retrain():
try:
data = load_feedback_data()
if data.empty:
print("[Retrain Error] No training data found.")
return False
model = EvoTransformerForClassification.from_pretrained("trained_model")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()
model.train()
for _, row in data.sample(frac=1).iterrows():
inputs = encode(row["goal"], row["sol1"], row["sol2"])
label = torch.tensor([row["label"]])
outputs = model(inputs)
if isinstance(outputs, tuple):
logits = outputs[0]
elif hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
if logits.ndim == 2 and label.ndim == 1:
loss = loss_fn(logits, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
else:
print("[Retrain Warning] Shape mismatch, skipping one example.")
model.save_pretrained("trained_model")
print("✅ Evo retrained and saved.")
return True
except Exception as e:
print(f"[Retrain Error] {e}")
return False