HemanM commited on
Commit
6e70824
Β·
verified Β·
1 Parent(s): 112a2a9

Create watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +122 -0
watchdog.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # watchdog.py
2
+
3
+ import firebase_admin
4
+ from firebase_admin import credentials, firestore
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from model import EvoTransformer # make sure this is in your project
11
+ import time
12
+ import datetime
13
+ import os
14
+
15
+ # βœ… Firebase Setup
16
+ if not firebase_admin._apps:
17
+ cred = credentials.Certificate("evotransformer-firebase-adminsdk-fbsvc-37a4b838aa.json")
18
+ firebase_admin.initialize_app(cred)
19
+
20
+ db = firestore.client()
21
+ COLLECTION = "evo_feedback_logs"
22
+ LAST_CHECK_FILE = "last_feedback_timestamp.txt"
23
+
24
+ # βœ… Dataset for training
25
+ class EvoDataset(Dataset):
26
+ def __init__(self, data):
27
+ self.data = data
28
+
29
+ def __getitem__(self, idx):
30
+ item = self.data[idx]
31
+ x = f"{item['goal']} [SEP] {item['solution1']} [SEP] {item['solution2']}"
32
+ y = 0 if item['correct'] == "Solution 1" else 1
33
+ return x, y
34
+
35
+ def __len__(self):
36
+ return len(self.data)
37
+
38
+ # βœ… Dummy tokenizer (replace with your tokenizer if needed)
39
+ def tokenize(text):
40
+ return torch.tensor([ord(c) % 128 for c in text[:256]])
41
+
42
+ # βœ… Fetch new data
43
+ def fetch_new_feedback():
44
+ if os.path.exists(LAST_CHECK_FILE):
45
+ with open(LAST_CHECK_FILE, "r") as f:
46
+ last_ts = f.read().strip()
47
+ else:
48
+ last_ts = "1970-01-01T00:00:00Z"
49
+
50
+ query = db.collection(COLLECTION).where("timestamp", ">", last_ts)
51
+ docs = list(query.stream())
52
+
53
+ feedbacks = []
54
+ latest_ts = last_ts
55
+ for doc in docs:
56
+ data = doc.to_dict()
57
+ if all(k in data for k in ["goal", "sol1", "sol2", "correct"]):
58
+ feedbacks.append({
59
+ "goal": data["goal"],
60
+ "solution1": data["sol1"],
61
+ "solution2": data["sol2"],
62
+ "correct": data["correct"]
63
+ })
64
+ latest_ts = max(latest_ts, data.get("timestamp", last_ts))
65
+
66
+ if feedbacks:
67
+ with open(LAST_CHECK_FILE, "w") as f:
68
+ f.write(latest_ts)
69
+
70
+ return feedbacks
71
+
72
+ # βœ… Train Evo on new data
73
+ def train_on_feedback(feedbacks):
74
+ if not feedbacks:
75
+ print("No new feedback to train on.")
76
+ return
77
+
78
+ print(f"πŸ” Retraining on {len(feedbacks)} new examples...")
79
+
80
+ dataset = EvoDataset(feedbacks)
81
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
82
+
83
+ model = EvoTransformer()
84
+ if os.path.exists("trained_model.pt"):
85
+ model.load_state_dict(torch.load("trained_model.pt"))
86
+
87
+ criterion = nn.CrossEntropyLoss()
88
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
89
+
90
+ model.train()
91
+ for epoch in range(3): # quick fine-tuning
92
+ total_loss = 0
93
+ correct = 0
94
+ for inputs, labels in dataloader:
95
+ inputs = torch.stack([tokenize(x) for x in inputs])
96
+ optimizer.zero_grad()
97
+ outputs = model(inputs)
98
+ loss = criterion(outputs, labels)
99
+ loss.backward()
100
+ optimizer.step()
101
+ total_loss += loss.item()
102
+ correct += (outputs.argmax(dim=1) == labels).sum().item()
103
+
104
+ acc = correct / len(dataset)
105
+ print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.2%}")
106
+
107
+ torch.save(model.state_dict(), "trained_model.pt")
108
+ print("βœ… Updated model saved.")
109
+
110
+ # βœ… Watch Loop
111
+ def watch():
112
+ print("🧠 Evo Watchdog started...")
113
+ while True:
114
+ try:
115
+ new_data = fetch_new_feedback()
116
+ train_on_feedback(new_data)
117
+ except Exception as e:
118
+ print(f"⚠️ Error: {str(e)}")
119
+ time.sleep(60) # check every 60 seconds
120
+
121
+ if __name__ == "__main__":
122
+ watch()