HemanM commited on
Commit
e4791e2
·
verified ·
1 Parent(s): 7a05320

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -123
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — Minimal, pro UI with big transparent sphere, accuracy, and robust rendering
2
  import math, json, random, time, threading
3
  from dataclasses import dataclass, asdict
4
  from typing import List, Tuple, Dict, Any, Optional
@@ -16,22 +16,28 @@ import torch.optim as optim
16
 
17
  from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
- # ---------- Minimal style ----------
 
 
20
  CUSTOM_CSS = """
21
- :root { --radius: 14px; --fg:#0f172a; --muted:#64748b; --line:#e5e7eb; }
22
- * { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Arial, "Noto Sans", "Apple Color Emoji", "Segoe UI Emoji"; }
23
- .gradio-container { max-width: 1180px !important; }
24
- #header { border-radius: var(--radius); padding: 8px 6px; }
25
  h1, h2, h3, .gr-markdown { color: var(--fg); }
26
- .gr-button { border-radius: 10px }
27
- .controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
28
  .panel { padding: 10px; }
29
- #stats { font-weight: 300; color: var(--fg); }
30
  #stats strong { font-weight: 500; }
31
- .small { font-size: 13px; color: var(--muted); }
 
 
32
  """
33
 
34
- # ---------- Genome ----------
 
 
35
  @dataclass
36
  class Genome:
37
  d_model: int
@@ -86,11 +92,12 @@ def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
86
  memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
87
  dropout = a.dropout if rng.random()<0.5 else b.dropout,
88
  species = a.species if rng.random()<0.5 else b.species,
89
- fitness = float("inf"),
90
- acc = None
91
  )
92
 
93
- # ---------- Proxy fitness ----------
 
 
94
  def rastrigin(x: np.ndarray) -> float:
95
  A, n = 10.0, x.shape[0]
96
  return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
@@ -107,9 +114,9 @@ class TinyMLP(nn.Module):
107
  )
108
  def forward(self, x): return self.net(x).squeeze(-1)
109
 
 
110
  @lru_cache(maxsize=4)
111
  def _cached_dataset(name: str):
112
- # Defensive: if loading fails (e.g., datasets version / no internet), return None
113
  try:
114
  if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
115
  if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
@@ -117,10 +124,9 @@ def _cached_dataset(name: str):
117
  return None
118
  return None
119
 
120
- def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
121
  data = _cached_dataset(dataset_name)
122
  if data is None:
123
- # Fallback to surrogate to keep the UI alive
124
  v = genome.vector() * 2 - 1
125
  base = rastrigin(v)
126
  parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
@@ -132,17 +138,14 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
132
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
133
  Xva = hash_vectorize(Xva_txt, n_features=nfeat, seed=5678)
134
 
135
- Xtr_t = torch.from_numpy(Xtr)
136
- ytr_t = torch.from_numpy(ytr.astype(np.float32))
137
- Xva_t = torch.from_numpy(Xva)
138
- yva_t = torch.from_numpy(yva.astype(np.float32))
139
 
140
  model = TinyMLP(nfeat, genome).to(device)
141
  opt = optim.AdamW(model.parameters(), lr=2e-3)
142
  lossf = nn.BCEWithLogitsLoss()
143
 
144
- model.train()
145
- steps, bs, N = 120, 256, Xtr_t.size(0)
146
  for _ in range(steps):
147
  idx = torch.randint(0, N, (bs,))
148
  xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
@@ -157,12 +160,12 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
157
  probs = torch.sigmoid(logits).cpu().numpy()
158
 
159
  if dataset_name.startswith("PIQA"):
160
- probs = probs.reshape(-1, 2); yva2 = yva.reshape(-1, 2)
161
  pred = (probs[:,0] > probs[:,1]).astype(np.int64)
162
  truth = (yva2[:,0] == 1).astype(np.int64)
163
  acc = float((pred == truth).mean())
164
  else:
