File size: 4,025 Bytes
e7e30db
c0a6a03
da42a90
e7e30db
 
3489232
 
e7e30db
ef5a88b
 
 
e7e30db
 
 
 
 
 
 
3489232
e7e30db
 
3489232
 
cae5830
e7e30db
3489232
e7e30db
3489232
 
 
 
 
 
 
 
 
 
 
e7e30db
 
3489232
 
 
 
 
 
 
 
 
e7e30db
ef5a88b
 
87324d5
 
 
 
ef5a88b
 
3489232
ef5a88b
 
 
 
e7e30db
ef5a88b
 
3489232
87324d5
 
 
 
 
 
 
 
 
 
ef5a88b
 
e7e30db
ef5a88b
 
3489232
ef5a88b
 
 
 
 
 
 
e7e30db
87324d5
 
c0a6a03
87324d5
c0a6a03
ef5a88b
 
c0a6a03
 
 
 
 
 
 
 
 
 
 
87324d5
c0a6a03
ef5a88b
 
87324d5
ef5a88b
 
 
40911df
ef5a88b
 
 
 
 
 
87324d5
ef5a88b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import json
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer

from init_model import load_model
from dashboard import evolution_accuracy_plot

# Initialize Firebase if not already initialized
if not firebase_admin._apps:
    cred = credentials.Certificate("firebase_key.json")
    firebase_admin.initialize_app(cred)

db = firestore.client()

def fetch_training_data(tokenizer):
    logs_ref = db.collection("evo_feedback")
    docs = logs_ref.stream()

    input_ids, attention_masks, labels = [], [], []
    for doc in docs:
        data = doc.to_dict()
        prompt = data.get("prompt", "")
        winner = data.get("winner", "")
        if winner and prompt:
            text = prompt + " [SEP] " + winner
            encoding = tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=128,
                return_tensors="pt"
            )
            input_ids.append(encoding["input_ids"][0])
            attention_masks.append(encoding["attention_mask"][0])
            label = 0 if "1" in winner else 1
            labels.append(label)

    if not input_ids:
        return None, None, None

    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.tensor(labels, dtype=torch.long)
    )

def get_architecture_summary(model):
    summary = {
        "Layers": getattr(model.config, "num_layers", "N/A"),
        "Attention Heads": getattr(model.config, "num_heads", "N/A"),
        "FFN Dim": getattr(model.config, "ffn_dim", "N/A"),
        "Memory Enabled": getattr(model.config, "use_memory", "N/A"),
    }
    return "\n".join(f"{k}: {v}" for k, v in summary.items())

def retrain_model():
    try:
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        input_ids, attention_masks, labels = fetch_training_data(tokenizer)

        if input_ids is None or len(input_ids) < 2:
            return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."

        # βœ… Explicitly define architecture details
        config = EvoTransformerConfig(
            hidden_size=384,
            num_layers=6,
            num_labels=2,
            num_heads=6,
            ffn_dim=1024,
            use_memory=False
        )

        model = EvoTransformerForClassification(config)
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(3):
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_masks)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

        # Simulate accuracy (placeholder)
        accuracy = 1.0

        # Log evolution accuracy
        log_path = "trained_model/evolution_log.json"
        os.makedirs("trained_model", exist_ok=True)

        if os.path.exists(log_path):
            with open(log_path, "r") as f:
                history = json.load(f)
        else:
            history = []

        history.append({"accuracy": accuracy})

        with open(log_path, "w") as f:
            json.dump(history, f)

        # Save model
        model.save_pretrained("trained_model")
        print("βœ… EvoTransformer retrained and saved.")

        # Reload and return dashboard updates
        updated_model = load_model()
        arch_text = get_architecture_summary(updated_model)
        plot = evolution_accuracy_plot()

        return arch_text, plot, "βœ… EvoTransformer retrained successfully!"

    except Exception as e:
        print(f"❌ Retraining failed: {e}")
        return "❌ Error", None, f"Retrain failed: {e}"

# Allow direct script run
if __name__ == "__main__":
    retrain_model()