Spaces:
Sleeping
Sleeping
File size: 4,215 Bytes
5c94b85 4c4f54b 250de38 5c94b85 250de38 5c94b85 250de38 5c94b85 250de38 bddaae0 5c94b85 250de38 5c94b85 250de38 5c94b85 250de38 5c94b85 bddaae0 5c94b85 250de38 5c94b85 bddaae0 5c94b85 250de38 5c94b85 bddaae0 5c94b85 bddaae0 5c94b85 bddaae0 5c94b85 250de38 5c94b85 250de38 5c94b85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from evo_architecture import mutate_genome, default_config, log_genome
from evo_model import EvoTransformerV22
import csv, os
# π» Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# π¦ Dataset built from feedback CSV
class FeedbackDataset(Dataset):
def __init__(self, tokenizer, data, max_len=128):
self.tokenizer = tokenizer
self.samples = data
self.max_len = max_len
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
item = self.samples[idx]
q, o1, o2, ctx, label = item
prompt = f"{q} [SEP] {o1} [SEP] {o2} [CTX] {ctx}"
enc = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
input_ids = enc["input_ids"].squeeze(0)
return input_ids, torch.tensor(label)
# π§ Load feedback data
def load_feedback():
data = []
if not os.path.exists("feedback_log.csv"):
return data
with open("feedback_log.csv", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
q = row["question"]
o1 = row["option1"]
o2 = row["option2"]
ctx = row["context"]
evo_out = row["evo_output"].strip()
vote = row.get("user_preference", "").lower()
evo_correct = row.get("evo_was_correct", "").lower()
# Priority 1: user vote
if vote == "evo":
label = 1
elif vote == "gpt":
label = 0
# Priority 2: evo correctness
elif evo_correct == "yes":
label = 1
else:
continue # skip uncertain rows
# Label 1 means Evo was correct/preferred
data.append([q, o1, o2, ctx, label])
return data
# π§ Evo model builder from config
def build_model(config):
from model import EvoEncoder
class EvoClassifier(nn.Module):
def __init__(self):
super().__init__()
self.encoder = EvoEncoder(
d_model=512,
num_heads=config["num_heads"],
ffn_dim=config["ffn_dim"],
num_layers=config["num_layers"],
memory_enabled=config["memory_enabled"]
)
self.pool = nn.AdaptiveAvgPool1d(1)
self.classifier = nn.Linear(512, 2)
def forward(self, input_ids):
x = self.encoder(input_ids)
x = self.pool(x.transpose(1, 2)).squeeze(-1)
return self.classifier(x)
return EvoClassifier().to(device)
# π Train Evo on feedback
def train_evo():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
data = load_feedback()
if not data:
print("β No feedback data found.")
return
base_config = default_config()
new_config = mutate_genome(base_config)
model = build_model(new_config)
model.train()
dataset = FeedbackDataset(tokenizer, data)
loader = DataLoader(dataset, batch_size=4, shuffle=True)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(3):
total_loss, correct = 0, 0
for input_ids, labels in loader:
input_ids, labels = input_ids.to(device), labels.to(device)
logits = model(input_ids)
loss = loss_fn(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
correct += (preds == labels).sum().item()
acc = correct / len(dataset)
print(f"β
Epoch {epoch+1} | Loss={total_loss:.4f} | Acc={acc:.4f}")
os.makedirs("trained_model", exist_ok=True)
torch.save(model.state_dict(), "trained_model/evo_retrained.pt")
log_genome(new_config, acc)
print("β
Evo retrained and genome logged.")
# π Entry point
if __name__ == "__main__":
train_evo()
|