Spaces:
Running
Running
import os | |
import torch | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from model import SimpleEvoModel | |
# Initialize Firebase if not already initialized | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase_key.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
def fetch_training_data(): | |
logs_ref = db.collection("evo_feedback") | |
docs = logs_ref.stream() | |
inputs, labels = [], [] | |
for doc in docs: | |
data = doc.to_dict() | |
goal = data.get("prompt", "") | |
winner = data.get("winner", "") | |
if winner: | |
# Simulated encoding | |
vector = [float(ord(c) % 256) / 255.0 for c in (goal + winner)] | |
vector = vector[:768] + [0.0] * max(0, 768 - len(vector)) # pad/truncate | |
label = 0 if "1" in winner else 1 | |
inputs.append(vector) | |
labels.append(label) | |
return torch.tensor(inputs, dtype=torch.float32), torch.tensor(labels, dtype=torch.long) | |
def retrain_and_save(): | |
X, y = fetch_training_data() | |
if len(X) < 2: | |
print("⚠️ Not enough training data.") | |
return | |
model = SimpleEvoModel() | |
loss_fn = torch.nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
for epoch in range(5): | |
optimizer.zero_grad() | |
output = model(X) | |
loss = loss_fn(output, y) | |
loss.backward() | |
optimizer.step() | |
# Save retrained model to trained_model/ | |
os.makedirs("trained_model", exist_ok=True) | |
torch.save(model.state_dict(), "trained_model/pytorch_model.bin") | |
print("✅ EvoTransformer retrained and saved to trained_model/") | |
if __name__ == "__main__": | |
retrain_and_save() | |