EvoTransformer-Demo / evo_transformer.py
HemanM's picture
Update evo_transformer.py
d6397a3 verified
raw
history blame
1.34 kB
import random
import copy
class EvoTransformer:
def __init__(self):
self.history = []
self.base_config = {
"layers": 4,
"attention_heads": 4,
"ffn_dim": 1024,
"dropout": 0.1,
"memory": False
}
def reset(self):
self.history = []
def mutate(self, config):
new_config = copy.deepcopy(config)
if random.random() < 0.5:
new_config["layers"] = min(12, max(1, new_config["layers"] + random.choice([-1, 1])))
if random.random() < 0.5:
new_config["attention_heads"] = min(12, max(1, new_config["attention_heads"] + random.choice([-1, 1])))
if random.random() < 0.5:
new_config["ffn_dim"] = min(4096, max(128, new_config["ffn_dim"] + random.choice([-512, 512])))
if random.random() < 0.5:
new_config["dropout"] = round(min(0.5, max(0.0, new_config["dropout"] + random.choice([-0.02, 0.02]))), 2)
if random.random() < 0.3:
new_config["memory"] = not new_config["memory"]
return new_config
def run_evolution(self, generations=5):
current = self.base_config
self.history.append(current)
for _ in range(generations - 1):
current = self.mutate(current)
self.history.append(current)