HemanM commited on
Commit
3796a74
·
verified ·
1 Parent(s): bf6c353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -118
app.py CHANGED
@@ -9,27 +9,25 @@ import plotly.graph_objs as go
9
  import gradio as gr
10
  import pandas as pd
11
 
12
- # New deps for proxy fitness
13
  import torch
14
  import torch.nn as nn
15
  import torch.optim as optim
16
 
17
- # Local utils (add this file next to app.py)
18
  from data_utils import load_piqa, load_hellaswag, hash_vectorize
19
 
20
  # =========================
21
- # UX THEME & STYLES
22
  # =========================
23
  CUSTOM_CSS = """
24
- :root { --radius-2xl: 20px; }
25
- .gradio-container {max-width: 1400px !important}
26
- #header-card {border-radius: var(--radius-2xl); box-shadow: 0 6px 24px rgba(0,0,0,0.08)}
27
- #viz-card, #right-card, #table-card {border-radius: var(--radius-2xl); box-shadow: 0 6px 24px rgba(0,0,0,0.06)}
28
- #stats {display:flex; gap:16px; flex-wrap:wrap}
29
- .stat {flex:1; min-width:180px; background:#0b1220; color:white; border-radius:16px; padding:14px 16px}
30
- .stat .k {font-size:14px; opacity:0.8}
31
- .stat .v {font-size:22px; font-weight:700}
32
- .gr-button {border-radius:14px}
33
  """
34
 
35
  # =========================
@@ -45,9 +43,9 @@ class Genome:
45
  dropout: float
46
  species: int = 0
47
  fitness: float = float("inf")
 
48
 
49
  def vector(self) -> np.ndarray:
50
- # Normalized structural vector (0..1)
51
  return np.array([
52
  self.d_model / 1024.0,
53
  self.n_layers / 24.0,
@@ -77,7 +75,7 @@ def mutate(g: Genome, rng: random.Random, rate: float) -> Genome:
77
  if rng.random() < rate: g.memory_tokens = rng.choice([0, 4, 8, 16])
78
  if rng.random() < rate: g.dropout = rng.choice([0.0, 0.05, 0.1, 0.15])
79
  if rng.random() < rate * 0.5: g.species = rng.randrange(5)
80
- g.fitness = float("inf")
81
  return g
82
 
83
  def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
@@ -89,7 +87,8 @@ def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
89
  memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
90
  dropout = a.dropout if rng.random()<0.5 else b.dropout,
91
  species = a.species if rng.random()<0.5 else b.species,
92
- fitness = float("inf")
 
93
  )
94
 
95
  # =========================
@@ -100,7 +99,6 @@ def rastrigin(x: np.ndarray) -> float:
100
  return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
101
 
102
  class TinyMLP(nn.Module):
103
- """Small MLP whose capacity depends on the genome (so evolution matters)."""
104
  def __init__(self, in_dim: int, genome: Genome):
105
  super().__init__()
106
  h1 = max(64, int(0.25 * genome.d_model))
@@ -110,29 +108,27 @@ class TinyMLP(nn.Module):
110
  nn.Linear(h1, h2), nn.ReLU(),
111
  nn.Linear(h2, 1)
112
  )
113
- def forward(self, x):
114
- return self.net(x).squeeze(-1)
115
 
116
  @lru_cache(maxsize=4)
117
  def _cached_dataset(name: str):
118
- if name.startswith("PIQA"):
119
- return load_piqa(subset=800, seed=42)
120
- if name.startswith("HellaSwag"):
121
- return load_hellaswag(subset=800, seed=42)
122
- return None # Demo uses surrogate
123
 
124
- def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu") -> Optional[float]:
 
125
  data = _cached_dataset(dataset_name)
126
  if data is None:
127
- return None
128
- Xtr_txt, ytr, Xva_txt, yva = data
 
129
 
130
- # Hash vectorize to fixed dimension (fast, no tokenizer)
131
  nfeat = 4096
