Spaces:
Sleeping
Sleeping
# watchdog.py | |
import torch | |
from evo_model import EvoTransformerForClassification, EvoTransformerConfig | |
from transformers import BertTokenizer | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
import os | |
from datetime import datetime | |
# β Load tokenizer | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
# β Init Firebase | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase_key.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
def manual_retrain(): | |
try: | |
# π Fetch feedback logs | |
docs = db.collection("evo_feedback_logs").stream() | |
data = [] | |
for doc in docs: | |
d = doc.to_dict() | |
if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]): | |
label = 0 if d["correct_answer"] == "Solution 1" else 1 | |
combined = f"{d['goal']} [SEP] {d['solution_1']} [SEP] {d['solution_2']}" | |
data.append((combined, label)) | |
if not data: | |
print("β No valid training data found.") | |
return False | |
# β Tokenize | |
inputs = tokenizer([x[0] for x in data], padding=True, truncation=True, return_tensors="pt") | |
labels = torch.tensor([x[1] for x in data]) | |
# β Load config + model | |
config = { | |
"vocab_size": tokenizer.vocab_size, | |
"d_model": 256, | |
"nhead": 4, | |
"dim_feedforward": 512, | |
"num_hidden_layers": 4 | |
} | |
model_config = EvoTransformerConfig(**config) | |
model = EvoTransformerForClassification(model_config) | |
# β Loss + optimizer | |
criterion = torch.nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
# β Train (simple 3-epoch fine-tune) | |
model.train() | |
for epoch in range(3): | |
optimizer.zero_grad() | |
outputs = model(inputs["input_ids"]) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}") | |
# β Save model | |
torch.save(model.state_dict(), "trained_model.pt") | |
print("β Evo updated via retrain from feedback!") | |
return True | |
except Exception as e: | |
print(f"[Retrain Error] {e}") | |
return False | |