Spaces:
Running
Running
from evo_model import EvoTransformerForClassification | |
from transformers import AutoTokenizer | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
from firebase_admin import firestore | |
class EvoDataset(Dataset): | |
def __init__(self, texts, labels, tokenizer, max_length=64): | |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=max_length) | |
self.labels = labels | |
def __getitem__(self, idx): | |
input_ids = torch.tensor(self.encodings["input_ids"][idx]) | |
label = torch.tensor(self.labels[idx]) | |
return input_ids, label | |
def __len__(self): | |
return len(self.labels) | |
def manual_retrain(): | |
try: | |
db = firestore.client() | |
docs = db.collection("evo_feedback_logs").stream() | |
goals, solution1, solution2, labels = [], [], [], [] | |
for doc in docs: | |
d = doc.to_dict() | |
if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]): | |
goals.append(d["goal"]) | |
solution1.append(d["solution_1"]) | |
solution2.append(d["solution_2"]) | |
labels.append(0 if d["correct_answer"] == "Solution 1" else 1) | |
if not goals: | |
print("[Retrain Error] No training data found.") | |
return False | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
texts = [f"{g} [SEP] {s1} [SEP] {s2}" for g, s1, s2 in zip(goals, solution1, solution2)] | |
dataset = EvoDataset(texts, labels, tokenizer) | |
loader = DataLoader(dataset, batch_size=4, shuffle=True) | |
config = { | |
"vocab_size": tokenizer.vocab_size, | |
"d_model": 256, | |
"nhead": 4, | |
"dim_feedforward": 512, | |
"num_hidden_layers": 4 | |
} | |
model = EvoTransformerForClassification.from_config_dict(config) | |
model.train() | |
optimizer = optim.AdamW(model.parameters(), lr=1e-4) | |
criterion = nn.CrossEntropyLoss() | |
for epoch in range(3): | |
for input_ids, label in loader: | |
logits = model(input_ids) | |
loss = criterion(logits, label) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
model.save_pretrained("trained_evo") | |
print("✅ Retraining complete.") | |
return True | |
except Exception as e: | |
print(f"[Retrain Error] {e}") | |
return False | |