132
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
133
  Xva = hash_vectorize(Xva_txt, n_features=nfeat, seed=5678)
134
 
135
- # to torch tensors
136
  Xtr_t = torch.from_numpy(Xtr)
137
  ytr_t = torch.from_numpy(ytr.astype(np.float32))
138
  Xva_t = torch.from_numpy(Xva)
@@ -142,129 +138,119 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
142
  opt = optim.AdamW(model.parameters(), lr=2e-3)
143
  lossf = nn.BCEWithLogitsLoss()
144
 
145
- # small, fast loop
146
  model.train()
147
- steps = 120
148
- bs = 256
149
  N = Xtr_t.size(0)
150
  for _ in range(steps):
151
  idx = torch.randint(0, N, (bs,))
152
- xb = Xtr_t[idx].to(device)
153
- yb = ytr_t[idx].to(device)
154
- logits = model(xb)
155
- loss = lossf(logits, yb)
156
- opt.zero_grad()
157
- loss.backward()
158
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
159
  opt.step()
160
 
161
- # eval
162
  model.eval()
163
  with torch.no_grad():
164
  logits = model(Xva_t.to(device))
165
  probs = torch.sigmoid(logits).cpu().numpy()
166
 
167
- # Turn rows into accuracy
168
  if dataset_name.startswith("PIQA"):
169
- # rows in pairs [A,B]; label vector marks which row is positive
170
- probs = probs.reshape(-1, 2)
171
- yva2 = yva.reshape(-1, 2)
172
- pred = (probs[:, 0] > probs[:, 1]).astype(np.int64)
173
- truth = (yva2[:, 0] == 1).astype(np.int64) # 1 means first row is correct
174
  acc = float((pred == truth).mean())
175
  else:
176
- # HellaSwag: groups of 4; pick argmax
177
- probs = probs.reshape(-1, 4)
178
- yva2 = yva.reshape(-1, 4)
179
- pred = probs.argmax(axis=1)
180
- truth = yva2.argmax(axis=1)
181
  acc = float((pred == truth).mean())
182
 
183
- # Fitness = error + tiny parsimony + small exploration noise (minimize)
184
  parsimony = 0.00000002 * (genome.d_model**2 * genome.n_layers) + 0.0001 * genome.memory_tokens
185
  noise = np.random.normal(scale=0.01 * max(0.0, min(1.0, explore)))
186
  fitness = (1.0 - acc) + parsimony + noise
187
- return float(max(0.0, min(1.5, fitness)))
188
 
189
- def fitness_hook(genome: Genome, dataset: str, explore: float) -> float:
190
- """Selects the correct fitness path based on dropdown."""
191
  if dataset == "Demo (Surrogate)":
192
  v = genome.vector() * 2 - 1
193
  base = rastrigin(v)
194
  parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
195
  noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
196
- return float(base + parsimony + noise)
197
-
198
  if dataset.startswith("PIQA"):
199
- fit = _train_eval_proxy(genome, "PIQA", explore)
200
- if fit is not None:
201
- return fit
202
-
203
  if dataset.startswith("HellaSwag"):
204
- fit = _train_eval_proxy(genome, "HellaSwag", explore)
205
- if fit is not None:
206
- return fit
207
-
208
- # fallback to surrogate if something went wrong
209
  v = genome.vector() * 2 - 1
210
- return float(rastrigin(v))
211
 
212
  # =========================
213
- # PROJECTION & VIZ
214
  # =========================
215
  def sphere_project(points: np.ndarray) -> np.ndarray:
216
- # Fixed random projection 6D -> 3D then normalize to unit sphere
217
  rng = np.random.RandomState(42)
218
  W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
219
  Y = points @ W
220
  norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
221
- return Y / norms
222
 
223
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
224
  species = np.array([g.species for g in genomes])
