HemanM commited on
Commit
d29dccb
·
verified ·
1 Parent(s): 9d9f2ca

Update watchdog.py

Browse files
Files changed (1) hide show
  1. watchdog.py +9 -10
watchdog.py CHANGED
@@ -4,7 +4,7 @@ from evo_model import EvoTransformerForClassification
4
  from firebase_admin import firestore
5
  import pandas as pd
6
 
7
- # Load tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
9
 
10
  def load_feedback_data():
@@ -24,7 +24,8 @@ def load_feedback_data():
24
 
25
  def encode(goal, sol1, sol2):
26
  prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
27
- return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids
 
28
 
29
  def manual_retrain():
30
  try:
@@ -39,15 +40,12 @@ def manual_retrain():
39
 
40
  model.train()
41
  for _, row in data.sample(frac=1).iterrows():
42
- inputs = encode(row["goal"], row["sol1"], row["sol2"])
43
  label = torch.tensor([row["label"]])
44
- outputs = model(inputs)
45
- if isinstance(outputs, tuple):
46
- logits = outputs[0]
47
- elif hasattr(outputs, "logits"):
48
- logits = outputs.logits
49
- else:
50
- logits = outputs
51
  if logits.ndim == 2 and label.ndim == 1:
52
  loss = loss_fn(logits, label)
53
  optimizer.zero_grad()
@@ -59,6 +57,7 @@ def manual_retrain():
59
  model.save_pretrained("trained_model")
60
  print("✅ Evo retrained and saved.")
61
  return True
 
62
  except Exception as e:
63
  print(f"[Retrain Error] {e}")
64
  return False
 
4
  from firebase_admin import firestore
5
  import pandas as pd
6
 
7
+ # Load tokenizer once
8
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
9
 
10
  def load_feedback_data():
 
24
 
25
  def encode(goal, sol1, sol2):
26
  prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
27
+ encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
28
+ return encoded.input_ids, encoded.attention_mask
29
 
30
  def manual_retrain():
31
  try:
 
40
 
41
  model.train()
42
  for _, row in data.sample(frac=1).iterrows():
43
+ input_ids, attention_mask = encode(row["goal"], row["sol1"], row["sol2"])
44
  label = torch.tensor([row["label"]])
45
+
46
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
47
+ logits = outputs.logits if hasattr(outputs, "logits") else outputs
48
+
 
 
 
49
  if logits.ndim == 2 and label.ndim == 1:
50
  loss = loss_fn(logits, label)
51
  optimizer.zero_grad()
 
57
  model.save_pretrained("trained_model")
58
  print("✅ Evo retrained and saved.")
59
  return True
60
+
61
  except Exception as e:
62
  print(f"[Retrain Error] {e}")
63
  return False