File size: 2,425 Bytes
e7e30db
da42a90
e7e30db
 
3489232
 
e7e30db
 
 
 
 
 
 
 
3489232
e7e30db
 
3489232
 
cae5830
e7e30db
3489232
e7e30db
3489232
 
 
 
 
 
 
 
 
 
 
e7e30db
 
3489232
 
 
 
 
 
 
 
 
e7e30db
 
3489232
 
 
 
e7e30db
 
 
3489232
 
 
 
 
e7e30db
 
3489232
e7e30db
3489232
 
e7e30db
 
3489232
 
e7e30db
3489232
e7e30db
 
 
 
40911df
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer

# Initialize Firebase if not already initialized
if not firebase_admin._apps:
    cred = credentials.Certificate("firebase_key.json")
    firebase_admin.initialize_app(cred)

db = firestore.client()

def fetch_training_data(tokenizer):
    logs_ref = db.collection("evo_feedback")
    docs = logs_ref.stream()

    input_ids, attention_masks, labels = [], [], []
    for doc in docs:
        data = doc.to_dict()
        prompt = data.get("prompt", "")
        winner = data.get("winner", "")
        if winner and prompt:
            text = prompt + " [SEP] " + winner
            encoding = tokenizer(
                text,
                truncation=True,
                padding="max_length",
                max_length=128,
                return_tensors="pt"
            )
            input_ids.append(encoding["input_ids"][0])
            attention_masks.append(encoding["attention_mask"][0])
            label = 0 if "1" in winner else 1
            labels.append(label)

    if not input_ids:
        return None, None, None

    return (
        torch.stack(input_ids),
        torch.stack(attention_masks),
        torch.tensor(labels, dtype=torch.long)
    )

def retrain_and_save():
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    input_ids, attention_masks, labels = fetch_training_data(tokenizer)

    if input_ids is None or len(input_ids) < 2:
        print("⚠️ Not enough training data.")
        return

    config = EvoTransformerConfig()
    model = EvoTransformerForClassification(config)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(3):
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_masks)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")

    os.makedirs("trained_model", exist_ok=True)
    model.save_pretrained("trained_model")
    print("✅ EvoTransformer retrained and saved to trained_model/")

if __name__ == "__main__":
    retrain_and_save()

# Alias to match expected import
retrain_model = retrain_and_save