225
- tooltip = [
226
- json.dumps({k:v for k,v in asdict(g).items() if k!="fitness"}) + f"\nfitness={g.fitness:.3f}"
227
- for g in genomes
228
- ]
229
 
230
  scatter = go.Scatter3d(
231
  x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
232
  mode='markers',
233
- marker=dict(size=6, color=species, opacity=0.9),
234
- text=tooltip, hoverinfo='text'
 
 
 
 
 
 
 
 
235
  )
236
 
237
- # Sphere mesh
238
- u = np.linspace(0, 2*np.pi, 48)
239
- v = np.linspace(0, np.pi, 24)
240
- xs = np.outer(np.cos(u), np.sin(v))
241
- ys = np.outer(np.sin(u), np.sin(v))
242
- zs = np.outer(np.ones_like(u), np.cos(v))
243
- sphere = go.Surface(x=xs, y=ys, z=zs, opacity=0.15, showscale=False)
 
244
 
245
  layout = go.Layout(
246
  title=f"Evo Sphere — Generation {gen_idx}",
247
  scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)),
248
  margin=dict(l=0, r=0, t=40, b=0),
249
- showlegend=False
 
250
  )
251
  return go.Figure(data=[sphere, scatter], layout=layout)
252
 
253
- def make_history_figure(history: List[Tuple[int,float]]) -> go.Figure:
 
254
  xs = [h[0] for h in history]
255
- ys = [h[1] for h in history]
 
 
 
 
 
256
  fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers")])
257
- fig.update_layout(title="Best Fitness per Generation", xaxis_title="Generation",
258
- yaxis_title="Fitness (lower is better)",
259
- margin=dict(l=30,r=10,t=40,b=30))
260
  return fig
261
 
262
  def approx_params(g: Genome) -> int:
263
- # Very rough estimate ignoring embeddings/vocab:
264
- # per-layer ~ (4 + 2*ffn_mult) * d_model^2
265
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
266
- total = per_layer * g.n_layers
267
- total += 1000 * g.memory_tokens # tiny bump for memory pathways (illustrative)
268
  return int(total)
269
 
270
  # =========================
@@ -277,7 +263,7 @@ class EvoRunner:
277
  self.stop_flag = False
278
  self.state: Dict[str, Any] = {}
279
 
280
- def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms):
281
  rng = random.Random(int(seed))
282
  self.stop_flag = False
283
  self.running = True
@@ -285,15 +271,16 @@ class EvoRunner:
285
  pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
286
  # initial eval
287
  for g in pop:
288
- g.fitness = fitness_hook(g, dataset, explore)
 
289
 
290
- history: List[Tuple[int,float]] = []
291
  best_overall: Optional[Genome] = None
292
 
293
  for gen in range(1, generations+1):
294
  if self.stop_flag: break
295
 
296
- # Selection: tournament size depends on exploitation
297
  k = max(2, int(2 + exploit * 5))
298
  parents = []
299
  for _ in range(pop_size):
@@ -303,16 +290,16 @@ class EvoRunner:
303
  # Reproduce
304
  children = []
305
  for i in range(0, pop_size, 2):
306
- a = parents[i]
307
- b = parents[(i+1) % pop_size]
308
  child1 = mutate(crossover(a,b,rng), rng, mutation_rate)
309
  child2 = mutate(crossover(b,a,rng), rng, mutation_rate)
310
  children.extend([child1, child2])
311
  children = children[:pop_size]
312
 
313
- # Evaluate kids
314
  for c in children:
315
- c.fitness = fitness_hook(c, dataset, explore)
 
316
 
317
  # Elitism