165
- probs = probs.reshape(-1, 4); yva2 = yva.reshape(-1, 4)
166
  pred = probs.argmax(axis=1); truth = yva2.argmax(axis=1)
167
  acc = float((pred == truth).mean())
168
 
@@ -178,39 +181,53 @@ def evaluate_genome(genome: Genome, dataset: str, explore: float):
178
  parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
179
  noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
180
  return float(base + parsimony + noise), None
181
- if dataset.startswith("PIQA"):
182
- return _train_eval_proxy(genome, "PIQA", explore)
183
- if dataset.startswith("HellaSwag"):
184
- return _train_eval_proxy(genome, "HellaSwag", explore)
185
- # Fallback
186
  v = genome.vector() * 2 - 1
187
  return float(rastrigin(v)), None
188
 
189
- # ---------- Viz helpers (bigger, transparent sphere) ----------
190
- PALETTE = ["#111827", "#334155", "#475569", "#64748b", "#94a3b8"] # muted grayscale/blue
191
- BG = "white"
 
 
 
192
 
193
  def sphere_project(points: np.ndarray) -> np.ndarray:
194
  rng = np.random.RandomState(42)
195
  W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
196
  Y = points @ W
197
  norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
198
- return (Y / norms) * 1.2
199
 
200
- def _species_colors(species: np.ndarray) -> list:
201
- return [PALETTE[int(s) % len(PALETTE)] for s in species]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
204
- species = np.array([g.species for g in genomes])
205
- colors = _species_colors(species)
206
  custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
207
  g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
208
  for g in genomes], dtype=np.float32)
209
-
210
  scatter = go.Scatter3d(
211
  x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
212
  mode='markers',
213
- marker=dict(size=6.5, color=colors, opacity=0.92),
214
  customdata=custom,
215
  hovertemplate=(
216
  "<b>Genome</b><br>"
@@ -221,35 +238,10 @@ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int
221
  "accuracy=%{customdata[8]:.3f}<extra></extra>"
222
  )
223
  )
224
-
225
- # Subtle, large sphere
226
- u = np.linspace(0, 2*np.pi, 72)
227
- v = np.linspace(0, np.pi, 36)
228
- r = 1.2
229
- xs = r*np.outer(np.cos(u), np.sin(v))
230
- ys = r*np.outer(np.sin(u), np.sin(v))
231
- zs = r*np.outer(np.ones_like(u), np.cos(v))
232
- sphere = go.Surface(
233
- x=xs, y=ys, z=zs,
234
- opacity=0.08,
235
- showscale=False,
236
- colorscale=[[0, "#cbd5e1"], [1, "#cbd5e1"]],
237
- hoverinfo="skip"
238
- )
239
-
240
- layout = go.Layout(
241
- paper_bgcolor=BG, plot_bgcolor=BG,
242
- title=f"Evo Architecture Sphere — Gen {gen_idx}",
243
- scene=dict(
244
- xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
245
- bgcolor=BG
246
- ),
247
- margin=dict(l=0, r=0, t=36, b=0),
248
- showlegend=False,
249
- height=720,
250
- font=dict(family="Inter, Arial, sans-serif", size=14)
251
- )
252
- return go.Figure(data=[sphere, scatter], layout=layout)
253
 
254
  def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
255
  xs = [h[0] for h in history]
@@ -259,37 +251,36 @@ def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> g
259
  else:
260
  ys = [h[1] for h in history]
261
  title, ylab = "Best Fitness per Generation", "Fitness (↓ better)"
262
- fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers", line=dict(width=2))])
263
  fig.update_layout(
264
- paper_bgcolor=BG, plot_bgcolor=BG,
265
  title=title, xaxis_title="Generation", yaxis_title=ylab,
266
- margin=dict(l=30, r=10, t=36, b=30),
267
- height=340,
268
- font=dict(family="Inter, Arial, sans-serif", size=14)
269
  )
 
270
  return fig
271
 
272
  def fig_to_html(fig: go.Figure) -> str:
