HemanM commited on
Commit
250de38
·
verified ·
1 Parent(s): 1dec93e

Update retrain_from_feedback.py

Browse files
Files changed (1) hide show
  1. retrain_from_feedback.py +33 -19
retrain_from_feedback.py CHANGED
@@ -3,15 +3,14 @@ import torch.nn as nn
3
  import torch.optim as optim
4
  from torch.utils.data import DataLoader, Dataset
5
  from transformers import AutoTokenizer
6
-
7
  from evo_architecture import mutate_genome, default_config, log_genome
8
  from evo_model import EvoTransformerV22
9
- import csv
10
- import os
11
 
12
- # Device setup
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
 
15
  class FeedbackDataset(Dataset):
16
  def __init__(self, tokenizer, data, max_len=128):
17
  self.tokenizer = tokenizer
@@ -22,15 +21,14 @@ class FeedbackDataset(Dataset):
22
  return len(self.samples)
23
 
24
  def __getitem__(self, idx):
25
- q, o1, o2, ctx, evo_ans = self.samples[idx]
26
- prompt = f"{q} [SEP] {o1} [SEP] {o2} [SEP] {ctx}"
 
27
  enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
28
  input_ids = enc["input_ids"].squeeze(0)
29
-
30
- # Label: 0 if Evo picked option1, else 1
31
- label = 0 if evo_ans.strip().lower() == o1.strip().lower() else 1
32
  return input_ids, torch.tensor(label)
33
 
 
34
  def load_feedback():
35
  data = []
36
  if not os.path.exists("feedback_log.csv"):
@@ -39,16 +37,30 @@ def load_feedback():
39
  with open("feedback_log.csv", encoding="utf-8") as f:
40
  reader = csv.DictReader(f)
41
  for row in reader:
42
- if row.get("evo_was_correct", "no").strip().lower() == "yes":
43
- data.append([
44
- row["question"],
45
- row["option1"],
46
- row["option2"],
47
- row["context"],
48
- row["evo_output"].strip()
49
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return data
51
 
 
52
  def build_model(config):
53
  from model import EvoEncoder
54
  class EvoClassifier(nn.Module):
@@ -62,7 +74,7 @@ def build_model(config):
62
  memory_enabled=config["memory_enabled"]
63
  )
64
  self.pool = nn.AdaptiveAvgPool1d(1)
65
- self.classifier = nn.Linear(512, 2) # two-class classification
66
 
67
  def forward(self, input_ids):
68
  x = self.encoder(input_ids)
@@ -71,6 +83,7 @@ def build_model(config):
71
 
72
  return EvoClassifier().to(device)
73
 
 
74
  def train_evo():
75
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
76
  data = load_feedback()
@@ -109,7 +122,8 @@ def train_evo():
109
  os.makedirs("trained_model", exist_ok=True)
110
  torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
111
  log_genome(new_config, acc)
112
- print("✅ Model saved and genome logged.")
113
 
 
114
  if __name__ == "__main__":
115
  train_evo()
 
3
  import torch.optim as optim
4
  from torch.utils.data import DataLoader, Dataset
5
  from transformers import AutoTokenizer
 
6
  from evo_architecture import mutate_genome, default_config, log_genome
7
  from evo_model import EvoTransformerV22
8
+ import csv, os
 
9
 
10
+ # 💻 Device setup
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ # 📦 Dataset built from feedback CSV
14
  class FeedbackDataset(Dataset):
15
  def __init__(self, tokenizer, data, max_len=128):
16
  self.tokenizer = tokenizer
 
21
  return len(self.samples)
22
 
23
  def __getitem__(self, idx):
24
+ item = self.samples[idx]
25
+ q, o1, o2, ctx, label = item
26
+ prompt = f"{q} [SEP] {o1} [SEP] {o2} [CTX] {ctx}"
27
  enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
28
  input_ids = enc["input_ids"].squeeze(0)
 
 
 
29
  return input_ids, torch.tensor(label)
30
 
31
+ # 🧠 Load feedback data
32
  def load_feedback():
33
  data = []
34
  if not os.path.exists("feedback_log.csv"):
 
37
  with open("feedback_log.csv", encoding="utf-8") as f:
38
  reader = csv.DictReader(f)
39
  for row in reader:
40
+ q = row["question"]
41
+ o1 = row["option1"]
42
+ o2 = row["option2"]
43
+ ctx = row["context"]
44
+ evo_out = row["evo_output"].strip()
45
+ vote = row.get("user_preference", "").lower()
46
+ evo_correct = row.get("evo_was_correct", "").lower()
47
+
48
+ # Priority 1: user vote
49
+ if vote == "evo":
50
+ label = 1
51
+ elif vote == "gpt":
52
+ label = 0
53
+ # Priority 2: evo correctness
54
+ elif evo_correct == "yes":
55
+ label = 1
56
+ else:
57
+ continue # skip uncertain rows
58
+
59
+ # Label 1 means Evo was correct/preferred
60
+ data.append([q, o1, o2, ctx, label])
61
  return data
62
 
63
+ # 🔧 Evo model builder from config
64
  def build_model(config):
65
  from model import EvoEncoder
66
  class EvoClassifier(nn.Module):
 
74
  memory_enabled=config["memory_enabled"]
75
  )
76
  self.pool = nn.AdaptiveAvgPool1d(1)
77
+ self.classifier = nn.Linear(512, 2)
78
 
79
  def forward(self, input_ids):
80
  x = self.encoder(input_ids)
 
83
 
84
  return EvoClassifier().to(device)
85
 
86
+ # 🔁 Train Evo on feedback
87
  def train_evo():
88
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
89
  data = load_feedback()
 
122
  os.makedirs("trained_model", exist_ok=True)
123
  torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
124
  log_genome(new_config, acc)
125
+ print("✅ Evo retrained and genome logged.")
126
 
127
+ # 🔁 Entry point
128
  if __name__ == "__main__":
129
  train_evo()