318
  elite_n = max(1, pop_size // 10)
@@ -326,18 +313,19 @@ class EvoRunner:
326
  if best_overall is None or best.fitness < best_overall.fitness:
327
  best_overall = best
328
 
329
- history.append((gen, best.fitness))
330
 
331
  # Viz snapshot
332
  P = np.stack([g.vector() for g in pop], axis=0)
333
  P3 = sphere_project(P)
334
  sphere_fig = make_sphere_figure(P3, pop, gen)
335
- hist_fig = make_history_figure(history)
336
  top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
337
  top_table = [
338
  {
339
  "gen": gen,
340
  "fitness": round(t.fitness, 4),
 
341
  "d_model": t.d_model,
342
  "layers": t.n_layers,
343
  "heads": t.n_heads,
@@ -357,7 +345,8 @@ class EvoRunner:
357
  "top": top_table,
358
  "best": best_card,
359
  "gen": gen,
360
- "dataset": dataset
 
361
  }
362
 
363
  time.sleep(max(0.0, pace_ms/1000.0))
@@ -369,16 +358,15 @@ class EvoRunner:
369
  t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
370
  t.start()
371
 
372
- def stop(self):
373
- self.stop_flag = True
374
 
375
  runner = EvoRunner()
376
 
377
  # =========================
378
- # GRADIO UI CALLBACKS
379
  # =========================
380
- def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms):
381
- runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms))
382
  return (gr.update(interactive=False), gr.update(interactive=True))
383
 
384
  def stop_evo():
@@ -389,16 +377,19 @@ def poll_state():
389
  with runner.lock:
390
  s = runner.state.copy()
391
  sphere = s.get("sphere", go.Figure())
392
- history = s.get("history", go.Figure())
393
  best = s.get("best", {})
394
  gen = s.get("gen", 0)
395
  dataset = s.get("dataset", "Demo (Surrogate)")
396
  top = s.get("top", [])
 
397
  if best:
 
398
  stats_md = (
399
  f"**Dataset:** {dataset} \n"
400
  f"**Generation:** {gen} \n"
401
  f"**Best fitness:** {best.get('fitness','–')} \n"
 
402
  f"**Config:** d_model={best.get('d_model')} · layers={best.get('layers')} · "
403
  f"heads={best.get('heads')} · ffn_mult={best.get('ffn_mult')} · mem={best.get('mem')} · "
404
  f"dropout={best.get('dropout')} \n"
@@ -419,14 +410,13 @@ def export_snapshot():
419
  return path
420
 
421
  # =========================
422
- # BUILD UI
423
  # =========================
424
  with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
425
  with gr.Column(elem_id="header-card"):
426
  gr.Markdown(
427
- "# Evo Playground — Live Evolving Transformer Architectures\n"
428
- "Watch the population **mutate, recombine, and converge** in real time. "
429
- "Choose a dataset and search behavior; the 3D sphere shows the architecture landscape (species = colors)."
430
  )
431
 
432
  with gr.Row():
@@ -435,9 +425,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
435
  with gr.Group():
436
  dataset = gr.Dropdown(
437
  label="Dataset",
438
- choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)", "WikiText Perplexity (Phase 2)"],
439
  value="Demo (Surrogate)",
440
- info="Demo is instant. PIQA/HellaSwag run a tiny CPU MLP proxy for real dataset fitness."
441
  )
442
  pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
443
  gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
@@ -447,12 +437,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
447
  exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation")
448
  seed = gr.Number(value=42, label="Seed", precision=0)
449
  pace = gr.Slider(0, 1000, value=120, step=10, label="Pace (ms between gens)")
 
450
  with gr.Row():
451
  start = gr.Button("▶ Start Evolution", variant="primary")
452
  stop = gr.Button("⏹ Stop", variant="secondary")
453
 
454
  with gr.Group(elem_id="right-card"):
455
- stats_md = gr.Markdown("Waiting…")
456
  export_btn = gr.Button("Export Snapshot (JSON)")
457
  export_file = gr.File(label="Download snapshot", visible=False)
458
 
@@ -461,16 +452,16 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
461
  with gr.Group(elem_id="viz-card"):
462
  sphere_plot = gr.Plot(label="Evolution Sphere")
463
  with gr.Group(elem_id="viz-card"):
