HemanM commited on
Commit
cae5830
·
verified ·
1 Parent(s): da42a90

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +45 -71
watchdog.py CHANGED
@@ -1,90 +1,64 @@
1
- # watchdog.py
2
-
3
- import firebase_admin
4
- from firebase_admin import credentials, firestore
5
  import torch
6
- import torch.nn as nn
7
- import torch.optim as optim
8
- from transformers import BertTokenizer
9
- from torch.utils.data import DataLoader, Dataset
10
- from evo_model import EvoTransformerForClassification, EvoTransformerConfig
11
-
12
- # Initialize Firebase
13
- if not firebase_admin._apps:
14
- cred = credentials.Certificate("firebase_key.json")
15
- firebase_admin.initialize_app(cred)
16
 
17
- db = firestore.client()
18
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
 
20
- # Dataset for training
21
- class FeedbackDataset(Dataset):
22
- def __init__(self, records, tokenizer, max_length=64):
23
- self.records = records
24
- self.tokenizer = tokenizer
25
- self.max_length = max_length
26
- self.label_map = {"Solution 1": 0, "Solution 2": 1}
 
 
 
 
 
 
 
27
 
28
- def __len__(self):
29
- return len(self.records)
 
30
 
31
- def __getitem__(self, idx):
32
- row = self.records[idx]
33
- combined = f"Goal: {row['goal']} Option 1: {row['solution_1']} Option 2: {row['solution_2']}"
34
- inputs = self.tokenizer(combined, padding="max_length", truncation=True,
35
- max_length=self.max_length, return_tensors="pt")
36
- label = self.label_map[row["correct_answer"]]
37
- return {
38
- "input_ids": inputs["input_ids"].squeeze(0),
39
- "attention_mask": inputs["attention_mask"].squeeze(0),
40
- "labels": torch.tensor(label)
41
- }
42
-
43
- # Manual retrain trigger
44
  def manual_retrain():
45
  try:
46
- # Step 1: Fetch feedback data from Firestore
47
- docs = db.collection("evo_feedback_logs").stream()
48
- feedback_data = [doc.to_dict() for doc in docs if "goal" in doc.to_dict()]
49
-
50
- if len(feedback_data) < 5:
51
- print("[Retrain Skipped] Not enough feedback.")
52
  return False
53
 
54
- # Step 2: Load tokenizer and dataset
55
- dataset = FeedbackDataset(feedback_data, tokenizer)
56
- loader = DataLoader(dataset, batch_size=4, shuffle=True)
57
 
58
- # Step 3: Load model
59
- config = EvoTransformerConfig()
60
- model = EvoTransformerForClassification(config)
61
  model.train()
62
-
63
- # Step 4: Define optimizer and loss
64
- optimizer = optim.Adam(model.parameters(), lr=2e-5)
65
- loss_fn = nn.CrossEntropyLoss()
66
-
67
- # Step 5: Train
68
- for epoch in range(3):
69
- total_loss = 0
70
- for batch in loader:
 
 
 
71
  optimizer.zero_grad()
72
- input_ids = batch["input_ids"]
73
- attention_mask = batch["attention_mask"]
74
- labels = batch["labels"]
75
-
76
- logits = model(input_ids)
77
- loss = loss_fn(logits, labels)
78
  loss.backward()
79
  optimizer.step()
80
- total_loss += loss.item()
81
- print(f"[Retrain] Epoch {epoch + 1} Loss: {total_loss:.4f}")
82
 
83
- # Step 6: Save updated model
84
- torch.save(model.state_dict(), "trained_model.pt")
85
- print("✅ Evo updated with latest feedback.")
86
  return True
87
-
88
  except Exception as e:
89
  print(f"[Retrain Error] {e}")
90
  return False
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer
3
+ from evo_model import EvoTransformerForClassification
4
+ from firebase_admin import firestore
5
+ import pandas as pd
 
 
 
 
 
 
6
 
7
+ # Load tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
9
 
10
+ def load_feedback_data():
11
+ db = firestore.client()
12
+ docs = db.collection("evo_feedback_logs").stream()
13
+ data = []
14
+ for doc in docs:
15
+ d = doc.to_dict()
16
+ if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
17
+ data.append((
18
+ d["goal"],
19
+ d["solution_1"],
20
+ d["solution_2"],
21
+ 0 if d["correct_answer"] == "Solution 1" else 1
22
+ ))
23
+ return pd.DataFrame(data, columns=["goal", "sol1", "sol2", "label"])
24
 
25
+ def encode(goal, sol1, sol2):
26
+ prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
27
+ return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def manual_retrain():
30
  try:
31
+ data = load_feedback_data()
32
+ if data.empty:
33
+ print("[Retrain Error] No training data found.")
 
 
 
34
  return False
35
 
36
+ model = EvoTransformerForClassification.from_pretrained("trained_model")
37
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
38
+ loss_fn = torch.nn.CrossEntropyLoss()
39
 
 
 
 
40
  model.train()
41
+ for _, row in data.sample(frac=1).iterrows():
42
+ inputs = encode(row["goal"], row["sol1"], row["sol2"])
43
+ label = torch.tensor([row["label"]])
44
+ outputs = model(inputs)
45
+ if isinstance(outputs, tuple):
46
+ logits = outputs[0]
47
+ elif hasattr(outputs, "logits"):
48
+ logits = outputs.logits
49
+ else:
50
+ logits = outputs
51
+ if logits.ndim == 2 and label.ndim == 1:
52
+ loss = loss_fn(logits, label)
53
  optimizer.zero_grad()
 
 
 
 
 
 
54
  loss.backward()
55
  optimizer.step()
56
+ else:
57
+ print("[Retrain Warning] Shape mismatch, skipping one example.")
58
 
59
+ model.save_pretrained("trained_model")
60
+ print("✅ Evo retrained and saved.")
 
61
  return True
 
62
  except Exception as e:
63
  print(f"[Retrain Error] {e}")
64
  return False