Spaces:
Sleeping
Sleeping
Update retrain_from_feedback.py
Browse files- retrain_from_feedback.py +33 -19
retrain_from_feedback.py
CHANGED
@@ -3,15 +3,14 @@ import torch.nn as nn
|
|
3 |
import torch.optim as optim
|
4 |
from torch.utils.data import DataLoader, Dataset
|
5 |
from transformers import AutoTokenizer
|
6 |
-
|
7 |
from evo_architecture import mutate_genome, default_config, log_genome
|
8 |
from evo_model import EvoTransformerV22
|
9 |
-
import csv
|
10 |
-
import os
|
11 |
|
12 |
-
# Device setup
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
|
|
15 |
class FeedbackDataset(Dataset):
|
16 |
def __init__(self, tokenizer, data, max_len=128):
|
17 |
self.tokenizer = tokenizer
|
@@ -22,15 +21,14 @@ class FeedbackDataset(Dataset):
|
|
22 |
return len(self.samples)
|
23 |
|
24 |
def __getitem__(self, idx):
|
25 |
-
|
26 |
-
|
|
|
27 |
enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
|
28 |
input_ids = enc["input_ids"].squeeze(0)
|
29 |
-
|
30 |
-
# Label: 0 if Evo picked option1, else 1
|
31 |
-
label = 0 if evo_ans.strip().lower() == o1.strip().lower() else 1
|
32 |
return input_ids, torch.tensor(label)
|
33 |
|
|
|
34 |
def load_feedback():
|
35 |
data = []
|
36 |
if not os.path.exists("feedback_log.csv"):
|
@@ -39,16 +37,30 @@ def load_feedback():
|
|
39 |
with open("feedback_log.csv", encoding="utf-8") as f:
|
40 |
reader = csv.DictReader(f)
|
41 |
for row in reader:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return data
|
51 |
|
|
|
52 |
def build_model(config):
|
53 |
from model import EvoEncoder
|
54 |
class EvoClassifier(nn.Module):
|
@@ -62,7 +74,7 @@ def build_model(config):
|
|
62 |
memory_enabled=config["memory_enabled"]
|
63 |
)
|
64 |
self.pool = nn.AdaptiveAvgPool1d(1)
|
65 |
-
self.classifier = nn.Linear(512, 2)
|
66 |
|
67 |
def forward(self, input_ids):
|
68 |
x = self.encoder(input_ids)
|
@@ -71,6 +83,7 @@ def build_model(config):
|
|
71 |
|
72 |
return EvoClassifier().to(device)
|
73 |
|
|
|
74 |
def train_evo():
|
75 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
76 |
data = load_feedback()
|
@@ -109,7 +122,8 @@ def train_evo():
|
|
109 |
os.makedirs("trained_model", exist_ok=True)
|
110 |
torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
|
111 |
log_genome(new_config, acc)
|
112 |
-
print("✅
|
113 |
|
|
|
114 |
if __name__ == "__main__":
|
115 |
train_evo()
|
|
|
3 |
import torch.optim as optim
|
4 |
from torch.utils.data import DataLoader, Dataset
|
5 |
from transformers import AutoTokenizer
|
|
|
6 |
from evo_architecture import mutate_genome, default_config, log_genome
|
7 |
from evo_model import EvoTransformerV22
|
8 |
+
import csv, os
|
|
|
9 |
|
10 |
+
# 💻 Device setup
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
+
# 📦 Dataset built from feedback CSV
|
14 |
class FeedbackDataset(Dataset):
|
15 |
def __init__(self, tokenizer, data, max_len=128):
|
16 |
self.tokenizer = tokenizer
|
|
|
21 |
return len(self.samples)
|
22 |
|
23 |
def __getitem__(self, idx):
|
24 |
+
item = self.samples[idx]
|
25 |
+
q, o1, o2, ctx, label = item
|
26 |
+
prompt = f"{q} [SEP] {o1} [SEP] {o2} [CTX] {ctx}"
|
27 |
enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
|
28 |
input_ids = enc["input_ids"].squeeze(0)
|
|
|
|
|
|
|
29 |
return input_ids, torch.tensor(label)
|
30 |
|
31 |
+
# 🧠 Load feedback data
|
32 |
def load_feedback():
|
33 |
data = []
|
34 |
if not os.path.exists("feedback_log.csv"):
|
|
|
37 |
with open("feedback_log.csv", encoding="utf-8") as f:
|
38 |
reader = csv.DictReader(f)
|
39 |
for row in reader:
|
40 |
+
q = row["question"]
|
41 |
+
o1 = row["option1"]
|
42 |
+
o2 = row["option2"]
|
43 |
+
ctx = row["context"]
|
44 |
+
evo_out = row["evo_output"].strip()
|
45 |
+
vote = row.get("user_preference", "").lower()
|
46 |
+
evo_correct = row.get("evo_was_correct", "").lower()
|
47 |
+
|
48 |
+
# Priority 1: user vote
|
49 |
+
if vote == "evo":
|
50 |
+
label = 1
|
51 |
+
elif vote == "gpt":
|
52 |
+
label = 0
|
53 |
+
# Priority 2: evo correctness
|
54 |
+
elif evo_correct == "yes":
|
55 |
+
label = 1
|
56 |
+
else:
|
57 |
+
continue # skip uncertain rows
|
58 |
+
|
59 |
+
# Label 1 means Evo was correct/preferred
|
60 |
+
data.append([q, o1, o2, ctx, label])
|
61 |
return data
|
62 |
|
63 |
+
# 🔧 Evo model builder from config
|
64 |
def build_model(config):
|
65 |
from model import EvoEncoder
|
66 |
class EvoClassifier(nn.Module):
|
|
|
74 |
memory_enabled=config["memory_enabled"]
|
75 |
)
|
76 |
self.pool = nn.AdaptiveAvgPool1d(1)
|
77 |
+
self.classifier = nn.Linear(512, 2)
|
78 |
|
79 |
def forward(self, input_ids):
|
80 |
x = self.encoder(input_ids)
|
|
|
83 |
|
84 |
return EvoClassifier().to(device)
|
85 |
|
86 |
+
# 🔁 Train Evo on feedback
|
87 |
def train_evo():
|
88 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
89 |
data = load_feedback()
|
|
|
122 |
os.makedirs("trained_model", exist_ok=True)
|
123 |
torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
|
124 |
log_genome(new_config, acc)
|
125 |
+
print("✅ Evo retrained and genome logged.")
|
126 |
|
127 |
+
# 🔁 Entry point
|
128 |
if __name__ == "__main__":
|
129 |
train_evo()
|