File size: 3,126 Bytes
5876a92
 
 
 
da42a90
 
 
 
 
 
c605926
da42a90
5876a92
 
 
c605926
5876a92
da42a90
c605926
da42a90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92432bf
 
da42a90
92432bf
da42a90
c605926
da42a90
 
92432bf
c605926
da42a90
 
 
c605926
da42a90
 
 
 
c605926
da42a90
 
 
c605926
da42a90
92432bf
da42a90
 
 
 
 
 
 
 
 
 
 
 
 
c605926
da42a90
5876a92
da42a90
92432bf
5876a92
92432bf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# watchdog.py

import firebase_admin
from firebase_admin import credentials, firestore
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer
from torch.utils.data import DataLoader, Dataset
from evo_model import EvoTransformerForClassification, EvoTransformerConfig

# Initialize Firebase
if not firebase_admin._apps:
    cred = credentials.Certificate("firebase_key.json")
    firebase_admin.initialize_app(cred)

db = firestore.client()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Dataset for training
class FeedbackDataset(Dataset):
    def __init__(self, records, tokenizer, max_length=64):
        self.records = records
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_map = {"Solution 1": 0, "Solution 2": 1}

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        row = self.records[idx]
        combined = f"Goal: {row['goal']} Option 1: {row['solution_1']} Option 2: {row['solution_2']}"
        inputs = self.tokenizer(combined, padding="max_length", truncation=True,
                                max_length=self.max_length, return_tensors="pt")
        label = self.label_map[row["correct_answer"]]
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": torch.tensor(label)
        }

# Manual retrain trigger
def manual_retrain():
    try:
        # Step 1: Fetch feedback data from Firestore
        docs = db.collection("evo_feedback_logs").stream()
        feedback_data = [doc.to_dict() for doc in docs if "goal" in doc.to_dict()]

        if len(feedback_data) < 5:
            print("[Retrain Skipped] Not enough feedback.")
            return False

        # Step 2: Load tokenizer and dataset
        dataset = FeedbackDataset(feedback_data, tokenizer)
        loader = DataLoader(dataset, batch_size=4, shuffle=True)

        # Step 3: Load model
        config = EvoTransformerConfig()
        model = EvoTransformerForClassification(config)
        model.train()

        # Step 4: Define optimizer and loss
        optimizer = optim.Adam(model.parameters(), lr=2e-5)
        loss_fn = nn.CrossEntropyLoss()

        # Step 5: Train
        for epoch in range(3):
            total_loss = 0
            for batch in loader:
                optimizer.zero_grad()
                input_ids = batch["input_ids"]
                attention_mask = batch["attention_mask"]
                labels = batch["labels"]

                logits = model(input_ids)
                loss = loss_fn(logits, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f"[Retrain] Epoch {epoch + 1} Loss: {total_loss:.4f}")

        # Step 6: Save updated model
        torch.save(model.state_dict(), "trained_model.pt")
        print("✅ Evo updated with latest feedback.")
        return True

    except Exception as e:
        print(f"[Retrain Error] {e}")
        return False