464
- hist_plot = gr.Plot(label="Best Fitness History")
465
  with gr.Group(elem_id="table-card"):
466
  top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
467
 
468
  # Wiring
469
- start.click(start_evo, [dataset, pop, gens, mut, explore, exploit, seed, pace], [start, stop])
470
  stop.click(stop_evo, [], [start, stop])
471
  export_btn.click(export_snapshot, [], [export_file])
472
 
473
- # Initial paint once when app loads
474
  demo.load(poll_state, None, [sphere_plot, hist_plot, stats_md, top_df])
475
 
476
  # Continuous polling (every 0.7s)
 
9
  import gradio as gr
10
  import pandas as pd
11
 
12
+ # Proxy fitness deps
13
  import torch
14
  import torch.nn as nn
15
  import torch.optim as optim
16
 
 
17
  from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
  # =========================
20
+ # UX THEME & STYLES (cleaner, pro)
21
  # =========================
22
  CUSTOM_CSS = """
23
+ :root { --radius-2xl: 18px; }
24
+ .gradio-container {max-width: 1320px !important}
25
+ #header-card, #viz-card, #right-card, #table-card {
26
+ border-radius: var(--radius-2xl);
27
+ box-shadow: 0 6px 24px rgba(0,0,0,0.06);
28
+ }
29
+ .gr-button {border-radius: 12px}
30
+ #stats-md {font-size: 15px;}
 
31
  """
32
 
33
  # =========================
 
43
  dropout: float
44
  species: int = 0
45
  fitness: float = float("inf")
46
+ acc: Optional[float] = None # accuracy when dataset is PIQA/HS
47
 
48
  def vector(self) -> np.ndarray:
 
49
  return np.array([
50
  self.d_model / 1024.0,
51
  self.n_layers / 24.0,
 
75
  if rng.random() < rate: g.memory_tokens = rng.choice([0, 4, 8, 16])
76
  if rng.random() < rate: g.dropout = rng.choice([0.0, 0.05, 0.1, 0.15])
77
  if rng.random() < rate * 0.5: g.species = rng.randrange(5)
78
+ g.fitness = float("inf"); g.acc = None
79
  return g
80
 
81
  def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
 
87
  memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
88
  dropout = a.dropout if rng.random()<0.5 else b.dropout,
89
  species = a.species if rng.random()<0.5 else b.species,
90
+ fitness = float("inf"),
91
+ acc = None
92
  )
93
 
94
  # =========================
 
99
  return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
100
 
101
  class TinyMLP(nn.Module):
 
102
  def __init__(self, in_dim: int, genome: Genome):
103
  super().__init__()
104
  h1 = max(64, int(0.25 * genome.d_model))
 
108
  nn.Linear(h1, h2), nn.ReLU(),
109
  nn.Linear(h2, 1)
110
  )
111
+ def forward(self, x): return self.net(x).squeeze(-1)
 
112
 
113
  @lru_cache(maxsize=4)
114
  def _cached_dataset(name: str):
115
+ if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
116
+ if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
117
+ return None
 
 
118
 
119
+ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu") -> Tuple[float, Optional[float]]:
120
+ """Returns (fitness, accuracy or None)."""
121
  data = _cached_dataset(dataset_name)
122
  if data is None:
123
+ # Demo path handled elsewhere
124
+ v = genome.vector() * 2 - 1
125
+ return float(rastrigin(v)), None
126
 
127
+ Xtr_txt, ytr, Xva_txt, yva = data
128
  nfeat = 4096
129
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
130
  Xva = hash_vectorize(Xva_txt, n_features=nfeat, seed=5678)
131
 
 
132
  Xtr_t = torch.from_numpy(Xtr)
133
  ytr_t = torch.from_numpy(ytr.astype(np.float32))
134
  Xva_t = torch.from_numpy(Xva)
 
138
  opt = optim.AdamW(model.parameters(), lr=2e-3)
