HemanM commited on
Commit
c605926
·
verified ·
1 Parent(s): e897bf0

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +65 -33
watchdog.py CHANGED
@@ -1,35 +1,67 @@
 
 
 
1
  import torch
2
- import pandas as pd
3
- from evo_model import EvoTransformer, train_evo_transformer
4
- from datasets import load_dataset
5
- import os
6
-
7
- def manual_retrain():
8
- try:
9
- # Load feedback data from Firestore
10
- from google.cloud import firestore
11
- db = firestore.Client.from_service_account_json("firebase_key.json")
12
- docs = db.collection("evo_feedback_logs").stream()
13
- data = [doc.to_dict() for doc in docs if "goal" in doc.to_dict()]
14
- if not data:
15
- print("No feedback data available.")
16
- return False
17
-
18
- # Convert to training format
19
- rows = []
20
- for d in data:
21
- question = d["goal"]
22
- option1 = d["sol1"]
23
- option2 = d["sol2"]
24
- correct = d["correct"]
25
- label = 0 if correct == "Solution 1" else 1
26
- rows.append((question, option1, option2, label))
27
- df = pd.DataFrame(rows, columns=["goal", "sol1", "sol2", "label"])
28
-
29
- # Train the Evo model (minimal epochs to simulate update)
30
- train_evo_transformer(df, epochs=1)
31
-
32
- return True
33
- except Exception as e:
34
- print(f"[Retrain Error] {e}")
35
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from firebase_admin import firestore
2
+ from evo_model import EvoTransformerForClassification
3
+ from transformers import AutoTokenizer
4
  import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+
9
+ class EvoDataset(Dataset):
10
+ def __init__(self, texts, labels, tokenizer, max_length=64):
11
+ self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
12
+ self.labels = labels
13
+
14
+ def __getitem__(self, idx):
15
+ input_ids = torch.tensor(self.encodings["input_ids"][idx])
16
+ label = torch.tensor(self.labels[idx])
17
+ return input_ids, label
18
+
19
+ def __len__(self):
20
+ return len(self.labels)
21
+
22
+ def train_evo_transformer():
23
+ db = firestore.client()
24
+ docs = db.collection("evo_feedback_logs").stream()
25
+
26
+ goals, solution1, solution2, labels = [], [], [], []
27
+ for doc in docs:
28
+ d = doc.to_dict()
29
+ if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
30
+ goals.append(d["goal"])
31
+ solution1.append(d["solution_1"])
32
+ solution2.append(d["solution_2"])
33
+ labels.append(0 if d["correct_answer"] == "Solution 1" else 1)
34
+
35
+ if not goals:
36
+ print("[Retrain Error] No training data found.")
 
37
  return False
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
40
+ texts = [f"{g} [SEP] {s1} [SEP] {s2}" for g, s1, s2 in zip(goals, solution1, solution2)]
41
+ dataset = EvoDataset(texts, labels, tokenizer)
42
+ loader = DataLoader(dataset, batch_size=4, shuffle=True)
43
+
44
+ config = {
45
+ "vocab_size": tokenizer.vocab_size,
46
+ "d_model": 256,
47
+ "nhead": 4,
48
+ "dim_feedforward": 512,
49
+ "num_hidden_layers": 4
50
+ }
51
+ model = EvoTransformerForClassification.from_config_dict(config)
52
+ model.train()
53
+
54
+ optimizer = optim.AdamW(model.parameters(), lr=1e-4)
55
+ criterion = nn.CrossEntropyLoss()
56
+
57
+ for epoch in range(3): # quick training
58
+ for input_ids, label in loader:
59
+ logits = model(input_ids)
60
+ loss = criterion(logits, label)
61
+ loss.backward()
62
+ optimizer.step()
63
+ optimizer.zero_grad()
64
+
65
+ model.save_pretrained("trained_evo")
66
+ print("✅ Retraining complete.")
67
+ return True