EvoTransformer-v2.1 / watchdog.py
HemanM's picture
Update watchdog.py
da42a90 verified
raw
history blame
3.13 kB
# watchdog.py
import firebase_admin
from firebase_admin import credentials, firestore
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer
from torch.utils.data import DataLoader, Dataset
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
# Initialize Firebase
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Dataset for training
class FeedbackDataset(Dataset):
def __init__(self, records, tokenizer, max_length=64):
self.records = records
self.tokenizer = tokenizer
self.max_length = max_length
self.label_map = {"Solution 1": 0, "Solution 2": 1}
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
row = self.records[idx]
combined = f"Goal: {row['goal']} Option 1: {row['solution_1']} Option 2: {row['solution_2']}"
inputs = self.tokenizer(combined, padding="max_length", truncation=True,
max_length=self.max_length, return_tensors="pt")
label = self.label_map[row["correct_answer"]]
return {
"input_ids": inputs["input_ids"].squeeze(0),
"attention_mask": inputs["attention_mask"].squeeze(0),
"labels": torch.tensor(label)
}
# Manual retrain trigger
def manual_retrain():
try:
# Step 1: Fetch feedback data from Firestore
docs = db.collection("evo_feedback_logs").stream()
feedback_data = [doc.to_dict() for doc in docs if "goal" in doc.to_dict()]
if len(feedback_data) < 5:
print("[Retrain Skipped] Not enough feedback.")
return False
# Step 2: Load tokenizer and dataset
dataset = FeedbackDataset(feedback_data, tokenizer)
loader = DataLoader(dataset, batch_size=4, shuffle=True)
# Step 3: Load model
config = EvoTransformerConfig()
model = EvoTransformerForClassification(config)
model.train()
# Step 4: Define optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()
# Step 5: Train
for epoch in range(3):
total_loss = 0
for batch in loader:
optimizer.zero_grad()
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
logits = model(input_ids)
loss = loss_fn(logits, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"[Retrain] Epoch {epoch + 1} Loss: {total_loss:.4f}")
# Step 6: Save updated model
torch.save(model.state_dict(), "trained_model.pt")
print("✅ Evo updated with latest feedback.")
return True
except Exception as e:
print(f"[Retrain Error] {e}")
return False