Spaces:
Running
Running
File size: 2,359 Bytes
e7e30db da42a90 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 cae5830 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db 3489232 e7e30db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import torch
import firebase_admin
from firebase_admin import credentials, firestore
from evo_model import EvoTransformerForClassification, EvoTransformerConfig
from transformers import BertTokenizer
# Initialize Firebase if not already initialized
if not firebase_admin._apps:
cred = credentials.Certificate("firebase_key.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
def fetch_training_data(tokenizer):
logs_ref = db.collection("evo_feedback")
docs = logs_ref.stream()
input_ids, attention_masks, labels = [], [], []
for doc in docs:
data = doc.to_dict()
prompt = data.get("prompt", "")
winner = data.get("winner", "")
if winner and prompt:
text = prompt + " [SEP] " + winner
encoding = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
input_ids.append(encoding["input_ids"][0])
attention_masks.append(encoding["attention_mask"][0])
label = 0 if "1" in winner else 1
labels.append(label)
if not input_ids:
return None, None, None
return (
torch.stack(input_ids),
torch.stack(attention_masks),
torch.tensor(labels, dtype=torch.long)
)
def retrain_and_save():
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
input_ids, attention_masks, labels = fetch_training_data(tokenizer)
if input_ids is None or len(input_ids) < 2:
print("⚠️ Not enough training data.")
return
config = EvoTransformerConfig()
model = EvoTransformerForClassification(config)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(3):
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_masks)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")
os.makedirs("trained_model", exist_ok=True)
model.save_pretrained("trained_model")
print("✅ EvoTransformer retrained and saved to trained_model/")
if __name__ == "__main__":
retrain_and_save()
|