EvoTransformer-v2.1 / watchdog.py
HemanM's picture
Update watchdog.py
75ff33c verified
raw
history blame
2.02 kB
# watchdog.py
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
from firebase_admin import firestore
import pandas as pd
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def load_feedback_data():
db = firestore.client()
docs = db.collection("evo_feedback_logs").stream()
data = []
for doc in docs:
d = doc.to_dict()
if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
data.append((
d["goal"],
d["solution_1"],
d["solution_2"],
0 if d["correct_answer"] == "Solution 1" else 1
))
return pd.DataFrame(data, columns=["goal", "sol1", "sol2", "label"])
def encode(goal, sol1, sol2):
prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
def manual_retrain():
try:
data = load_feedback_data()
if data.empty:
print("[Retrain Error] No training data found.")
return False
model = EvoTransformerForClassification.from_pretrained("trained_model")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
for _, row in data.sample(frac=1).iterrows():
encoded = encode(row["goal"], row["sol1"], row["sol2"])
labels = torch.tensor([row["label"]])
outputs = model(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"], labels=labels)
if isinstance(outputs, tuple):
loss = outputs[0]
else:
loss = outputs
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.save_pretrained("trained_model")
print("✅ Evo retrained and saved.")
return True
except Exception as e:
print(f"[Retrain Error] {e}")
return False