Spaces:
Sleeping
Sleeping
import random | |
import json | |
import os | |
import uuid | |
import csv | |
import torch.nn as nn | |
GENOME_LOG = "genome_log.csv" | |
BEST_GENOME_FILE = "best_genome.json" | |
# Mutation bounds | |
MUTATION_LIMITS = { | |
"num_layers": (2, 12), | |
"ffn_dim": (256, 4096), | |
"num_heads": (2, 16), | |
} | |
def default_config(): | |
return { | |
"genome_id": str(uuid.uuid4()), | |
"num_layers": 6, | |
"ffn_dim": 1024, | |
"num_heads": 8, | |
"memory_enabled": True | |
} | |
def mutate_genome(base_config, exploration_rate=0.5): | |
config = base_config.copy() | |
config["genome_id"] = str(uuid.uuid4()) | |
mutation_type = random.choice(["num_layers", "ffn_dim", "num_heads", "memory_enabled"]) | |
if mutation_type == "memory_enabled": | |
config["memory_enabled"] = not config["memory_enabled"] | |
else: | |
min_val, max_val = MUTATION_LIMITS[mutation_type] | |
change = int((max_val - min_val) * exploration_rate) | |
delta = random.randint(-change, change) | |
config[mutation_type] = max(min_val, min(max_val, config[mutation_type] + delta)) | |
# ✅ Ensure num_heads divides d_model cleanly | |
embed_dim = 512 | |
if config["num_heads"] > embed_dim: | |
config["num_heads"] = max(1, embed_dim // 64) | |
while embed_dim % config["num_heads"] != 0: | |
config["num_heads"] -= 1 | |
if config["num_heads"] <= 0: | |
config["num_heads"] = 1 | |
break | |
return config | |
def log_genome(config, score=None): | |
row = [ | |
config.get("genome_id", ""), | |
config.get("num_layers", ""), | |
config.get("ffn_dim", ""), | |
config.get("num_heads", ""), | |
config.get("memory_enabled", ""), | |
score if score is not None else "" | |
] | |
file_exists = os.path.exists(GENOME_LOG) | |
with open(GENOME_LOG, "a", newline="", encoding="utf-8") as f: | |
writer = csv.writer(f) | |
if not file_exists: | |
writer.writerow(["genome_id", "num_layers", "ffn_dim", "num_heads", "memory_enabled", "score"]) | |
writer.writerow(row) | |
def save_best_genome(config): | |
with open(BEST_GENOME_FILE, "w", encoding="utf-8") as f: | |
json.dump(config, f) | |
def load_best_genome(): | |
if os.path.exists(BEST_GENOME_FILE): | |
with open(BEST_GENOME_FILE, "r", encoding="utf-8") as f: | |
return json.load(f) | |
return default_config() | |
def build_model_from_config(config): | |
from evo_model import EvoTransformerV22 | |
return EvoTransformerV22(config) | |