139
  lossf = nn.BCEWithLogitsLoss()
140
 
 
141
  model.train()
142
+ steps, bs = 120, 256
 
143
  N = Xtr_t.size(0)
144
  for _ in range(steps):
145
  idx = torch.randint(0, N, (bs,))
146
+ xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
147
+ logits = model(xb); loss = lossf(logits, yb)
148
+ opt.zero_grad(); loss.backward()
 
 
 
149
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
150
  opt.step()
151
 
 
152
  model.eval()
153
  with torch.no_grad():
154
  logits = model(Xva_t.to(device))
155
  probs = torch.sigmoid(logits).cpu().numpy()
156
 
 
157
  if dataset_name.startswith("PIQA"):
158
+ probs = probs.reshape(-1, 2); yva2 = yva.reshape(-1, 2)
159
+ pred = (probs[:,0] > probs[:,1]).astype(np.int64)
160
+ truth = (yva2[:,0] == 1).astype(np.int64)
 
 
161
  acc = float((pred == truth).mean())
162
  else:
163
+ probs = probs.reshape(-1, 4); yva2 = yva.reshape(-1, 4)
164
+ pred = probs.argmax(axis=1); truth = yva2.argmax(axis=1)
 
 
 
165
  acc = float((pred == truth).mean())
166
 
 
167
  parsimony = 0.00000002 * (genome.d_model**2 * genome.n_layers) + 0.0001 * genome.memory_tokens
168
  noise = np.random.normal(scale=0.01 * max(0.0, min(1.0, explore)))
169
  fitness = (1.0 - acc) + parsimony + noise
170
+ return float(max(0.0, min(1.5, fitness))), float(acc)
171
 
172
+ def evaluate_genome(genome: Genome, dataset: str, explore: float) -> Tuple[float, Optional[float]]:
 
173
  if dataset == "Demo (Surrogate)":
174
  v = genome.vector() * 2 - 1
175
  base = rastrigin(v)
176
  parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
177
  noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
178
+ return float(base + parsimony + noise), None
 
179
  if dataset.startswith("PIQA"):
180
+ return _train_eval_proxy(genome, "PIQA", explore)
 
 
 
181
  if dataset.startswith("HellaSwag"):
182
+ return _train_eval_proxy(genome, "HellaSwag", explore)
183
+ # fallback
 
 
 
184
  v = genome.vector() * 2 - 1
185
+ return float(rastrigin(v)), None
186
 
187
  # =========================
188
+ # PROJECTION & VIZ (bigger, transparent sphere, rich hover)
189
  # =========================
190
  def sphere_project(points: np.ndarray) -> np.ndarray:
 
191
  rng = np.random.RandomState(42)
192
  W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
193
  Y = points @ W
194
  norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
195
+ return (Y / norms) * 1.15 # slightly larger radius
196
 
197
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
198
  species = np.array([g.species for g in genomes])
199
+ # Prepare hover with all fields
200
+ custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
201
+ g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
202
+ for g in genomes], dtype=np.float32)
203
 
204
  scatter = go.Scatter3d(
205
  x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
206
  mode='markers',
207
+ marker=dict(size=7, color=species, opacity=0.95),
208
+ customdata=custom,
209
+ hovertemplate=(
210
+ "d_model=%{customdata[0]:.0f}<br>"
211
+ "layers=%{customdata[1]:.0f} · heads=%{customdata[2]:.0f}<br>"
212
+ "ffn_mult=%{customdata[3]:.1f} · mem=%{customdata[4]:.0f} · drop=%{customdata[5]:.2f}<br>"
213
+ "species=%{customdata[6]:.0f}<br>"
214
+ "fitness=%{customdata[7]:.4f}<br>"
215
+ "accuracy=%{customdata[8]:.3f}<extra></extra>"
216
+ )
217
  )
218
 
