EvoTransformer-v2.1 / watchdog.py
HemanM's picture
Update watchdog.py
5876a92 verified
raw
history blame
2.39 kB
# watchdog.py
import torch
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer
import firebase_admin
from firebase_admin import credentials, firestore
import os
from datetime import datetime
# βœ… Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# βœ… Init Firebase
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
def manual_retrain():
try:
# πŸ” Fetch feedback logs
docs = db.collection("evo_feedback_logs").stream()
data = []
for doc in docs:
d = doc.to_dict()
if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
label = 0 if d["correct_answer"] == "Solution 1" else 1
combined = f"{d['goal']} [SEP] {d['solution_1']} [SEP] {d['solution_2']}"
data.append((combined, label))
if not data:
print("❌ No valid training data found.")
return False
# βœ… Tokenize
inputs = tokenizer([x[0] for x in data], padding=True, truncation=True, return_tensors="pt")
labels = torch.tensor([x[1] for x in data])
# βœ… Load config + model
config = {
"vocab_size": tokenizer.vocab_size,
"d_model": 256,
"nhead": 4,
"dim_feedforward": 512,
"num_hidden_layers": 4
}
model_config = EvoTransformerConfig(**config)
model = EvoTransformerForClassification(model_config)
# βœ… Loss + optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# βœ… Train (simple 3-epoch fine-tune)
model.train()
for epoch in range(3):
optimizer.zero_grad()
outputs = model(inputs["input_ids"])
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}")
# βœ… Save model
torch.save(model.state_dict(), "trained_model.pt")
print("βœ… Evo updated via retrain from feedback!")
return True
except Exception as e:
print(f"[Retrain Error] {e}")
return False