HemanM commited on
Commit
92432bf
·
verified ·
1 Parent(s): 923fd6e

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +45 -41
watchdog.py CHANGED
@@ -1,10 +1,10 @@
1
- from firebase_admin import firestore
2
  from evo_model import EvoTransformerForClassification
3
  from transformers import AutoTokenizer
4
  import torch
5
- from torch.utils.data import Dataset, DataLoader
6
  import torch.nn as nn
7
  import torch.optim as optim
 
 
8
 
9
  class EvoDataset(Dataset):
10
  def __init__(self, texts, labels, tokenizer, max_length=64):
@@ -19,49 +19,53 @@ class EvoDataset(Dataset):
19
  def __len__(self):
20
  return len(self.labels)
21
 
22
- def train_evo_transformer():
23
- db = firestore.client()
24
- docs = db.collection("evo_feedback_logs").stream()
 
25
 
26
- goals, solution1, solution2, labels = [], [], [], []
27
- for doc in docs:
28
- d = doc.to_dict()
29
- if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
30
- goals.append(d["goal"])
31
- solution1.append(d["solution_1"])
32
- solution2.append(d["solution_2"])
33
- labels.append(0 if d["correct_answer"] == "Solution 1" else 1)
34
 
35
- if not goals:
36
- print("[Retrain Error] No training data found.")
37
- return False
38
 
39
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
40
- texts = [f"{g} [SEP] {s1} [SEP] {s2}" for g, s1, s2 in zip(goals, solution1, solution2)]
41
- dataset = EvoDataset(texts, labels, tokenizer)
42
- loader = DataLoader(dataset, batch_size=4, shuffle=True)
43
 
44
- config = {
45
- "vocab_size": tokenizer.vocab_size,
46
- "d_model": 256,
47
- "nhead": 4,
48
- "dim_feedforward": 512,
49
- "num_hidden_layers": 4
50
- }
51
- model = EvoTransformerForClassification.from_config_dict(config)
52
- model.train()
53
 
54
- optimizer = optim.AdamW(model.parameters(), lr=1e-4)
55
- criterion = nn.CrossEntropyLoss()
56
 
57
- for epoch in range(3): # quick training
58
- for input_ids, label in loader:
59
- logits = model(input_ids)
60
- loss = criterion(logits, label)
61
- loss.backward()
62
- optimizer.step()
63
- optimizer.zero_grad()
64
 
65
- model.save_pretrained("trained_evo")
66
- print("✅ Retraining complete.")
67
- return True
 
 
 
 
 
1
  from evo_model import EvoTransformerForClassification
2
  from transformers import AutoTokenizer
3
  import torch
 
4
  import torch.nn as nn
5
  import torch.optim as optim
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from firebase_admin import firestore
8
 
9
  class EvoDataset(Dataset):
10
  def __init__(self, texts, labels, tokenizer, max_length=64):
 
19
  def __len__(self):
20
  return len(self.labels)
21
 
22
+ def manual_retrain():
23
+ try:
24
+ db = firestore.client()
25
+ docs = db.collection("evo_feedback_logs").stream()
26
 
27
+ goals, solution1, solution2, labels = [], [], [], []
28
+ for doc in docs:
29
+ d = doc.to_dict()
30
+ if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
31
+ goals.append(d["goal"])
32
+ solution1.append(d["solution_1"])
33
+ solution2.append(d["solution_2"])
34
+ labels.append(0 if d["correct_answer"] == "Solution 1" else 1)
35
 
36
+ if not goals:
37
+ print("[Retrain Error] No training data found.")
38
+ return False
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
41
+ texts = [f"{g} [SEP] {s1} [SEP] {s2}" for g, s1, s2 in zip(goals, solution1, solution2)]
42
+ dataset = EvoDataset(texts, labels, tokenizer)
43
+ loader = DataLoader(dataset, batch_size=4, shuffle=True)
44
 
45
+ config = {
46
+ "vocab_size": tokenizer.vocab_size,
47
+ "d_model": 256,
48
+ "nhead": 4,
49
+ "dim_feedforward": 512,
50
+ "num_hidden_layers": 4
51
+ }
52
+ model = EvoTransformerForClassification.from_config_dict(config)
53
+ model.train()
54
 
55
+ optimizer = optim.AdamW(model.parameters(), lr=1e-4)
56
+ criterion = nn.CrossEntropyLoss()
57
 
58
+ for epoch in range(3):
59
+ for input_ids, label in loader:
60
+ logits = model(input_ids)
61
+ loss = criterion(logits, label)
62
+ loss.backward()
63
+ optimizer.step()
64
+ optimizer.zero_grad()
65
 
66
+ model.save_pretrained("trained_evo")
67
+ print("✅ Retraining complete.")
68
+ return True
69
+ except Exception as e:
70
+ print(f"[Retrain Error] {e}")
71
+ return False