219
+ # Faint sphere
220
+ u = np.linspace(0, 2*np.pi, 64)
221
+ v = np.linspace(0, np.pi, 32)
222
+ r = 1.15
223
+ xs = r*np.outer(np.cos(u), np.sin(v))
224
+ ys = r*np.outer(np.sin(u), np.sin(v))
225
+ zs = r*np.outer(np.ones_like(u), np.cos(v))
226
+ sphere = go.Surface(x=xs, y=ys, z=zs, opacity=0.06, showscale=False)
227
 
228
  layout = go.Layout(
229
  title=f"Evo Sphere — Generation {gen_idx}",
230
  scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)),
231
  margin=dict(l=0, r=0, t=40, b=0),
232
+ showlegend=False,
233
+ height=680
234
  )
235
  return go.Figure(data=[sphere, scatter], layout=layout)
236
 
237
+ def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
238
+ # history items: (gen, best_fitness, best_acc or NaN)
239
  xs = [h[0] for h in history]
240
+ if metric == "Accuracy":
241
+ ys = [h[2] if (h[2] == h[2]) else None for h in history] # keep None for Demo
242
+ title, ylab = "Best Accuracy per Generation", "Accuracy"
243
+ else:
244
+ ys = [h[1] for h in history]
245
+ title, ylab = "Best Fitness per Generation", "Fitness (lower is better)"
246
  fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers")])
247
+ fig.update_layout(title=title, xaxis_title="Generation", yaxis_title=ylab,
248
+ margin=dict(l=30,r=10,t=40,b=30), height=360)
 
249
  return fig
250
 
251
  def approx_params(g: Genome) -> int:
 
 
252
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
253
+ total = per_layer * g.n_layers + 1000 * g.memory_tokens
 
254
  return int(total)
255
 
256
  # =========================
 
263
  self.stop_flag = False
264
  self.state: Dict[str, Any] = {}
265
 
266
+ def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms, metric_choice):
267
  rng = random.Random(int(seed))
268
  self.stop_flag = False
269
  self.running = True
 
271
  pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
272
  # initial eval
273
  for g in pop:
274
+ fit, acc = evaluate_genome(g, dataset, explore)
275
+ g.fitness, g.acc = fit, acc
276
 
277
+ history: List[Tuple[int,float,float]] = [] # (gen, best_fitness, best_acc or NaN)
278
  best_overall: Optional[Genome] = None
279
 
280
  for gen in range(1, generations+1):
281
  if self.stop_flag: break
282
 
283
+ # Selection (tournament)
284
  k = max(2, int(2 + exploit * 5))
285
  parents = []
286
  for _ in range(pop_size):
 
290
  # Reproduce
291
  children = []
292
  for i in range(0, pop_size, 2):
293
+ a = parents[i]; b = parents[(i+1) % pop_size]
 
294
  child1 = mutate(crossover(a,b,rng), rng, mutation_rate)
295
  child2 = mutate(crossover(b,a,rng), rng, mutation_rate)
296
  children.extend([child1, child2])
297
  children = children[:pop_size]
298
 
299
+ # Evaluate children
300
  for c in children:
301
+ fit, acc = evaluate_genome(c, dataset, explore)
302
+ c.fitness, c.acc = fit, acc
303
 
304
  # Elitism
