Spaces:
Running
Running
# watchdog.py | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from transformers import BertTokenizer | |
from torch.utils.data import DataLoader, Dataset | |
from evo_model import EvoTransformerForClassification, EvoTransformerConfig | |
# Initialize Firebase | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase_key.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
# Dataset for training | |
class FeedbackDataset(Dataset): | |
def __init__(self, records, tokenizer, max_length=64): | |
self.records = records | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
self.label_map = {"Solution 1": 0, "Solution 2": 1} | |
def __len__(self): | |
return len(self.records) | |
def __getitem__(self, idx): | |
row = self.records[idx] | |
combined = f"Goal: {row['goal']} Option 1: {row['solution_1']} Option 2: {row['solution_2']}" | |
inputs = self.tokenizer(combined, padding="max_length", truncation=True, | |
max_length=self.max_length, return_tensors="pt") | |
label = self.label_map[row["correct_answer"]] | |
return { | |
"input_ids": inputs["input_ids"].squeeze(0), | |
"attention_mask": inputs["attention_mask"].squeeze(0), | |
"labels": torch.tensor(label) | |
} | |
# Manual retrain trigger | |
def manual_retrain(): | |
try: | |
# Step 1: Fetch feedback data from Firestore | |
docs = db.collection("evo_feedback_logs").stream() | |
feedback_data = [doc.to_dict() for doc in docs if "goal" in doc.to_dict()] | |
if len(feedback_data) < 5: | |
print("[Retrain Skipped] Not enough feedback.") | |
return False | |
# Step 2: Load tokenizer and dataset | |
dataset = FeedbackDataset(feedback_data, tokenizer) | |
loader = DataLoader(dataset, batch_size=4, shuffle=True) | |
# Step 3: Load model | |
config = EvoTransformerConfig() | |
model = EvoTransformerForClassification(config) | |
model.train() | |
# Step 4: Define optimizer and loss | |
optimizer = optim.Adam(model.parameters(), lr=2e-5) | |
loss_fn = nn.CrossEntropyLoss() | |
# Step 5: Train | |
for epoch in range(3): | |
total_loss = 0 | |
for batch in loader: | |
optimizer.zero_grad() | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
logits = model(input_ids) | |
loss = loss_fn(logits, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"[Retrain] Epoch {epoch + 1} Loss: {total_loss:.4f}") | |
# Step 6: Save updated model | |
torch.save(model.state_dict(), "trained_model.pt") | |
print("✅ Evo updated with latest feedback.") | |
return True | |
except Exception as e: | |
print(f"[Retrain Error] {e}") | |
return False | |