273
- # Inline Plotly JS to avoid any CDN/network dependency in Spaces
274
- return pio.to_html(
275
- fig,
276
- include_plotlyjs=True, # IMPORTANT: inline JS so the sphere always renders
277
- full_html=False,
278
- config=dict(displaylogo=False)
279
- )
280
 
281
  def approx_params(g: Genome) -> int:
282
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
283
  total = per_layer * g.n_layers + 1000 * g.memory_tokens
284
  return int(total)
285
 
286
- # ---------- Orchestrator ----------
 
 
287
  class EvoRunner:
288
  def __init__(self):
289
  self.lock = threading.Lock()
290
  self.running = False
291
  self.stop_flag = False
292
  self.state: Dict[str, Any] = {}
 
 
 
 
293
 
294
  def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms, metric_choice):
295
  rng = random.Random(int(seed))
@@ -302,16 +293,12 @@ class EvoRunner:
302
  g.fitness, g.acc = fit, acc
303
 
304
  history: List[Tuple[int,float,float]] = []
305
- best_overall: Optional[Genome] = None
306
 
307
  for gen in range(1, generations+1):
308
  if self.stop_flag: break
309
 
310
  k = max(2, int(2 + exploit * 5))
311
- parents = []
312
- for _ in range(pop_size):
313
- sample = rng.sample(pop, k=k)
314
- parents.append(min(sample, key=lambda x: x.fitness))
315
 
316
  children = []
317
  for i in range(0, pop_size, 2):
@@ -331,8 +318,6 @@ class EvoRunner:
331
  pop[-elite_n:] = elites
332
 
333
  best = min(pop, key=lambda x: x.fitness)
334
- if best_overall is None or best.fitness < best_overall.fitness: best_overall = best
335
-
336
  history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
337
 
338
  P = np.stack([g.vector() for g in pop], axis=0)
@@ -341,22 +326,14 @@ class EvoRunner:
341
  hist_fig = make_history_figure(history, metric_choice)
342
 
343
  top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
344
- top_table = [
345
- {
346
- "gen": gen,
347
- "fitness": round(t.fitness, 4),
348
- "accuracy": (None if t.acc is None else round(float(t.acc), 4)),
349
- "d_model": t.d_model,
350
- "layers": t.n_layers,
351
- "heads": t.n_heads,
352
- "ffn_mult": t.ffn_mult,
353
- "mem": t.memory_tokens,
354
- "dropout": t.dropout,
355
- "species": t.species,
356
- "params_approx": approx_params(t)
357
- } for t in top
358
- ]
359
- best_card = top_table[0] if len(top_table) else {}
360
 
361
  with self.lock:
362
  self.state = {
@@ -376,24 +353,42 @@ class EvoRunner:
376
  if self.running: return
377
  t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
378
  t.start()
 
379
  def stop(self): self.stop_flag = True
380
 
 
 
 
 
 
 
 
 
 
381
  runner = EvoRunner()
382
 
383
- # ---------- UI callbacks ----------
 
 
384
  def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
385
  runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
386
- return (gr.update(interactive=False), gr.update(interactive=True))
387
 
388
  def stop_evo():
389
  runner.stop()
390
- return (gr.update(interactive=True), gr.update(interactive=False))
 
 
 
 
 
 
391
 
392
  def poll_state():
393
  with runner.lock:
394
  s = runner.state.copy()
395
- sphere_html = s.get("sphere_html", "")
396
- history_html = s.get("history_html", "")
397
  best = s.get("best", {})
398
  gen = s.get("gen", 0)
399
  dataset = s.get("dataset", "Demo (Surrogate)")
@@ -411,7 +406,7 @@ def poll_state():
411
  f"**~Params (rough):** {best.get('params_approx'):,}"
412
  )
413
  else:
414
- stats_md = "Waiting… click **Start Evolution**."
415
  df = pd.DataFrame(top)
416
  return sphere_html, history_html, stats_md, df
417
 
@@ -424,10 +419,12 @@ def export_snapshot():
424
  f.write(payload)