305
  elite_n = max(1, pop_size // 10)
 
313
  if best_overall is None or best.fitness < best_overall.fitness:
314
  best_overall = best
315
 
316
+ history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
317
 
318
  # Viz snapshot
319
  P = np.stack([g.vector() for g in pop], axis=0)
320
  P3 = sphere_project(P)
321
  sphere_fig = make_sphere_figure(P3, pop, gen)
322
+ hist_fig = make_history_figure(history, metric_choice)
323
  top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
324
  top_table = [
325
  {
326
  "gen": gen,
327
  "fitness": round(t.fitness, 4),
328
+ "accuracy": (None if t.acc is None else round(float(t.acc), 4)),
329
  "d_model": t.d_model,
330
  "layers": t.n_layers,
331
  "heads": t.n_heads,
 
345
  "top": top_table,
346
  "best": best_card,
347
  "gen": gen,
348
+ "dataset": dataset,
349
+ "metric": metric_choice
350
  }
351
 
352
  time.sleep(max(0.0, pace_ms/1000.0))
 
358
  t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
359
  t.start()
360
 
361
+ def stop(self): self.stop_flag = True
 
362
 
363
  runner = EvoRunner()
364
 
365
  # =========================
366
+ # UI CALLBACKS
367
  # =========================
368
+ def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
369
+ runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
370
  return (gr.update(interactive=False), gr.update(interactive=True))
371
 
372
  def stop_evo():
 
377
  with runner.lock:
378
  s = runner.state.copy()
379
  sphere = s.get("sphere", go.Figure())
380
+ history = s.get("history", go.Figure()) # already built by runner
381
  best = s.get("best", {})
382
  gen = s.get("gen", 0)
383
  dataset = s.get("dataset", "Demo (Surrogate)")
384
  top = s.get("top", [])
385
+ # Stats text
386
  if best:
387
+ acc_txt = "—" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
388
  stats_md = (
389
  f"**Dataset:** {dataset} \n"
390
  f"**Generation:** {gen} \n"
391
  f"**Best fitness:** {best.get('fitness','–')} \n"
392
+ f"**Best accuracy:** {acc_txt} \n"
393
  f"**Config:** d_model={best.get('d_model')} · layers={best.get('layers')} · "
394
  f"heads={best.get('heads')} · ffn_mult={best.get('ffn_mult')} · mem={best.get('mem')} · "
395
  f"dropout={best.get('dropout')} \n"
 
410
  return path
411
 
412
  # =========================
413
+ # BUILD UI (bigger sphere, metric toggle)
414
  # =========================
415
  with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
416
  with gr.Column(elem_id="header-card"):
417
  gr.Markdown(
418
+ "# Evo Playground — Live Evolution of Transformer Architectures\n"
419
+ "Tune the search, watch the population converge, and track **accuracy** in real time (PIQA/HellaSwag)."
 
420
  )
421
 
422
  with gr.Row():
 
425
  with gr.Group():
426
  dataset = gr.Dropdown(
427
  label="Dataset",
428
+ choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
429
  value="Demo (Surrogate)",
430
+ info="PIQA/HellaSwag compute real proxy accuracy; Demo uses a fast surrogate."
431
  )
432
  pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
433
  gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
 
437
  exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation")
438
  seed = gr.Number(value=42, label="Seed", precision=0)
439
  pace = gr.Slider(0, 1000, value=120, step=10, label="Pace (ms between gens)")
440
+ metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy", label="History Metric")
441
  with gr.Row():
442
  start = gr.Button("▶ Start Evolution", variant="primary")
443
  stop = gr.Button("⏹ Stop", variant="secondary")
444
 
445
  with gr.Group(elem_id="right-card"):
446
+ stats_md = gr.Markdown("Waiting…", elem_id="stats-md")
447
  export_btn = gr.Button("Export Snapshot (JSON)")
448
  export_file = gr.File(label="Download snapshot", visible=False)
449
 
 
452
  with gr.Group(elem_id="viz-card"):
453
  sphere_plot = gr.Plot(label="Evolution Sphere")
454
  with gr.Group(elem_id="viz-card"):
455
+ hist_plot = gr.Plot(label="Progress")
456
  with gr.Group(elem_id="table-card"):
457
  top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
458
 
459
  # Wiring
460
+ start.click(start_evo, [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice], [start, stop])
461
  stop.click(stop_evo, [], [start, stop])
462
  export_btn.click(export_snapshot, [], [export_file])
463
 
464
+ # Initial paint
465
  demo.load(poll_state, None, [sphere_plot, hist_plot, stats_md, top_df])
466
 
467
  # Continuous polling (every 0.7s)