Spaces:
Running
Running
File size: 3,316 Bytes
e7e30db da42a90 e7e30db 3489232 e7e30db ef5a88b e7e30db 3489232 e7e30db 3489232 cae5830 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db ef5a88b 3489232 ef5a88b e7e30db ef5a88b 3489232 ef5a88b e7e30db ef5a88b 3489232 ef5a88b e7e30db ef5a88b 40911df ef5a88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer
from init_model import load_model
from dashboard import evolution_accuracy_plot
# Initialize Firebase if not already initialized
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
def fetch_training_data(tokenizer):
logs_ref = db.collection("evo_feedback")
docs = logs_ref.stream()
input_ids, attention_masks, labels = [], [], []
for doc in docs:
data = doc.to_dict()
prompt = data.get("prompt", "")
winner = data.get("winner", "")
if winner and prompt:
text = prompt + " [SEP] " + winner
encoding = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
input_ids.append(encoding["input_ids"][0])
attention_masks.append(encoding["attention_mask"][0])
label = 0 if "1" in winner else 1
labels.append(label)
if not input_ids:
return None, None, None
return (
torch.stack(input_ids),
torch.stack(attention_masks),
torch.tensor(labels, dtype=torch.long)
)
def get_architecture_summary(model):
summary = {
"Layers": getattr(model, "num_layers", "N/A"),
"Attention Heads": getattr(model, "num_heads", "N/A"),
"FFN Dim": getattr(model, "ffn_dim", "N/A"),
"Memory Enabled": getattr(model, "use_memory", "N/A"),
}
return "\n".join(f"{k}: {v}" for k, v in summary.items())
def retrain_model():
try:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
input_ids, attention_masks, labels = fetch_training_data(tokenizer)
if input_ids is None or len(input_ids) < 2:
return "⚠️ Not enough data to retrain.", None, "Please log more feedback first."
config = EvoTransformerConfig()
model = EvoTransformerForClassification(config)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(3):
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_masks)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
os.makedirs("trained_model", exist_ok=True)
model.save_pretrained("trained_model")
print("✅ EvoTransformer retrained and saved.")
# Reload the updated model for summary + plot
updated_model = load_model()
arch_text = get_architecture_summary(updated_model)
plot = evolution_accuracy_plot()
return arch_text, plot, "✅ EvoTransformer retrained successfully!"
except Exception as e:
print(f"❌ Retraining failed: {e}")
return "❌ Error", None, f"Retrain failed: {e}"
# Support direct script run
if __name__ == "__main__":
retrain_model()
|