File size: 3,822 Bytes
5c94b85
 
 
 
 
 
 
bddaae0
5c94b85
 
 
bddaae0
5c94b85
 
 
 
 
 
 
 
 
 
 
 
bddaae0
 
 
5c94b85
bddaae0
 
 
5c94b85
 
 
 
 
 
 
 
 
 
bddaae0
5c94b85
 
 
 
 
bddaae0
5c94b85
 
 
 
 
bddaae0
5c94b85
 
 
 
 
 
 
 
 
 
bddaae0
5c94b85
 
 
 
 
 
bddaae0
5c94b85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bddaae0
5c94b85
 
 
 
 
bddaae0
 
5c94b85
 
 
 
 
 
bddaae0
 
5c94b85
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer

from evo_architecture import mutate_genome, default_config, log_genome
from model import EvoTransformerV22  # Ensure this is compatible with config
import csv
import os

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FeedbackDataset(Dataset):
    def __init__(self, tokenizer, data, max_len=128):
        self.tokenizer = tokenizer
        self.samples = data
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        q, o1, o2, ctx, evo_ans = self.samples[idx]
        prompt = f"{q} [SEP] {o1} [SEP] {o2} [SEP] {ctx}"
        enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
        input_ids = enc["input_ids"].squeeze(0)

        # Label: 0 if Evo picked option1, else 1
        label = 0 if evo_ans.strip().lower() == o1.strip().lower() else 1
        return input_ids, torch.tensor(label)

def load_feedback():
    data = []
    if not os.path.exists("feedback_log.csv"):
        return data

    with open("feedback_log.csv", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row.get("evo_was_correct", "no").strip().lower() == "yes":
                data.append([
                    row["question"],
                    row["option1"],
                    row["option2"],
                    row["context"],
                    row["evo_output"].strip()
                ])
    return data

def build_model(config):
    from model import EvoEncoder
    class EvoClassifier(nn.Module):
        def __init__(self):
            super().__init__()
            self.encoder = EvoEncoder(
                d_model=512,
                num_heads=config["num_heads"],
                ffn_dim=config["ffn_dim"],
                num_layers=config["num_layers"],
                memory_enabled=config["memory_enabled"]
            )
            self.pool = nn.AdaptiveAvgPool1d(1)
            self.classifier = nn.Linear(512, 2)  # two-class classification

        def forward(self, input_ids):
            x = self.encoder(input_ids)
            x = self.pool(x.transpose(1, 2)).squeeze(-1)
            return self.classifier(x)

    return EvoClassifier().to(device)

def train_evo():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    data = load_feedback()
    if not data:
        print("❌ No feedback data found.")
        return

    base_config = default_config()
    new_config = mutate_genome(base_config)
    model = build_model(new_config)
    model.train()

    dataset = FeedbackDataset(tokenizer, data)
    loader = DataLoader(dataset, batch_size=4, shuffle=True)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(3):
        total_loss, correct = 0, 0
        for input_ids, labels in loader:
            input_ids, labels = input_ids.to(device), labels.to(device)
            logits = model(input_ids)
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()

        acc = correct / len(dataset)
        print(f"✅ Epoch {epoch+1} | Loss={total_loss:.4f} | Acc={acc:.4f}")

    os.makedirs("trained_model", exist_ok=True)
    torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
    log_genome(new_config, acc)
    print("✅ Model saved and genome logged.")

if __name__ == "__main__":
    train_evo()