File size: 2,256 Bytes
da42a90
cae5830
 
 
 
c605926
d29dccb
cae5830
c605926
cae5830
 
 
 
 
 
 
 
 
 
 
 
 
 
da42a90
cae5830
 
d29dccb
 
da42a90
92432bf
 
cae5830
 
 
92432bf
c605926
cae5830
 
 
c605926
da42a90
cae5830
d29dccb
cae5830
d29dccb
 
 
 
cae5830
 
da42a90
 
 
cae5830
 
c605926
cae5830
 
92432bf
d29dccb
92432bf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
from firebase_admin import firestore
import pandas as pd

# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def load_feedback_data():
    db = firestore.client()
    docs = db.collection("evo_feedback_logs").stream()
    data = []
    for doc in docs:
        d = doc.to_dict()
        if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
            data.append((
                d["goal"],
                d["solution_1"],
                d["solution_2"],
                0 if d["correct_answer"] == "Solution 1" else 1
            ))
    return pd.DataFrame(data, columns=["goal", "sol1", "sol2", "label"])

def encode(goal, sol1, sol2):
    prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
    encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    return encoded.input_ids, encoded.attention_mask

def manual_retrain():
    try:
        data = load_feedback_data()
        if data.empty:
            print("[Retrain Error] No training data found.")
            return False

        model = EvoTransformerForClassification.from_pretrained("trained_model")
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = torch.nn.CrossEntropyLoss()

        model.train()
        for _, row in data.sample(frac=1).iterrows():
            input_ids, attention_mask = encode(row["goal"], row["sol1"], row["sol2"])
            label = torch.tensor([row["label"]])

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits if hasattr(outputs, "logits") else outputs

            if logits.ndim == 2 and label.ndim == 1:
                loss = loss_fn(logits, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            else:
                print("[Retrain Warning] Shape mismatch, skipping one example.")

        model.save_pretrained("trained_model")
        print("✅ Evo retrained and saved.")
        return True

    except Exception as e:
        print(f"[Retrain Error] {e}")
        return False