HemanM commited on
Commit
5c94b85
·
verified ·
1 Parent(s): 2f93104

Create retrain_from_feedback

Browse files
Files changed (1) hide show
  1. retrain_from_feedback +116 -0
retrain_from_feedback ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # retrain_from_feedback.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from transformers import AutoTokenizer
8
+
9
+ from evo_architecture import mutate_genome, default_config, log_genome
10
+ from model import EvoTransformerV22 # Must accept dynamic config
11
+ import csv
12
+ import os
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ class FeedbackDataset(Dataset):
17
+ def __init__(self, tokenizer, data, max_len=128):
18
+ self.tokenizer = tokenizer
19
+ self.samples = data
20
+ self.max_len = max_len
21
+
22
+ def __len__(self):
23
+ return len(self.samples)
24
+
25
+ def __getitem__(self, idx):
26
+ q, o1, o2, ctx, ans, label = self.samples[idx]
27
+ text = f"{q} [SEP] {o1} [SEP] {o2} [SEP] {ctx}"
28
+ enc = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
29
+ input_ids = enc["input_ids"].squeeze(0)
30
+ label = 1 if ans.strip().lower() == o1.strip().lower() else 0
31
+ return input_ids, torch.tensor(label)
32
+
33
+ def load_feedback():
34
+ data = []
35
+ if not os.path.exists("feedback_log.csv"):
36
+ return data
37
+
38
+ with open("feedback_log.csv", encoding="utf-8") as f:
39
+ reader = csv.DictReader(f)
40
+ for row in reader:
41
+ if row["evo_was_correct"].strip().lower() == "yes":
42
+ data.append([
43
+ row["question"],
44
+ row["option1"],
45
+ row["option2"],
46
+ row["context"],
47
+ row["evo_output"],
48
+ "yes"
49
+ ])
50
+ return data
51
+
52
+ def build_model(config):
53
+ from model import EvoEncoder
54
+ class CustomEvo(nn.Module):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.encoder = EvoEncoder(
58
+ d_model=512,
59
+ num_heads=config["num_heads"],
60
+ ffn_dim=config["ffn_dim"],
61
+ num_layers=config["num_layers"],
62
+ memory_enabled=config["memory_enabled"]
63
+ )
64
+ self.pool = nn.AdaptiveAvgPool1d(1)
65
+ self.classifier = nn.Linear(512, 1)
66
+
67
+ def forward(self, input_ids):
68
+ x = self.encoder(input_ids)
69
+ x = self.pool(x.transpose(1, 2)).squeeze(-1)
70
+ return self.classifier(x)
71
+
72
+ return CustomEvo().to(device)
73
+
74
+ def train_evo():
75
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
76
+ data = load_feedback()
77
+ if not data:
78
+ print("❌ No feedback data found.")
79
+ return
80
+
81
+ base_config = default_config()
82
+ new_config = mutate_genome(base_config)
83
+ model = build_model(new_config)
84
+ model.train()
85
+
86
+ dataset = FeedbackDataset(tokenizer, data)
87
+ loader = DataLoader(dataset, batch_size=4, shuffle=True)
88
+
89
+ loss_fn = nn.BCEWithLogitsLoss()
90
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
91
+
92
+ for epoch in range(3):
93
+ total_loss, correct = 0, 0
94
+ for input_ids, labels in loader:
95
+ input_ids, labels = input_ids.to(device), labels.float().to(device)
96
+ logits = model(input_ids).squeeze(-1)
97
+ loss = loss_fn(logits, labels)
98
+ optimizer.zero_grad()
99
+ loss.backward()
100
+ optimizer.step()
101
+
102
+ total_loss += loss.item()
103
+ preds = (torch.sigmoid(logits) > 0.5).long()
104
+ correct += (preds == labels.long()).sum().item()
105
+
106
+ acc = correct / len(dataset)
107
+ print(f"✅ Epoch {epoch+1} | Loss={total_loss:.4f} | Acc={acc:.4f}")
108
+
109
+ # Save model + genome
110
+ os.makedirs("trained_model", exist_ok=True)
111
+ torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
112
+ log_genome(new_config, acc)
113
+ print("✅ Model saved and genome logged.")
114
+
115
+ if __name__ == "__main__":
116
+ train_evo()