425
  return path
426
 
427
- # ---------- Build UI (minimal layout) ----------
428
- with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
 
 
429
  with gr.Column(elem_id="header"):
430
- gr.Markdown("## Evo Playground — Minimal Live Evolution (PIQA / HellaSwag accuracy)")
431
 
432
  with gr.Row():
433
  with gr.Column(scale=1, elem_classes=["controls"]):
@@ -436,7 +433,7 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
436
  label="Dataset",
437
  choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
438
  value="Demo (Surrogate)",
439
- info="PIQA/HellaSwag compute real proxy accuracy; Demo is a fast surrogate."
440
  )
441
  pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
442
  gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
@@ -445,14 +442,16 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
445
  explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration")
446
  exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation")
447
  seed = gr.Number(value=42, label="Seed", precision=0)
448
- pace = gr.Slider(0, 1000, value=120, step=10, label="Pace (ms between gens)")
449
  metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy", label="History Metric")
 
450
  with gr.Row():
451
- start = gr.Button("▶ Start Evolution", variant="primary")
452
- stop = gr.Button("⏹ Stop", variant="secondary")
 
453
 
454
  with gr.Group(elem_classes=["panel"]):
455
- stats_md = gr.Markdown("Waiting…", elem_id="stats")
456
 
457
  with gr.Group(elem_classes=["panel"]):
458
  export_btn = gr.Button("Export Snapshot (JSON)")
@@ -466,12 +465,13 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
466
  with gr.Group(elem_classes=["panel"]):
467
  top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
468
 
469
- # Wiring
470
- start.click(start_evo, [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice], [start, stop])
471
- stop.click(stop_evo, [], [start, stop])
 
472
  export_btn.click(export_snapshot, [], [export_file])
473
 
474
- # Initial paint + polling
475
  demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
476
  gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
477
 
 
1
+ # app.py — Minimal dark UI, default idle sphere, Clear button, inline Plotly
2
  import math, json, random, time, threading
3
  from dataclasses import dataclass, asdict
4
  from typing import List, Tuple, Dict, Any, Optional
 
16
 
17
  from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
+ # =========================
20
+ # STYLE — calm, dark, thin
21
+ # =========================
22
  CUSTOM_CSS = """
23
+ :root { --radius: 14px; --fg:#E5E7EB; --muted:#94A3B8; --line:#111827; --bg:#0F1A24; }
24
+ * { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Arial; font-weight: 300; }
25
+ .gradio-container { max-width: 1140px !important; background: var(--bg); }
26
+ #header { border-radius: var(--radius); padding: 6px 2px; }
27
  h1, h2, h3, .gr-markdown { color: var(--fg); }
28
+ .gr-button { border-radius: 10px; }
29
+ .controls .gr-group, .panel { border: 1px solid #1f2b36; border-radius: var(--radius); background: #0c161f; }
30
  .panel { padding: 10px; }
31
+ #stats { color: var(--fg); }
32
  #stats strong { font-weight: 500; }
33
+ .small { font-size: 12px; color: var(--muted); }
34
+ label, .gradio-container * { color: var(--fg); }
35
+ input, textarea, select { color: var(--fg) !important; }
36
  """
37
 
38
+ # =========================
39
+ # GENOME
40
+ # =========================
41
  @dataclass
42
  class Genome:
43
  d_model: int
 
92
  memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
93
  dropout = a.dropout if rng.random()<0.5 else b.dropout,
94
  species = a.species if rng.random()<0.5 else b.species,
95
+ fitness = float("inf"), acc=None
 
96
  )
97
 
98
+ # =========================
99
+ # PROXY FITNESS
100
+ # =========================
101
  def rastrigin(x: np.ndarray) -> float:
102
  A, n = 10.0, x.shape[0]
103
  return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
 
114
  )
115
  def forward(self, x): return self.net(x).squeeze(-1)
116
 
117
+ from functools import lru_cache
118
  @lru_cache(maxsize=4)
