HemanM commited on
Commit
3489232
·
verified ·
1 Parent(s): e7e30db

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +41 -22
watchdog.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import torch
3
  import firebase_admin
4
  from firebase_admin import credentials, firestore
5
- from model import SimpleEvoModel
 
6
 
7
  # Initialize Firebase if not already initialized
8
  if not firebase_admin._apps:
@@ -11,45 +12,63 @@ if not firebase_admin._apps:
11
 
12
  db = firestore.client()
13
 
14
- def fetch_training_data():
15
  logs_ref = db.collection("evo_feedback")
16
  docs = logs_ref.stream()
17
-
18
- inputs, labels = [], []
19
  for doc in docs:
20
  data = doc.to_dict()
21
- goal = data.get("prompt", "")
22
  winner = data.get("winner", "")
23
- if winner:
24
- # Simulated encoding
25
- vector = [float(ord(c) % 256) / 255.0 for c in (goal + winner)]
26
- vector = vector[:768] + [0.0] * max(0, 768 - len(vector)) # pad/truncate
 
 
 
 
 
 
 
27
  label = 0 if "1" in winner else 1
28
- inputs.append(vector)
29
  labels.append(label)
30
-
31
- return torch.tensor(inputs, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)
 
 
 
 
 
 
 
32
 
33
  def retrain_and_save():
34
- X, y = fetch_training_data()
35
- if len(X) < 2:
 
 
36
  print("⚠️ Not enough training data.")
37
  return
38
 
39
- model = SimpleEvoModel()
 
 
 
 
40
  loss_fn = torch.nn.CrossEntropyLoss()
41
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
42
 
43
- for epoch in range(5):
44
  optimizer.zero_grad()
45
- output = model(X)
46
- loss = loss_fn(output, y)
47
  loss.backward()
48
  optimizer.step()
49
-
50
- # Save retrained model to trained_model/
51
  os.makedirs("trained_model", exist_ok=True)
52
- torch.save(model.state_dict(), "trained_model/pytorch_model.bin")
53
  print("✅ EvoTransformer retrained and saved to trained_model/")
54
 
55
  if __name__ == "__main__":
 
2
  import torch
3
  import firebase_admin
4
  from firebase_admin import credentials, firestore
5
+ from evo_model import EvoTransformerForClassification, EvoTransformerConfig
6
+ from transformers import BertTokenizer
7
 
8
  # Initialize Firebase if not already initialized
9
  if not firebase_admin._apps:
 
12
 
13
  db = firestore.client()
14
 
15
+ def fetch_training_data(tokenizer):
16
  logs_ref = db.collection("evo_feedback")
17
  docs = logs_ref.stream()
18
+
19
+ input_ids, attention_masks, labels = [], [], []
20
  for doc in docs:
21
  data = doc.to_dict()
22
+ prompt = data.get("prompt", "")
23
  winner = data.get("winner", "")
24
+ if winner and prompt:
25
+ text = prompt + " [SEP] " + winner
26
+ encoding = tokenizer(
27
+ text,
28
+ truncation=True,
29
+ padding="max_length",
30
+ max_length=128,
31
+ return_tensors="pt"
32
+ )
33
+ input_ids.append(encoding["input_ids"][0])
34
+ attention_masks.append(encoding["attention_mask"][0])
35
  label = 0 if "1" in winner else 1
 
36
  labels.append(label)
37
+
38
+ if not input_ids:
39
+ return None, None, None
40
+
41
+ return (
42
+ torch.stack(input_ids),
43
+ torch.stack(attention_masks),
44
+ torch.tensor(labels, dtype=torch.long)
45
+ )
46
 
47
  def retrain_and_save():
48
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
49
+ input_ids, attention_masks, labels = fetch_training_data(tokenizer)
50
+
51
+ if input_ids is None or len(input_ids) < 2:
52
  print("⚠️ Not enough training data.")
53
  return
54
 
55
+ config = EvoTransformerConfig()
56
+ model = EvoTransformerForClassification(config)
57
+ model.train()
58
+
59
+ optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
60
  loss_fn = torch.nn.CrossEntropyLoss()
 
61
 
62
+ for epoch in range(3):
63
  optimizer.zero_grad()
64
+ outputs = model(input_ids, attention_mask=attention_masks)
65
+ loss = loss_fn(outputs, labels)
66
  loss.backward()
67
  optimizer.step()
68
+ print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
69
+
70
  os.makedirs("trained_model", exist_ok=True)
71
+ model.save_pretrained("trained_model")
72
  print("✅ EvoTransformer retrained and saved to trained_model/")
73
 
74
  if __name__ == "__main__":