HemanM commited on
Commit
75ff33c
·
verified ·
1 Parent(s): 7665fc0

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +14 -16
watchdog.py CHANGED
@@ -1,10 +1,12 @@
 
 
1
  import torch
2
  from transformers import AutoTokenizer
3
  from evo_model import EvoTransformerForClassification
4
  from firebase_admin import firestore
5
  import pandas as pd
6
 
7
- # Load tokenizer once
8
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
9
 
10
  def load_feedback_data():
@@ -24,8 +26,7 @@ def load_feedback_data():
24
 
25
  def encode(goal, sol1, sol2):
26
  prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
27
- encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
28
- return encoded.input_ids, encoded.attention_mask
29
 
30
  def manual_retrain():
31
  try:
@@ -36,28 +37,25 @@ def manual_retrain():
36
 
37
  model = EvoTransformerForClassification.from_pretrained("trained_model")
38
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
39
- loss_fn = torch.nn.CrossEntropyLoss()
40
 
41
  model.train()
42
  for _, row in data.sample(frac=1).iterrows():
43
- input_ids, attention_mask = encode(row["goal"], row["sol1"], row["sol2"])
44
- label = torch.tensor([row["label"]])
45
-
46
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
47
- logits = outputs.logits if hasattr(outputs, "logits") else outputs
48
 
49
- if logits.ndim == 2 and label.ndim == 1:
50
- loss = loss_fn(logits, label)
51
- optimizer.zero_grad()
52
- loss.backward()
53
- optimizer.step()
54
  else:
55
- print("[Retrain Warning] Shape mismatch, skipping one example.")
 
 
 
 
56
 
57
  model.save_pretrained("trained_model")
58
  print("✅ Evo retrained and saved.")
59
  return True
60
-
61
  except Exception as e:
62
  print(f"[Retrain Error] {e}")
63
  return False
 
1
+ # watchdog.py
2
+
3
  import torch
4
  from transformers import AutoTokenizer
5
  from evo_model import EvoTransformerForClassification
6
  from firebase_admin import firestore
7
  import pandas as pd
8
 
9
+ # Load tokenizer
10
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
11
 
12
  def load_feedback_data():
 
26
 
27
  def encode(goal, sol1, sol2):
28
  prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
29
+ return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
 
30
 
31
  def manual_retrain():
32
  try:
 
37
 
38
  model = EvoTransformerForClassification.from_pretrained("trained_model")
39
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 
40
 
41
  model.train()
42
  for _, row in data.sample(frac=1).iterrows():
43
+ encoded = encode(row["goal"], row["sol1"], row["sol2"])
44
+ labels = torch.tensor([row["label"]])
45
+ outputs = model(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"], labels=labels)
 
 
46
 
47
+ if isinstance(outputs, tuple):
48
+ loss = outputs[0]
 
 
 
49
  else:
50
+ loss = outputs
51
+
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
 
56
  model.save_pretrained("trained_model")
57
  print("✅ Evo retrained and saved.")
58
  return True
 
59
  except Exception as e:
60
  print(f"[Retrain Error] {e}")
61
  return False