119
  def _cached_dataset(name: str):
 
120
  try:
121
  if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
122
  if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
 
124
  return None
125
  return None
126
 
127
+ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str="cpu"):
128
  data = _cached_dataset(dataset_name)
129
  if data is None:
 
130
  v = genome.vector() * 2 - 1
131
  base = rastrigin(v)
132
  parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
 
138
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
139
  Xva = hash_vectorize(Xva_txt, n_features=nfeat, seed=5678)
140
 
141
+ Xtr_t = torch.from_numpy(Xtr); ytr_t = torch.from_numpy(ytr.astype(np.float32))
142
+ Xva_t = torch.from_numpy(Xva); yva_t = torch.from_numpy(yva.astype(np.float32))
 
 
143
 
144
  model = TinyMLP(nfeat, genome).to(device)
145
  opt = optim.AdamW(model.parameters(), lr=2e-3)
146
  lossf = nn.BCEWithLogitsLoss()
147
 
148
+ model.train(); steps, bs, N = 120, 256, Xtr_t.size(0)
 
149
  for _ in range(steps):
150
  idx = torch.randint(0, N, (bs,))
151
  xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
 
160
  probs = torch.sigmoid(logits).cpu().numpy()
161
 
162
  if dataset_name.startswith("PIQA"):
163
+ probs = probs.reshape(-1,2); yva2 = yva.reshape(-1,2)
164
  pred = (probs[:,0] > probs[:,1]).astype(np.int64)
165
  truth = (yva2[:,0] == 1).astype(np.int64)
166
  acc = float((pred == truth).mean())
167
  else:
168
+ probs = probs.reshape(-1,4); yva2 = yva.reshape(-1,4)
169
  pred = probs.argmax(axis=1); truth = yva2.argmax(axis=1)
170
  acc = float((pred == truth).mean())
171
 
 
181
  parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
182
  noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
183
  return float(base + parsimony + noise), None
184
+ if dataset.startswith("PIQA"): return _train_eval_proxy(genome, "PIQA", explore)
185
+ if dataset.startswith("HellaSwag"): return _train_eval_proxy(genome, "HellaSwag", explore)
 
 
 
186
  v = genome.vector() * 2 - 1
187
  return float(rastrigin(v)), None
188
 
189
+ # =========================
190
+ # VIZ big transparent sphere
191
+ # =========================
192
+ BG = "#0F1A24"
193
+ DOT = "#93C5FD" # soft blue dot
194
+ SPHERE = "#cbd5e1" # subtle sphere tint
195
 
196
  def sphere_project(points: np.ndarray) -> np.ndarray:
197
  rng = np.random.RandomState(42)
198
  W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
199
  Y = points @ W
200
  norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
201
+ return (Y / norms) * 1.22
202
 
203
+ def make_idle_sphere() -> go.Figure:
204
+ # empty scatter, only sphere
205
+ u = np.linspace(0, 2*np.pi, 72)
206
+ v = np.linspace(0, np.pi, 36)
207
+ r = 1.22
208
+ xs = r*np.outer(np.cos(u), np.sin(v))
209
+ ys = r*np.outer(np.sin(u), np.sin(v))
210
+ zs = r*np.outer(np.ones_like(u), np.cos(v))
211
+ sphere = go.Surface(x=xs, y=ys, z=zs, opacity=0.06, showscale=False,
212
+ colorscale=[[0, SPHERE],[1, SPHERE]], hoverinfo="skip")
213
+ layout = go.Layout(
214
+ paper_bgcolor=BG, plot_bgcolor=BG,
215
+ title="Architecture Sphere (idle)", titlefont=dict(color="#E5E7EB"),
216
+ scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), bgcolor=BG),
217
+ margin=dict(l=0, r=0, t=36, b=0), showlegend=False, height=720,
218
+ font=dict(family="Inter, Arial, sans-serif", size=14, color="#E5E7EB")
219
+ )
220
+ return go.Figure(data=[sphere], layout=layout)
221
 
