HemanM commited on
Commit
f997552
·
verified ·
1 Parent(s): 1e3a2f8

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +31 -123
watchdog.py CHANGED
@@ -1,127 +1,35 @@
1
- # watchdog.py
2
-
3
- import firebase_admin
4
- from firebase_admin import credentials, firestore
5
- import pandas as pd
6
  import torch
7
- from torch.utils.data import Dataset, DataLoader
8
- import torch.nn as nn
9
- import torch.optim as optim
10
- from model import EvoTransformer # make sure this is in your project
11
- import time
12
- import datetime
13
  import os
14
 
15
- # ✅ Firebase Setup
16
- if not firebase_admin._apps:
17
- cred = credentials.Certificate("evotransformer-firebase-adminsdk-fbsvc-37a4b838aa.json")
18
- firebase_admin.initialize_app(cred)
19
-
20
- db = firestore.client()
21
- COLLECTION = "evo_feedback_logs"
22
- LAST_CHECK_FILE = "last_feedback_timestamp.txt"
23
-
24
- # ✅ Dataset for training
25
- class EvoDataset(Dataset):
26
- def __init__(self, data):
27
- self.data = data
28
-
29
- def __getitem__(self, idx):
30
- item = self.data[idx]
31
- x = f"{item['goal']} [SEP] {item['solution1']} [SEP] {item['solution2']}"
32
- y = 0 if item['correct'] == "Solution 1" else 1
33
- return x, y
34
-
35
- def __len__(self):
36
- return len(self.data)
37
-
38
- # ✅ Dummy tokenizer (replace with your tokenizer if needed)
39
- def tokenize(text):
40
- return torch.tensor([ord(c) % 128 for c in text[:256]])
41
-
42
- # ✅ Fetch new data
43
- def fetch_new_feedback():
44
- if os.path.exists(LAST_CHECK_FILE):
45
- with open(LAST_CHECK_FILE, "r") as f:
46
- last_ts = f.read().strip()
47
- else:
48
- last_ts = "1970-01-01T00:00:00Z"
49
-
50
- query = db.collection(COLLECTION).where("timestamp", ">", last_ts)
51
- docs = list(query.stream())
52
-
53
- feedbacks = []
54
- latest_ts = last_ts
55
- for doc in docs:
56
- data = doc.to_dict()
57
- if all(k in data for k in ["goal", "sol1", "sol2", "correct"]):
58
- feedbacks.append({
59
- "goal": data["goal"],
60
- "solution1": data["sol1"],
61
- "solution2": data["sol2"],
62
- "correct": data["correct"]
63
- })
64
- latest_ts = max(latest_ts, data.get("timestamp", last_ts))
65
-
66
- if feedbacks:
67
- with open(LAST_CHECK_FILE, "w") as f:
68
- f.write(latest_ts)
69
-
70
- return feedbacks
71
-
72
- # ✅ Train Evo on new data
73
- def train_on_feedback(feedbacks):
74
- if not feedbacks:
75
- print("No new feedback to train on.")
76
- return
77
-
78
- print(f"🔁 Retraining on {len(feedbacks)} new examples...")
79
-
80
- dataset = EvoDataset(feedbacks)
81
- dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
82
-
83
- model = EvoTransformer()
84
- if os.path.exists("trained_model.pt"):
85
- model.load_state_dict(torch.load("trained_model.pt"))
86
-
87
- criterion = nn.CrossEntropyLoss()
88
- optimizer = optim.Adam(model.parameters(), lr=0.001)
89
-
90
- model.train()
91
- for epoch in range(3): # quick fine-tuning
92
- total_loss = 0
93
- correct = 0
94
- for inputs, labels in dataloader:
95
- inputs = torch.stack([tokenize(x) for x in inputs])
96
- optimizer.zero_grad()
97
- outputs = model(inputs)
98
- loss = criterion(outputs, labels)
99
- loss.backward()
100
- optimizer.step()
101
- total_loss += loss.item()
102
- correct += (outputs.argmax(dim=1) == labels).sum().item()
103
-
104
- acc = correct / len(dataset)
105
- print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.2%}")
106
-
107
- torch.save(model.state_dict(), "trained_model.pt")
108
- print("✅ Updated model saved.")
109
-
110
- # ✅ Watch Loop
111
- def watch():
112
- print("🧠 Evo Watchdog started...")
113
- while True:
114
- try:
115
- new_data = fetch_new_feedback()
116
- train_on_feedback(new_data)
117
- except Exception as e:
118
- print(f"⚠️ Error: {str(e)}")
119
- time.sleep(60) # check every 60 seconds
120
-
121
  def manual_retrain():
122
- new_data = fetch_new_feedback()
123
- train_on_feedback(new_data)
124
-
125
- # Optional: only run loop if executed directly
126
- if __name__ == "__main__":
127
- watch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import pandas as pd
3
+ from evo_model import EvoTransformer, train_evo_transformer
4
+ from datasets import load_dataset
 
 
 
5
  import os
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def manual_retrain():
8
+ try:
9
+ # Load feedback data from Firestore
10
+ from google.cloud import firestore
11
+ db = firestore.Client.from_service_account_json("firebase_key.json")
12
+ docs = db.collection("evo_feedback_logs").stream()
13
+ data = [doc.to_dict() for doc in docs if "goal" in doc.to_dict()]
14
+ if not data:
15
+ print("No feedback data available.")
16
+ return False
17
+
18
+ # Convert to training format
19
+ rows = []
20
+ for d in data:
21
+ question = d["goal"]
22
+ option1 = d["sol1"]
23
+ option2 = d["sol2"]
24
+ correct = d["correct"]
25
+ label = 0 if correct == "Solution 1" else 1
26
+ rows.append((question, option1, option2, label))
27
+ df = pd.DataFrame(rows, columns=["goal", "sol1", "sol2", "label"])
28
+
29
+ # Train the Evo model (minimal epochs to simulate update)
30
+ train_evo_transformer(df, epochs=1)
31
+
32
+ return True
33
+ except Exception as e:
34
+ print(f"[Retrain Error] {e}")
35
+ return False