File size: 3,781 Bytes
6e70824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d1957
 
 
 
 
6e70824
75d1957
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# watchdog.py

import firebase_admin
from firebase_admin import credentials, firestore
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from model import EvoTransformer  # make sure this is in your project
import time
import datetime
import os

# βœ… Firebase Setup
if not firebase_admin._apps:
    cred = credentials.Certificate("evotransformer-firebase-adminsdk-fbsvc-37a4b838aa.json")
    firebase_admin.initialize_app(cred)

db = firestore.client()
COLLECTION = "evo_feedback_logs"
LAST_CHECK_FILE = "last_feedback_timestamp.txt"

# βœ… Dataset for training
class EvoDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        item = self.data[idx]
        x = f"{item['goal']} [SEP] {item['solution1']} [SEP] {item['solution2']}"
        y = 0 if item['correct'] == "Solution 1" else 1
        return x, y

    def __len__(self):
        return len(self.data)

# βœ… Dummy tokenizer (replace with your tokenizer if needed)
def tokenize(text):
    return torch.tensor([ord(c) % 128 for c in text[:256]])

# βœ… Fetch new data
def fetch_new_feedback():
    if os.path.exists(LAST_CHECK_FILE):
        with open(LAST_CHECK_FILE, "r") as f:
            last_ts = f.read().strip()
    else:
        last_ts = "1970-01-01T00:00:00Z"

    query = db.collection(COLLECTION).where("timestamp", ">", last_ts)
    docs = list(query.stream())

    feedbacks = []
    latest_ts = last_ts
    for doc in docs:
        data = doc.to_dict()
        if all(k in data for k in ["goal", "sol1", "sol2", "correct"]):
            feedbacks.append({
                "goal": data["goal"],
                "solution1": data["sol1"],
                "solution2": data["sol2"],
                "correct": data["correct"]
            })
            latest_ts = max(latest_ts, data.get("timestamp", last_ts))

    if feedbacks:
        with open(LAST_CHECK_FILE, "w") as f:
            f.write(latest_ts)

    return feedbacks

# βœ… Train Evo on new data
def train_on_feedback(feedbacks):
    if not feedbacks:
        print("No new feedback to train on.")
        return

    print(f"πŸ” Retraining on {len(feedbacks)} new examples...")

    dataset = EvoDataset(feedbacks)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    model = EvoTransformer()
    if os.path.exists("trained_model.pt"):
        model.load_state_dict(torch.load("trained_model.pt"))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(3):  # quick fine-tuning
        total_loss = 0
        correct = 0
        for inputs, labels in dataloader:
            inputs = torch.stack([tokenize(x) for x in inputs])
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        acc = correct / len(dataset)
        print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Accuracy={acc:.2%}")

    torch.save(model.state_dict(), "trained_model.pt")
    print("βœ… Updated model saved.")

# βœ… Watch Loop
def watch():
    print("🧠 Evo Watchdog started...")
    while True:
        try:
            new_data = fetch_new_feedback()
            train_on_feedback(new_data)
        except Exception as e:
            print(f"⚠️ Error: {str(e)}")
        time.sleep(60)  # check every 60 seconds

def manual_retrain():
    new_data = fetch_new_feedback()
    train_on_feedback(new_data)

# Optional: only run loop if executed directly
if __name__ == "__main__":
    watch()