222
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
223
+ # single-color dots for a sober look
 
224
  custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
225
  g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
226
  for g in genomes], dtype=np.float32)
 
227
  scatter = go.Scatter3d(
228
  x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
229
  mode='markers',
230
+ marker=dict(size=7.2, color=DOT, opacity=0.92),
231
  customdata=custom,
232
  hovertemplate=(
233
  "<b>Genome</b><br>"
 
238
  "accuracy=%{customdata[8]:.3f}<extra></extra>"
239
  )
240
  )
241
+ idle = make_idle_sphere()
242
+ layout = idle.layout.update(title=f"Evo Architecture Sphere — Gen {gen_idx}")
243
+ fig = go.Figure(data=idle.data + (scatter,), layout=layout)
244
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
247
  xs = [h[0] for h in history]
 
251
  else:
252
  ys = [h[1] for h in history]
253
  title, ylab = "Best Fitness per Generation", "Fitness (↓ better)"
254
+ fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers", line=dict(width=2), marker=dict(color=DOT))])
255
  fig.update_layout(
256
+ paper_bgcolor=BG, plot_bgcolor=BG, font=dict(color="#E5E7EB"),
257
  title=title, xaxis_title="Generation", yaxis_title=ylab,
258
+ margin=dict(l=30, r=10, t=36, b=30), height=340
 
 
259
  )
260
+ fig.update_xaxes(gridcolor="#1f2b36"); fig.update_yaxes(gridcolor="#1f2b36")
261
  return fig
262
 
263
  def fig_to_html(fig: go.Figure) -> str:
264
+ return pio.to_html(fig, include_plotlyjs=True, full_html=False, config=dict(displaylogo=False))
 
 
 
 
 
 
265
 
266
  def approx_params(g: Genome) -> int:
267
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
268
  total = per_layer * g.n_layers + 1000 * g.memory_tokens
269
  return int(total)
270
 
271
+ # =========================
272
+ # RUNNER
273
+ # =========================
274
  class EvoRunner:
275
  def __init__(self):
276
  self.lock = threading.Lock()
277
  self.running = False
278
  self.stop_flag = False
279
  self.state: Dict[str, Any] = {}
280
+ # seed the idle sphere immediately
281
+ idle = fig_to_html(make_idle_sphere())
282
+ self.state = {"sphere_html": idle, "history_html": fig_to_html(make_history_figure([], "Accuracy")),
283
+ "top": [], "best": {}, "gen": 0, "dataset": "Demo (Surrogate)", "metric": "Accuracy"}
284
 
285
  def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms, metric_choice):
286
  rng = random.Random(int(seed))
 
293
  g.fitness, g.acc = fit, acc
294
 
295
  history: List[Tuple[int,float,float]] = []
 
296
 
297
  for gen in range(1, generations+1):
298
  if self.stop_flag: break
299
 
300
  k = max(2, int(2 + exploit * 5))
301
+ parents = [min(rng.sample(pop, k=k), key=lambda x: x.fitness) for _ in range(pop_size)]
 
 
 
302
 
303
  children = []
304
  for i in range(0, pop_size, 2):
 
318
  pop[-elite_n:] = elites
319
 
320
  best = min(pop, key=lambda x: x.fitness)
 
 
321
  history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
322
 
323
  P = np.stack([g.vector() for g in pop], axis=0)
 
326
  hist_fig = make_history_figure(history, metric_choice)
327
 
328
  top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
329
+ top_table = [{
330
+ "gen": gen, "fitness": round(t.fitness, 4),
331
+ "accuracy": (None if t.acc is None else round(float(t.acc), 4)),
332
+ "d_model": t.d_model, "layers": t.n_layers, "heads": t.n_heads,
333
+ "ffn_mult": t.ffn_mult, "mem": t.memory_tokens, "dropout": t.dropout,
334
+ "params_approx": approx_params(t)
335
+ } for t in top]
336
+ best_card = top_table[0] if top_table else {}
 
 
 
 
 
 
 
 
337
 
