Spaces:
Running
Running
Update watchdog.py
Browse files- watchdog.py +14 -16
watchdog.py
CHANGED
@@ -1,10 +1,12 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer
|
3 |
from evo_model import EvoTransformerForClassification
|
4 |
from firebase_admin import firestore
|
5 |
import pandas as pd
|
6 |
|
7 |
-
# Load tokenizer
|
8 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
9 |
|
10 |
def load_feedback_data():
|
@@ -24,8 +26,7 @@ def load_feedback_data():
|
|
24 |
|
25 |
def encode(goal, sol1, sol2):
|
26 |
prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
|
27 |
-
|
28 |
-
return encoded.input_ids, encoded.attention_mask
|
29 |
|
30 |
def manual_retrain():
|
31 |
try:
|
@@ -36,28 +37,25 @@ def manual_retrain():
|
|
36 |
|
37 |
model = EvoTransformerForClassification.from_pretrained("trained_model")
|
38 |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
39 |
-
loss_fn = torch.nn.CrossEntropyLoss()
|
40 |
|
41 |
model.train()
|
42 |
for _, row in data.sample(frac=1).iterrows():
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
47 |
-
logits = outputs.logits if hasattr(outputs, "logits") else outputs
|
48 |
|
49 |
-
if
|
50 |
-
loss =
|
51 |
-
optimizer.zero_grad()
|
52 |
-
loss.backward()
|
53 |
-
optimizer.step()
|
54 |
else:
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
57 |
model.save_pretrained("trained_model")
|
58 |
print("✅ Evo retrained and saved.")
|
59 |
return True
|
60 |
-
|
61 |
except Exception as e:
|
62 |
print(f"[Retrain Error] {e}")
|
63 |
return False
|
|
|
1 |
+
# watchdog.py
|
2 |
+
|
3 |
import torch
|
4 |
from transformers import AutoTokenizer
|
5 |
from evo_model import EvoTransformerForClassification
|
6 |
from firebase_admin import firestore
|
7 |
import pandas as pd
|
8 |
|
9 |
+
# Load tokenizer
|
10 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
11 |
|
12 |
def load_feedback_data():
|
|
|
26 |
|
27 |
def encode(goal, sol1, sol2):
|
28 |
prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
|
29 |
+
return tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
|
|
30 |
|
31 |
def manual_retrain():
|
32 |
try:
|
|
|
37 |
|
38 |
model = EvoTransformerForClassification.from_pretrained("trained_model")
|
39 |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
|
|
40 |
|
41 |
model.train()
|
42 |
for _, row in data.sample(frac=1).iterrows():
|
43 |
+
encoded = encode(row["goal"], row["sol1"], row["sol2"])
|
44 |
+
labels = torch.tensor([row["label"]])
|
45 |
+
outputs = model(input_ids=encoded["input_ids"], attention_mask=encoded["attention_mask"], labels=labels)
|
|
|
|
|
46 |
|
47 |
+
if isinstance(outputs, tuple):
|
48 |
+
loss = outputs[0]
|
|
|
|
|
|
|
49 |
else:
|
50 |
+
loss = outputs
|
51 |
+
|
52 |
+
optimizer.zero_grad()
|
53 |
+
loss.backward()
|
54 |
+
optimizer.step()
|
55 |
|
56 |
model.save_pretrained("trained_model")
|
57 |
print("✅ Evo retrained and saved.")
|
58 |
return True
|
|
|
59 |
except Exception as e:
|
60 |
print(f"[Retrain Error] {e}")
|
61 |
return False
|