HemanM commited on
Commit
bddaae0
·
verified ·
1 Parent(s): 32d29da

Update retrain_from_feedback

Browse files
Files changed (1) hide show
  1. retrain_from_feedback +18 -19
retrain_from_feedback CHANGED
@@ -1,5 +1,3 @@
1
- # retrain_from_feedback.py
2
-
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
@@ -7,10 +5,11 @@ from torch.utils.data import DataLoader, Dataset
7
  from transformers import AutoTokenizer
8
 
9
  from evo_architecture import mutate_genome, default_config, log_genome
10
- from model import EvoTransformerV22 # Must accept dynamic config
11
  import csv
12
  import os
13
 
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  class FeedbackDataset(Dataset):
@@ -23,11 +22,13 @@ class FeedbackDataset(Dataset):
23
  return len(self.samples)
24
 
25
  def __getitem__(self, idx):
26
- q, o1, o2, ctx, ans, label = self.samples[idx]
27
- text = f"{q} [SEP] {o1} [SEP] {o2} [SEP] {ctx}"
28
- enc = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
29
  input_ids = enc["input_ids"].squeeze(0)
30
- label = 1 if ans.strip().lower() == o1.strip().lower() else 0
 
 
31
  return input_ids, torch.tensor(label)
32
 
33
  def load_feedback():
@@ -38,20 +39,19 @@ def load_feedback():
38
  with open("feedback_log.csv", encoding="utf-8") as f:
39
  reader = csv.DictReader(f)
40
  for row in reader:
41
- if row["evo_was_correct"].strip().lower() == "yes":
42
  data.append([
43
  row["question"],
44
  row["option1"],
45
  row["option2"],
46
  row["context"],
47
- row["evo_output"],
48
- "yes"
49
  ])
50
  return data
51
 
52
  def build_model(config):
53
  from model import EvoEncoder
54
- class CustomEvo(nn.Module):
55
  def __init__(self):
56
  super().__init__()
57
  self.encoder = EvoEncoder(
@@ -62,14 +62,14 @@ def build_model(config):
62
  memory_enabled=config["memory_enabled"]
63
  )
64
  self.pool = nn.AdaptiveAvgPool1d(1)
65
- self.classifier = nn.Linear(512, 1)
66
 
67
  def forward(self, input_ids):
68
  x = self.encoder(input_ids)
69
  x = self.pool(x.transpose(1, 2)).squeeze(-1)
70
  return self.classifier(x)
71
 
72
- return CustomEvo().to(device)
73
 
74
  def train_evo():
75
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
@@ -86,27 +86,26 @@ def train_evo():
86
  dataset = FeedbackDataset(tokenizer, data)
87
  loader = DataLoader(dataset, batch_size=4, shuffle=True)
88
 
89
- loss_fn = nn.BCEWithLogitsLoss()
90
  optimizer = optim.Adam(model.parameters(), lr=1e-4)
91
 
92
  for epoch in range(3):
93
  total_loss, correct = 0, 0
94
  for input_ids, labels in loader:
95
- input_ids, labels = input_ids.to(device), labels.float().to(device)
96
- logits = model(input_ids).squeeze(-1)
97
  loss = loss_fn(logits, labels)
98
  optimizer.zero_grad()
99
  loss.backward()
100
  optimizer.step()
101
 
102
  total_loss += loss.item()
103
- preds = (torch.sigmoid(logits) > 0.5).long()
104
- correct += (preds == labels.long()).sum().item()
105
 
106
  acc = correct / len(dataset)
107
  print(f"✅ Epoch {epoch+1} | Loss={total_loss:.4f} | Acc={acc:.4f}")
108
 
109
- # Save model + genome
110
  os.makedirs("trained_model", exist_ok=True)
111
  torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
112
  log_genome(new_config, acc)
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
 
5
  from transformers import AutoTokenizer
6
 
7
  from evo_architecture import mutate_genome, default_config, log_genome
8
+ from model import EvoTransformerV22 # Ensure this is compatible with config
9
  import csv
10
  import os
11
 
12
+ # Device setup
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  class FeedbackDataset(Dataset):
 
22
  return len(self.samples)
23
 
24
  def __getitem__(self, idx):
25
+ q, o1, o2, ctx, evo_ans = self.samples[idx]
26
+ prompt = f"{q} [SEP] {o1} [SEP] {o2} [SEP] {ctx}"
27
+ enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
28
  input_ids = enc["input_ids"].squeeze(0)
29
+
30
+ # Label: 0 if Evo picked option1, else 1
31
+ label = 0 if evo_ans.strip().lower() == o1.strip().lower() else 1
32
  return input_ids, torch.tensor(label)
33
 
34
  def load_feedback():
 
39
  with open("feedback_log.csv", encoding="utf-8") as f:
40
  reader = csv.DictReader(f)
41
  for row in reader:
42
+ if row.get("evo_was_correct", "no").strip().lower() == "yes":
43
  data.append([
44
  row["question"],
45
  row["option1"],
46
  row["option2"],
47
  row["context"],
48
+ row["evo_output"].strip()
 
49
  ])
50
  return data
51
 
52
  def build_model(config):
53
  from model import EvoEncoder
54
+ class EvoClassifier(nn.Module):
55
  def __init__(self):
56
  super().__init__()
57
  self.encoder = EvoEncoder(
 
62
  memory_enabled=config["memory_enabled"]
63
  )
64
  self.pool = nn.AdaptiveAvgPool1d(1)
65
+ self.classifier = nn.Linear(512, 2) # two-class classification
66
 
67
  def forward(self, input_ids):
68
  x = self.encoder(input_ids)
69
  x = self.pool(x.transpose(1, 2)).squeeze(-1)
70
  return self.classifier(x)
71
 
72
+ return EvoClassifier().to(device)
73
 
74
  def train_evo():
75
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
86
  dataset = FeedbackDataset(tokenizer, data)
87
  loader = DataLoader(dataset, batch_size=4, shuffle=True)
88
 
89
+ loss_fn = nn.CrossEntropyLoss()
90
  optimizer = optim.Adam(model.parameters(), lr=1e-4)
91
 
92
  for epoch in range(3):
93
  total_loss, correct = 0, 0
94
  for input_ids, labels in loader:
95
+ input_ids, labels = input_ids.to(device), labels.to(device)
96
+ logits = model(input_ids)
97
  loss = loss_fn(logits, labels)
98
  optimizer.zero_grad()
99
  loss.backward()
100
  optimizer.step()
101
 
102
  total_loss += loss.item()
103
+ preds = torch.argmax(logits, dim=1)
104
+ correct += (preds == labels).sum().item()
105
 
106
  acc = correct / len(dataset)
107
  print(f"✅ Epoch {epoch+1} | Loss={total_loss:.4f} | Acc={acc:.4f}")
108
 
 
109
  os.makedirs("trained_model", exist_ok=True)
110
  torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
111
  log_genome(new_config, acc)