338
  with self.lock:
339
  self.state = {
 
353
  if self.running: return
354
  t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
355
  t.start()
356
+
357
  def stop(self): self.stop_flag = True
358
 
359
+ def clear(self):
360
+ # stop and reset to idle sphere
361
+ self.stop_flag = True
362
+ idle = fig_to_html(make_idle_sphere())
363
+ with self.lock:
364
+ self.running = False
365
+ self.state = {"sphere_html": idle, "history_html": fig_to_html(make_history_figure([], "Accuracy")),
366
+ "top": [], "best": {}, "gen": 0, "dataset": "Demo (Surrogate)", "metric": "Accuracy"}
367
+
368
  runner = EvoRunner()
369
 
370
+ # =========================
371
+ # UI CALLBACKS
372
+ # =========================
373
  def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
374
  runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
375
+ return (gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False))
376
 
377
  def stop_evo():
378
  runner.stop()
379
+ return (gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True))
380
+
381
+ def clear_evo():
382
+ runner.clear()
383
+ # return updated visuals + reset buttons
384
+ sphere_html, history_html, stats_md, df = poll_state()
385
+ return sphere_html, history_html, stats_md, df, gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True)
386
 
387
  def poll_state():
388
  with runner.lock:
389
  s = runner.state.copy()
390
+ sphere_html = s.get("sphere_html", fig_to_html(make_idle_sphere()))
391
+ history_html = s.get("history_html", fig_to_html(make_history_figure([], "Accuracy")))
392
  best = s.get("best", {})
393
  gen = s.get("gen", 0)
394
  dataset = s.get("dataset", "Demo (Surrogate)")
 
406
  f"**~Params (rough):** {best.get('params_approx'):,}"
407
  )
408
  else:
409
+ stats_md = "Ready. Press **Start** to evolve, or **Clear** anytime."
410
  df = pd.DataFrame(top)
411
  return sphere_html, history_html, stats_md, df
412
 
 
419
  f.write(payload)
420
  return path
421
 
422
+ # =========================
423
+ # BUILD UI
424
+ # =========================
425
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
426
  with gr.Column(elem_id="header"):
427
+ gr.Markdown("### Evo Playground — Live Evolution (clean dark)")
428
 
429
  with gr.Row():
430
  with gr.Column(scale=1, elem_classes=["controls"]):
 
433
  label="Dataset",
434
  choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
435
  value="Demo (Surrogate)",
436
+ info="PIQA/HellaSwag compute proxy accuracy; Demo is a fast surrogate."
437
  )
438
  pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
439
  gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
 
442
  explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration")
443
  exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation")
444
  seed = gr.Number(value=42, label="Seed", precision=0)
445
+ pace = gr.Slider(0, 1000, value=120, step=10, label="Pace (ms)")
446
  metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy", label="History Metric")
447
+
448
  with gr.Row():
449
+ start = gr.Button("▶ Start", variant="primary")
450
+ stop = gr.Button("⏹ Stop", variant="secondary", interactive=False)
451
+ clear = gr.Button("↺ Clear", variant="secondary")
452
 
453
  with gr.Group(elem_classes=["panel"]):
454
+ stats_md = gr.Markdown("Ready. Press **Start** to evolve, or **Clear** anytime.", elem_id="stats")
455
 
456
  with gr.Group(elem_classes=["panel"]):
457
  export_btn = gr.Button("Export Snapshot (JSON)")
 
465
  with gr.Group(elem_classes=["panel"]):
466
  top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
467
 
468
+ # wiring
469
+ start.click(start_evo, [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice], [start, stop, clear])
470
+ stop.click(stop_evo, [], [start, stop, clear])
471
+ clear.click(clear_evo, [], [sphere_html, hist_html, stats_md, top_df, start, stop, clear])
472
  export_btn.click(export_snapshot, [], [export_file])
473
 
474
+ # initial paint + polling
475
  demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
476
  gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
477