HemanM commited on
Commit
b0e670a
Β·
verified Β·
1 Parent(s): 515b961

Update retrain_from_feedback.py

Browse files
Files changed (1) hide show
  1. retrain_from_feedback.py +41 -129
retrain_from_feedback.py CHANGED
@@ -1,129 +1,41 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import DataLoader, Dataset
5
- from transformers import AutoTokenizer
6
- from evo_architecture import mutate_genome, default_config, log_genome
7
- from evo_model import EvoTransformerV22
8
- import csv, os
9
-
10
- # πŸ’» Device setup
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- # πŸ“¦ Dataset built from feedback CSV
14
- class FeedbackDataset(Dataset):
15
- def __init__(self, tokenizer, data, max_len=128):
16
- self.tokenizer = tokenizer
17
- self.samples = data
18
- self.max_len = max_len
19
-
20
- def __len__(self):
21
- return len(self.samples)
22
-
23
- def __getitem__(self, idx):
24
- item = self.samples[idx]
25
- q, o1, o2, ctx, label = item
26
- prompt = f"{q} [SEP] {o1} [SEP] {o2} [CTX] {ctx}"
27
- enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
28
- input_ids = enc["input_ids"].squeeze(0)
29
- return input_ids, torch.tensor(label)
30
-
31
- # 🧠 Load feedback data
32
- def load_feedback():
33
- data = []
34
- if not os.path.exists("feedback_log.csv"):
35
- return data
36
-
37
- with open("feedback_log.csv", encoding="utf-8") as f:
38
- reader = csv.DictReader(f)
39
- for row in reader:
40
- q = row["question"]
41
- o1 = row["option1"]
42
- o2 = row["option2"]
43
- ctx = row["context"]
44
- evo_out = row["evo_output"].strip()
45
- vote = row.get("user_preference", "").lower()
46
- evo_correct = row.get("evo_was_correct", "").lower()
47
-
48
- # Priority 1: user vote
49
- if vote == "evo":
50
- label = 1
51
- elif vote == "gpt":
52
- label = 0
53
- # Priority 2: evo correctness
54
- elif evo_correct == "yes":
55
- label = 1
56
- else:
57
- continue # skip uncertain rows
58
-
59
- # Label 1 means Evo was correct/preferred
60
- data.append([q, o1, o2, ctx, label])
61
- return data
62
-
63
- # πŸ”§ Evo model builder from config
64
- def build_model(config):
65
- from evo_model import EvoEncoder
66
- class EvoClassifier(nn.Module):
67
- def __init__(self):
68
- super().__init__()
69
- self.encoder = EvoEncoder(
70
- d_model=512,
71
- num_heads=config["num_heads"],
72
- ffn_dim=config["ffn_dim"],
73
- num_layers=config["num_layers"],
74
- memory_enabled=config["memory_enabled"]
75
- )
76
- self.pool = nn.AdaptiveAvgPool1d(1)
77
- self.classifier = nn.Linear(512, 2)
78
-
79
- def forward(self, input_ids):
80
- x = self.encoder(input_ids)
81
- x = self.pool(x.transpose(1, 2)).squeeze(-1)
82
- return self.classifier(x)
83
-
84
- return EvoClassifier().to(device)
85
-
86
- # πŸ” Train Evo on feedback
87
- def train_evo():
88
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
89
- data = load_feedback()
90
- if not data:
91
- print("❌ No feedback data found.")
92
- return
93
-
94
- base_config = default_config()
95
- new_config = mutate_genome(base_config)
96
- model = build_model(new_config)
97
- model.train()
98
-
99
- dataset = FeedbackDataset(tokenizer, data)
100
- loader = DataLoader(dataset, batch_size=4, shuffle=True)
101
-
102
- loss_fn = nn.CrossEntropyLoss()
103
- optimizer = optim.Adam(model.parameters(), lr=1e-4)
104
-
105
- for epoch in range(3):
106
- total_loss, correct = 0, 0
107
- for input_ids, labels in loader:
108
- input_ids, labels = input_ids.to(device), labels.to(device)
109
- logits = model(input_ids)
110
- loss = loss_fn(logits, labels)
111
- optimizer.zero_grad()
112
- loss.backward()
113
- optimizer.step()
114
-
115
- total_loss += loss.item()
116
- preds = torch.argmax(logits, dim=1)
117
- correct += (preds == labels).sum().item()
118
-
119
- acc = correct / len(dataset)
120
- print(f"βœ… Epoch {epoch+1} | Loss={total_loss:.4f} | Acc={acc:.4f}")
121
-
122
- os.makedirs("trained_model", exist_ok=True)
123
- torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
124
- log_genome(new_config, acc)
125
- print("βœ… Evo retrained and genome logged.")
126
-
127
- # πŸ” Entry point
128
- if __name__ == "__main__":
129
- train_evo()
 
1
+ import csv
2
+ import os
3
+ from datetime import datetime
4
+ from retrain_from_feedback import train_evo
5
+
6
+ # πŸ” Main entry point for feedback-triggered retraining
7
+ def retrain_from_feedback(feedback_log):
8
+ # βœ… Check if feedback is present
9
+ if not feedback_log:
10
+ return "⚠️ No feedback data to retrain from."
11
+
12
+ # πŸ“ Write feedback to CSV
13
+ try:
14
+ os.makedirs("feedback", exist_ok=True)
15
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
16
+ filepath = f"feedback/feedback_log.csv" # also usable for loading
17
+ with open(filepath, "w", newline="", encoding="utf-8") as f:
18
+ writer = csv.writer(f)
19
+ writer.writerow([
20
+ "question", "option1", "option2", "answer",
21
+ "confidence", "reasoning", "context",
22
+ "user_preference", "evo_was_correct", "evo_output"
23
+ ])
24
+ for row in feedback_log:
25
+ question, option1, option2, answer, confidence, reasoning, context = row
26
+
27
+ # Simulate Evo being preferred (you can modify this logic later)
28
+ writer.writerow([
29
+ question, option1, option2, answer,
30
+ confidence, reasoning, context,
31
+ "evo", "yes", answer
32
+ ])
33
+ except Exception as e:
34
+ return f"❌ Failed to save feedback: {str(e)}"
35
+
36
+ # πŸ” Trigger training
37
+ try:
38
+ train_evo() # This uses the latest feedback_log.csv
39
+ return "βœ… Evo retrained and weights saved."
40
+ except Exception as e:
41
+ return f"❌ Evo training failed: {str(e)}"