HemanM commited on
Commit
5876a92
Β·
verified Β·
1 Parent(s): 92432bf

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +43 -41
watchdog.py CHANGED
@@ -1,47 +1,44 @@
1
- from evo_model import EvoTransformerForClassification
2
- from transformers import AutoTokenizer
3
  import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- from torch.utils.data import Dataset, DataLoader
7
- from firebase_admin import firestore
 
 
8
 
9
- class EvoDataset(Dataset):
10
- def __init__(self, texts, labels, tokenizer, max_length=64):
11
- self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length)
12
- self.labels = labels
13
 
14
- def __getitem__(self, idx):
15
- input_ids = torch.tensor(self.encodings["input_ids"][idx])
16
- label = torch.tensor(self.labels[idx])
17
- return input_ids, label
18
 
19
- def __len__(self):
20
- return len(self.labels)
21
 
22
  def manual_retrain():
23
  try:
24
- db = firestore.client()
25
  docs = db.collection("evo_feedback_logs").stream()
26
-
27
- goals, solution1, solution2, labels = [], [], [], []
28
  for doc in docs:
29
  d = doc.to_dict()
30
  if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
31
- goals.append(d["goal"])
32
- solution1.append(d["solution_1"])
33
- solution2.append(d["solution_2"])
34
- labels.append(0 if d["correct_answer"] == "Solution 1" else 1)
35
 
36
- if not goals:
37
- print("[Retrain Error] No training data found.")
38
  return False
39
 
40
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
41
- texts = [f"{g} [SEP] {s1} [SEP] {s2}" for g, s1, s2 in zip(goals, solution1, solution2)]
42
- dataset = EvoDataset(texts, labels, tokenizer)
43
- loader = DataLoader(dataset, batch_size=4, shuffle=True)
44
 
 
45
  config = {
46
  "vocab_size": tokenizer.vocab_size,
47
  "d_model": 256,
@@ -49,23 +46,28 @@ def manual_retrain():
49
  "dim_feedforward": 512,
50
  "num_hidden_layers": 4
51
  }
52
- model = EvoTransformerForClassification.from_config_dict(config)
53
- model.train()
54
 
55
- optimizer = optim.AdamW(model.parameters(), lr=1e-4)
56
- criterion = nn.CrossEntropyLoss()
 
57
 
 
 
58
  for epoch in range(3):
59
- for input_ids, label in loader:
60
- logits = model(input_ids)
61
- loss = criterion(logits, label)
62
- loss.backward()
63
- optimizer.step()
64
- optimizer.zero_grad()
65
 
66
- model.save_pretrained("trained_evo")
67
- print("βœ… Retraining complete.")
 
68
  return True
 
69
  except Exception as e:
70
  print(f"[Retrain Error] {e}")
71
  return False
 
1
+ # watchdog.py
2
+
3
  import torch
4
+ from evo_model import EvoTransformerForClassification, EvoTransformerConfig
5
+ from transformers import BertTokenizer
6
+ import firebase_admin
7
+ from firebase_admin import credentials, firestore
8
+ import os
9
+ from datetime import datetime
10
 
11
+ # βœ… Load tokenizer
12
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
 
 
13
 
14
+ # βœ… Init Firebase
15
+ if not firebase_admin._apps:
16
+ cred = credentials.Certificate("firebase_key.json")
17
+ firebase_admin.initialize_app(cred)
18
 
19
+ db = firestore.client()
 
20
 
21
  def manual_retrain():
22
  try:
23
+ # πŸ” Fetch feedback logs
24
  docs = db.collection("evo_feedback_logs").stream()
25
+ data = []
 
26
  for doc in docs:
27
  d = doc.to_dict()
28
  if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
29
+ label = 0 if d["correct_answer"] == "Solution 1" else 1
30
+ combined = f"{d['goal']} [SEP] {d['solution_1']} [SEP] {d['solution_2']}"
31
+ data.append((combined, label))
 
32
 
33
+ if not data:
34
+ print("❌ No valid training data found.")
35
  return False
36
 
37
+ # βœ… Tokenize
38
+ inputs = tokenizer([x[0] for x in data], padding=True, truncation=True, return_tensors="pt")
39
+ labels = torch.tensor([x[1] for x in data])
 
40
 
41
+ # βœ… Load config + model
42
  config = {
43
  "vocab_size": tokenizer.vocab_size,
44
  "d_model": 256,
 
46
  "dim_feedforward": 512,
47
  "num_hidden_layers": 4
48
  }
49
+ model_config = EvoTransformerConfig(**config)
50
+ model = EvoTransformerForClassification(model_config)
51
 
52
+ # βœ… Loss + optimizer
53
+ criterion = torch.nn.CrossEntropyLoss()
54
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
55
 
56
+ # βœ… Train (simple 3-epoch fine-tune)
57
+ model.train()
58
  for epoch in range(3):
59
+ optimizer.zero_grad()
60
+ outputs = model(inputs["input_ids"])
61
+ loss = criterion(outputs, labels)
62
+ loss.backward()
63
+ optimizer.step()
64
+ print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}")
65
 
66
+ # βœ… Save model
67
+ torch.save(model.state_dict(), "trained_model.pt")
68
+ print("βœ… Evo updated via retrain from feedback!")
69
  return True
70
+
71
  except Exception as e:
72
  print(f"[Retrain Error] {e}")
73
  return False