File size: 3,316 Bytes
e7e30db
da42a90
e7e30db
 
3489232
 
e7e30db
ef5a88b
 
 
e7e30db
 
 
 
 
 
 
3489232
e7e30db
 
3489232
 
cae5830
e7e30db
3489232
e7e30db
3489232
 
 
 
 
 
 
 
 
 
 
e7e30db
 
3489232
 
 
 
 
 
 
 
 
e7e30db
ef5a88b
 
 
 
 
 
 
 
3489232
ef5a88b
 
 
 
e7e30db
ef5a88b
 
3489232
ef5a88b
 
 
e7e30db
ef5a88b
 
3489232
ef5a88b
 
 
 
 
 
 
e7e30db
ef5a88b
 
 
 
 
 
 
 
 
40911df
ef5a88b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer

from init_model import load_model
from dashboard import evolution_accuracy_plot

# Initialize Firebase if not already initialized
if not firebase_admin._apps:
    cred = credentials.Certificate("firebase_key.json")
    firebase_admin.initialize_app(cred)

db = firestore.client()

def fetch_training_data(tokenizer):
    logs_ref = db.collection("evo_feedback")
    docs = logs_ref.stream()

    input_ids, attention_masks, labels = [], [], []
    for doc in docs:
        data = doc.to_dict()
        prompt = data.get("prompt", "")
        winner = data.get("winner", "")
        if winner and prompt:
            text = prompt + " [SEP] " + winner
            encoding = tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=128,
                return_tensors="pt"
            )
            input_ids.append(encoding["input_ids"][0])
            attention_masks.append(encoding["attention_mask"][0])
            label = 0 if "1" in winner else 1
            labels.append(label)

    if not input_ids:
        return None, None, None

    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.tensor(labels, dtype=torch.long)
    )

def get_architecture_summary(model):
    summary = {
        "Layers": getattr(model, "num_layers", "N/A"),
        "Attention Heads": getattr(model, "num_heads", "N/A"),
        "FFN Dim": getattr(model, "ffn_dim", "N/A"),
        "Memory Enabled": getattr(model, "use_memory", "N/A"),
    }
    return "\n".join(f"{k}: {v}" for k, v in summary.items())

def retrain_model():
    try:
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        input_ids, attention_masks, labels = fetch_training_data(tokenizer)

        if input_ids is None or len(input_ids) < 2:
            return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."

        config = EvoTransformerConfig()
        model = EvoTransformerForClassification(config)
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(3):
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_masks)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

        os.makedirs("trained_model", exist_ok=True)
        model.save_pretrained("trained_model")

        print("✅ EvoTransformer retrained and saved.")

        # Reload the updated model for summary + plot
        updated_model = load_model()
        arch_text = get_architecture_summary(updated_model)
        plot = evolution_accuracy_plot()

        return arch_text, plot, "✅ EvoTransformer retrained successfully!"

    except Exception as e:
        print(f"❌ Retraining failed: {e}")
        return "❌ Error", None, f"Retrain failed: {e}"

# Support direct script run
if __name__ == "__main__":
    retrain_model()