Spaces:
Running
Running
File size: 2,256 Bytes
da42a90 cae5830 c605926 d29dccb cae5830 c605926 cae5830 da42a90 cae5830 d29dccb da42a90 92432bf cae5830 92432bf c605926 cae5830 c605926 da42a90 cae5830 d29dccb cae5830 d29dccb cae5830 da42a90 cae5830 c605926 cae5830 92432bf d29dccb 92432bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
from transformers import AutoTokenizer
from evo_model import EvoTransformerForClassification
from firebase_admin import firestore
import pandas as pd
# Load tokenizer once
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def load_feedback_data():
db = firestore.client()
docs = db.collection("evo_feedback_logs").stream()
data = []
for doc in docs:
d = doc.to_dict()
if all(k in d for k in ["goal", "solution_1", "solution_2", "correct_answer"]):
data.append((
d["goal"],
d["solution_1"],
d["solution_2"],
0 if d["correct_answer"] == "Solution 1" else 1
))
return pd.DataFrame(data, columns=["goal", "sol1", "sol2", "label"])
def encode(goal, sol1, sol2):
prompt = f"Goal: {goal} Option 1: {sol1} Option 2: {sol2}"
encoded = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
return encoded.input_ids, encoded.attention_mask
def manual_retrain():
try:
data = load_feedback_data()
if data.empty:
print("[Retrain Error] No training data found.")
return False
model = EvoTransformerForClassification.from_pretrained("trained_model")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()
model.train()
for _, row in data.sample(frac=1).iterrows():
input_ids, attention_mask = encode(row["goal"], row["sol1"], row["sol2"])
label = torch.tensor([row["label"]])
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits if hasattr(outputs, "logits") else outputs
if logits.ndim == 2 and label.ndim == 1:
loss = loss_fn(logits, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
else:
print("[Retrain Warning] Shape mismatch, skipping one example.")
model.save_pretrained("trained_model")
print("✅ Evo retrained and saved.")
return True
except Exception as e:
print(f"[Retrain Error] {e}")
return False
|