Spaces:
Running
Running
# watchdog.py | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
import pandas as pd | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn as nn | |
import torch.optim as optim | |
from model import EvoTransformer # make sure this is in your project | |
import time | |
import datetime | |
import os | |
# β Firebase Setup | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("evotransformer-firebase-adminsdk-fbsvc-37a4b838aa.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
COLLECTION = "evo_feedback_logs" | |
LAST_CHECK_FILE = "last_feedback_timestamp.txt" | |
# β Dataset for training | |
class EvoDataset(Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
x = f"{item['goal']} [SEP] {item['solution1']} [SEP] {item['solution2']}" | |
y = 0 if item['correct'] == "Solution 1" else 1 | |
return x, y | |
def __len__(self): | |
return len(self.data) | |
# β Dummy tokenizer (replace with your tokenizer if needed) | |
def tokenize(text): | |
return torch.tensor([ord(c) % 128 for c in text[:256]]) | |
# β Fetch new data | |
def fetch_new_feedback(): | |
if os.path.exists(LAST_CHECK_FILE): | |
with open(LAST_CHECK_FILE, "r") as f: | |
last_ts = f.read().strip() | |
else: | |
last_ts = "1970-01-01T00:00:00Z" | |
query = db.collection(COLLECTION).where("timestamp", ">", last_ts) | |
docs = list(query.stream()) | |
feedbacks = [] | |
latest_ts = last_ts | |
for doc in docs: | |
data = doc.to_dict() | |
if all(k in data for k in ["goal", "sol1", "sol2", "correct"]): | |
feedbacks.append({ | |
"goal": data["goal"], | |
"solution1": data["sol1"], | |
"solution2": data["sol2"], | |
"correct": data["correct"] | |
}) | |
latest_ts = max(latest_ts, data.get("timestamp", last_ts)) | |
if feedbacks: | |
with open(LAST_CHECK_FILE, "w") as f: | |
f.write(latest_ts) | |
return feedbacks | |
# β Train Evo on new data | |
def train_on_feedback(feedbacks): | |
if not feedbacks: | |
print("No new feedback to train on.") | |
return | |
print(f"π Retraining on {len(feedbacks)} new examples...") | |
dataset = EvoDataset(feedbacks) | |
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | |
model = EvoTransformer() | |
if os.path.exists("trained_model.pt"): | |
model.load_state_dict(torch.load("trained_model.pt")) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.Adam(model.parameters(), lr=0.001) | |
model.train() | |
for epoch in range(3): # quick fine-tuning | |
total_loss = 0 | |
correct = 0 | |
for inputs, labels in dataloader: | |
inputs = torch.stack([tokenize(x) for x in inputs]) | |
optimizer.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
correct += (outputs.argmax(dim=1) == labels).sum().item() | |
acc = correct / len(dataset) | |
print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.2%}") | |
torch.save(model.state_dict(), "trained_model.pt") | |
print("β Updated model saved.") | |
# β Watch Loop | |
def watch(): | |
print("π§ Evo Watchdog started...") | |
while True: | |
try: | |
new_data = fetch_new_feedback() | |
train_on_feedback(new_data) | |
except Exception as e: | |
print(f"β οΈ Error: {str(e)}") | |
time.sleep(60) # check every 60 seconds | |
def manual_retrain(): | |
new_data = fetch_new_feedback() | |
train_on_feedback(new_data) | |
# Optional: only run loop if executed directly | |
if __name__ == "__main__": | |
watch() |