Spaces:
Running
Running
import os | |
import torch | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from evo_model import EvoTransformerForClassification, EvoTransformerConfig | |
from transformers import BertTokenizer | |
# Initialize Firebase if not already initialized | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase_key.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
def fetch_training_data(tokenizer): | |
logs_ref = db.collection("evo_feedback") | |
docs = logs_ref.stream() | |
input_ids, attention_masks, labels = [], [], [] | |
for doc in docs: | |
data = doc.to_dict() | |
prompt = data.get("prompt", "") | |
winner = data.get("winner", "") | |
if winner and prompt: | |
text = prompt + " [SEP] " + winner | |
encoding = tokenizer( | |
text, | |
truncation=True, | |
padding="max_length", | |
max_length=128, | |
return_tensors="pt" | |
) | |
input_ids.append(encoding["input_ids"][0]) | |
attention_masks.append(encoding["attention_mask"][0]) | |
label = 0 if "1" in winner else 1 | |
labels.append(label) | |
if not input_ids: | |
return None, None, None | |
return ( | |
torch.stack(input_ids), | |
torch.stack(attention_masks), | |
torch.tensor(labels, dtype=torch.long) | |
) | |
def retrain_and_save(): | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
input_ids, attention_masks, labels = fetch_training_data(tokenizer) | |
if input_ids is None or len(input_ids) < 2: | |
print("⚠️ Not enough training data.") | |
return | |
config = EvoTransformerConfig() | |
model = EvoTransformerForClassification(config) | |
model.train() | |
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4) | |
loss_fn = torch.nn.CrossEntropyLoss() | |
for epoch in range(3): | |
optimizer.zero_grad() | |
outputs = model(input_ids, attention_mask=attention_masks) | |
loss = loss_fn(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}") | |
os.makedirs("trained_model", exist_ok=True) | |
model.save_pretrained("trained_model") | |
print("✅ EvoTransformer retrained and saved to trained_model/") | |
if __name__ == "__main__": | |
retrain_and_save() | |
# Alias to match expected import | |
retrain_model = retrain_and_save |