EvoTransformer-v2.1 / watchdog.py
HemanM's picture
Update watchdog.py
e7e30db verified
raw
history blame
1.75 kB
import os
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from model import SimpleEvoModel
# Initialize Firebase if not already initialized
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
def fetch_training_data():
logs_ref = db.collection("evo_feedback")
docs = logs_ref.stream()
inputs, labels = [], []
for doc in docs:
data = doc.to_dict()
goal = data.get("prompt", "")
winner = data.get("winner", "")
if winner:
# Simulated encoding
vector = [float(ord(c) % 256) / 255.0 for c in (goal + winner)]
vector = vector[:768] + [0.0] * max(0, 768 - len(vector)) # pad/truncate
label = 0 if "1" in winner else 1
inputs.append(vector)
labels.append(label)
return torch.tensor(inputs, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)
def retrain_and_save():
X, y = fetch_training_data()
if len(X) < 2:
print("⚠️ Not enough training data.")
return
model = SimpleEvoModel()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(5):
optimizer.zero_grad()
output = model(X)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
# Save retrained model to trained_model/
os.makedirs("trained_model", exist_ok=True)
torch.save(model.state_dict(), "trained_model/pytorch_model.bin")
print("✅ EvoTransformer retrained and saved to trained_model/")
if __name__ == "__main__":
retrain_and_save()