HemanM commited on
Commit
d4bbba0
·
verified ·
1 Parent(s): 3796a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -102
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py
2
  import math, json, random, time, threading
3
  from dataclasses import dataclass, asdict
4
  from typing import List, Tuple, Dict, Any, Optional
@@ -6,33 +6,32 @@ from functools import lru_cache
6
 
7
  import numpy as np
8
  import plotly.graph_objs as go
 
9
  import gradio as gr
10
  import pandas as pd
11
 
12
- # Proxy fitness deps
13
  import torch
14
  import torch.nn as nn
15
  import torch.optim as optim
16
 
17
  from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
- # =========================
20
- # UX THEME & STYLES (cleaner, pro)
21
- # =========================
22
  CUSTOM_CSS = """
23
- :root { --radius-2xl: 18px; }
24
- .gradio-container {max-width: 1320px !important}
25
- #header-card, #viz-card, #right-card, #table-card {
26
- border-radius: var(--radius-2xl);
27
- box-shadow: 0 6px 24px rgba(0,0,0,0.06);
28
- }
29
- .gr-button {border-radius: 12px}
30
- #stats-md {font-size: 15px;}
 
 
 
31
  """
32
 
33
- # =========================
34
- # GENOME & EVOLUTION CORE
35
- # =========================
36
  @dataclass
37
  class Genome:
38
  d_model: int
@@ -43,7 +42,7 @@ class Genome:
43
  dropout: float
44
  species: int = 0
45
  fitness: float = float("inf")
46
- acc: Optional[float] = None # accuracy when dataset is PIQA/HS
47
 
48
  def vector(self) -> np.ndarray:
49
  return np.array([
@@ -91,9 +90,7 @@ def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
91
  acc = None
92
  )
93
 
94
- # =========================
95
- # PROXY FITNESS (Phase 2a)
96
- # =========================
97
  def rastrigin(x: np.ndarray) -> float:
98
  A, n = 10.0, x.shape[0]
99
  return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
@@ -116,14 +113,11 @@ def _cached_dataset(name: str):
116
  if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
117
  return None
118
 
119
- def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu") -> Tuple[float, Optional[float]]:
120
- """Returns (fitness, accuracy or None)."""
121
  data = _cached_dataset(dataset_name)
122
  if data is None:
123
- # Demo path handled elsewhere
124
  v = genome.vector() * 2 - 1
125
  return float(rastrigin(v)), None
126
-
127
  Xtr_txt, ytr, Xva_txt, yva = data
128
  nfeat = 4096
129
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
@@ -139,8 +133,7 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
139
  lossf = nn.BCEWithLogitsLoss()
140
 
141
  model.train()
142
- steps, bs = 120, 256
143
- N = Xtr_t.size(0)
144
  for _ in range(steps):
145
  idx = torch.randint(0, N, (bs,))
146
  xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
@@ -169,7 +162,7 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
169
  fitness = (1.0 - acc) + parsimony + noise
170
  return float(max(0.0, min(1.5, fitness))), float(acc)
171
 
172
- def evaluate_genome(genome: Genome, dataset: str, explore: float) -> Tuple[float, Optional[float]]:
173
  if dataset == "Demo (Surrogate)":
174
  v = genome.vector() * 2 - 1
175
  base = rastrigin(v)
@@ -180,23 +173,30 @@ def evaluate_genome(genome: Genome, dataset: str, explore: float) -> Tuple[float
180
  return _train_eval_proxy(genome, "PIQA", explore)
181
  if dataset.startswith("HellaSwag"):
182
  return _train_eval_proxy(genome, "HellaSwag", explore)
183
- # fallback
184
  v = genome.vector() * 2 - 1
185
  return float(rastrigin(v)), None
186
 
187
- # =========================
188
- # PROJECTION & VIZ (bigger, transparent sphere, rich hover)
189
- # =========================
 
190
  def sphere_project(points: np.ndarray) -> np.ndarray:
191
  rng = np.random.RandomState(42)
192
  W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
193
  Y = points @ W
194
  norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
195
- return (Y / norms) * 1.15 # slightly larger radius
 
 
 
 
 
 
 
196
 
197
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
198
  species = np.array([g.species for g in genomes])
199
- # Prepare hover with all fields
200
  custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
201
  g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
202
  for g in genomes], dtype=np.float32)
@@ -204,11 +204,11 @@ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int
204
  scatter = go.Scatter3d(
205
  x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
206
  mode='markers',
207
- marker=dict(size=7, color=species, opacity=0.95),
208
  customdata=custom,
209
  hovertemplate=(
210
- "d_model=%{customdata[0]:.0f}<br>"
211
- "layers=%{customdata[1]:.0f} · heads=%{customdata[2]:.0f}<br>"
212
  "ffn_mult=%{customdata[3]:.1f} · mem=%{customdata[4]:.0f} · drop=%{customdata[5]:.2f}<br>"
213
  "species=%{customdata[6]:.0f}<br>"
214
  "fitness=%{customdata[7]:.4f}<br>"
@@ -216,46 +216,63 @@ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int
216
  )
217
  )
218
 
219
- # Faint sphere
220
- u = np.linspace(0, 2*np.pi, 64)
221
- v = np.linspace(0, np.pi, 32)
222
- r = 1.15
223
  xs = r*np.outer(np.cos(u), np.sin(v))
224
  ys = r*np.outer(np.sin(u), np.sin(v))
225
  zs = r*np.outer(np.ones_like(u), np.cos(v))
226
- sphere = go.Surface(x=xs, y=ys, z=zs, opacity=0.06, showscale=False)
 
 
 
 
 
 
227
 
228
  layout = go.Layout(
229
- title=f"Evo Sphere — Generation {gen_idx}",
230
- scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)),
231
- margin=dict(l=0, r=0, t=40, b=0),
 
 
 
 
232
  showlegend=False,
233
- height=680
 
234
  )
235
  return go.Figure(data=[sphere, scatter], layout=layout)
236
 
237
  def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
238
- # history items: (gen, best_fitness, best_acc or NaN)
239
  xs = [h[0] for h in history]
240
  if metric == "Accuracy":
241
- ys = [h[2] if (h[2] == h[2]) else None for h in history] # keep None for Demo
242
  title, ylab = "Best Accuracy per Generation", "Accuracy"
243
  else:
244
  ys = [h[1] for h in history]
245
- title, ylab = "Best Fitness per Generation", "Fitness (lower is better)"
246
- fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers")])
247
- fig.update_layout(title=title, xaxis_title="Generation", yaxis_title=ylab,
248
- margin=dict(l=30,r=10,t=40,b=30), height=360)
 
 
 
 
 
249
  return fig
250
 
 
 
 
 
251
  def approx_params(g: Genome) -> int:
252
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
253
  total = per_layer * g.n_layers + 1000 * g.memory_tokens
254
  return int(total)
255
 
256
- # =========================
257
- # ORCHESTRATOR
258
- # =========================
259
  class EvoRunner:
260
  def __init__(self):
261
  self.lock = threading.Lock()
@@ -269,25 +286,22 @@ class EvoRunner:
269
  self.running = True
270
 
271
  pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
272
- # initial eval
273
  for g in pop:
274
  fit, acc = evaluate_genome(g, dataset, explore)
275
  g.fitness, g.acc = fit, acc
276
 
277
- history: List[Tuple[int,float,float]] = [] # (gen, best_fitness, best_acc or NaN)
278
  best_overall: Optional[Genome] = None
279
 
280
  for gen in range(1, generations+1):
281
  if self.stop_flag: break
282
 
283
- # Selection (tournament)
284
  k = max(2, int(2 + exploit * 5))
285
  parents = []
286
  for _ in range(pop_size):
287
  sample = rng.sample(pop, k=k)
288
  parents.append(min(sample, key=lambda x: x.fitness))
289
 
290
- # Reproduce
291
  children = []
292
  for i in range(0, pop_size, 2):
293
  a = parents[i]; b = parents[(i+1) % pop_size]
@@ -296,30 +310,25 @@ class EvoRunner:
296
  children.extend([child1, child2])
297
  children = children[:pop_size]
298
 
299
- # Evaluate children
300
  for c in children:
301
  fit, acc = evaluate_genome(c, dataset, explore)
302
  c.fitness, c.acc = fit, acc
303
 
304
- # Elitism
305
  elite_n = max(1, pop_size // 10)
306
  elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
307
-
308
- # Next pop
309
  pop = sorted(children, key=lambda x: x.fitness)
310
  pop[-elite_n:] = elites
311
 
312
  best = min(pop, key=lambda x: x.fitness)
313
- if best_overall is None or best.fitness < best_overall.fitness:
314
- best_overall = best
315
 
316
  history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
317
 
318
- # Viz snapshot
319
  P = np.stack([g.vector() for g in pop], axis=0)
320
  P3 = sphere_project(P)
321
  sphere_fig = make_sphere_figure(P3, pop, gen)
322
  hist_fig = make_history_figure(history, metric_choice)
 
323
  top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
324
  top_table = [
325
  {
@@ -340,8 +349,8 @@ class EvoRunner:
340
 
341
  with self.lock:
342
  self.state = {
343
- "sphere": sphere_fig,
344
- "history": hist_fig,
345
  "top": top_table,
346
  "best": best_card,
347
  "gen": gen,
@@ -350,21 +359,17 @@ class EvoRunner:
350
  }
351
 
352
  time.sleep(max(0.0, pace_ms/1000.0))
353
-
354
  self.running = False
355
 
356
  def start(self, *args, **kwargs):
357
  if self.running: return
358
  t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
359
  t.start()
360
-
361
  def stop(self): self.stop_flag = True
362
 
363
  runner = EvoRunner()
364
 
365
- # =========================
366
- # UI CALLBACKS
367
- # =========================
368
  def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
369
  runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
370
  return (gr.update(interactive=False), gr.update(interactive=True))
@@ -376,13 +381,12 @@ def stop_evo():
376
  def poll_state():
377
  with runner.lock:
378
  s = runner.state.copy()
379
- sphere = s.get("sphere", go.Figure())
380
- history = s.get("history", go.Figure()) # already built by runner
381
  best = s.get("best", {})
382
  gen = s.get("gen", 0)
383
  dataset = s.get("dataset", "Demo (Surrogate)")
384
  top = s.get("top", [])
385
- # Stats text
386
  if best:
387
  acc_txt = "—" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
388
  stats_md = (
@@ -398,7 +402,7 @@ def poll_state():
398
  else:
399
  stats_md = "Waiting… click **Start Evolution**."
400
  df = pd.DataFrame(top)
401
- return sphere, history, stats_md, df
402
 
403
  def export_snapshot():
404
  from json import dumps
@@ -409,25 +413,19 @@ def export_snapshot():
409
  f.write(payload)
410
  return path
411
 
412
- # =========================
413
- # BUILD UI (bigger sphere, metric toggle)
414
- # =========================
415
- with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
416
- with gr.Column(elem_id="header-card"):
417
- gr.Markdown(
418
- "# Evo Playground — Live Evolution of Transformer Architectures\n"
419
- "Tune the search, watch the population converge, and track **accuracy** in real time (PIQA/HellaSwag)."
420
- )
421
 
422
  with gr.Row():
423
- # LEFT: Controls
424
- with gr.Column(scale=1):
425
  with gr.Group():
426
  dataset = gr.Dropdown(
427
  label="Dataset",
428
  choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
429
  value="Demo (Surrogate)",
430
- info="PIQA/HellaSwag compute real proxy accuracy; Demo uses a fast surrogate."
431
  )
432
  pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
433
  gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
@@ -442,18 +440,19 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
442
  start = gr.Button("▶ Start Evolution", variant="primary")
443
  stop = gr.Button("⏹ Stop", variant="secondary")
444
 
445
- with gr.Group(elem_id="right-card"):
446
- stats_md = gr.Markdown("Waiting…", elem_id="stats-md")
 
 
447
  export_btn = gr.Button("Export Snapshot (JSON)")
448
  export_file = gr.File(label="Download snapshot", visible=False)
449
 
450
- # RIGHT: Viz + Table
451
  with gr.Column(scale=2):
452
- with gr.Group(elem_id="viz-card"):
453
- sphere_plot = gr.Plot(label="Evolution Sphere")
454
- with gr.Group(elem_id="viz-card"):
455
- hist_plot = gr.Plot(label="Progress")
456
- with gr.Group(elem_id="table-card"):
457
  top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
458
 
459
  # Wiring
@@ -461,12 +460,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
461
  stop.click(stop_evo, [], [start, stop])
462
  export_btn.click(export_snapshot, [], [export_file])
463
 
464
- # Initial paint
465
- demo.load(poll_state, None, [sphere_plot, hist_plot, stats_md, top_df])
466
-
467
- # Continuous polling (every 0.7s)
468
- poller = gr.Timer(0.7)
469
- poller.tick(poll_state, None, [sphere_plot, hist_plot, stats_md, top_df])
470
 
471
  if __name__ == "__main__":
472
  demo.launch()
 
1
+ # app.py — Minimal, pro UI with big transparent sphere and clean hover
2
  import math, json, random, time, threading
3
  from dataclasses import dataclass, asdict
4
  from typing import List, Tuple, Dict, Any, Optional
 
6
 
7
  import numpy as np
8
  import plotly.graph_objs as go
9
+ import plotly.io as pio
10
  import gradio as gr
11
  import pandas as pd
12
 
 
13
  import torch
14
  import torch.nn as nn
15
  import torch.optim as optim
16
 
17
  from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
+ # ---------- Minimal style ----------
 
 
20
  CUSTOM_CSS = """
21
+ :root { --radius: 14px; --fg:#0f172a; --muted:#64748b; --line:#e5e7eb; }
22
+ * { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Arial, "Noto Sans", "Apple Color Emoji", "Segoe UI Emoji"; }
23
+ .gradio-container { max-width: 1180px !important; }
24
+ #header { border-radius: var(--radius); padding: 8px 6px; }
25
+ h1, h2, h3, .gr-markdown { color: var(--fg); }
26
+ .gr-button { border-radius: 10px; }
27
+ .controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
28
+ .panel { padding: 10px; }
29
+ #stats { font-weight: 300; color: var(--fg); }
30
+ #stats strong { font-weight: 500; }
31
+ .small { font-size: 13px; color: var(--muted); }
32
  """
33
 
34
+ # ---------- Genome ----------
 
 
35
  @dataclass
36
  class Genome:
37
  d_model: int
 
42
  dropout: float
43
  species: int = 0
44
  fitness: float = float("inf")
45
+ acc: Optional[float] = None
46
 
47
  def vector(self) -> np.ndarray:
48
  return np.array([
 
90
  acc = None
91
  )
92
 
93
+ # ---------- Proxy fitness ----------
 
 
94
  def rastrigin(x: np.ndarray) -> float:
95
  A, n = 10.0, x.shape[0]
96
  return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
 
113
  if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
114
  return None
115
 
116
+ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
 
117
  data = _cached_dataset(dataset_name)
118
  if data is None:
 
119
  v = genome.vector() * 2 - 1
120
  return float(rastrigin(v)), None
 
121
  Xtr_txt, ytr, Xva_txt, yva = data
122
  nfeat = 4096
123
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
 
133
  lossf = nn.BCEWithLogitsLoss()
134
 
135
  model.train()
136
+ steps, bs, N = 120, 256, Xtr_t.size(0)
 
137
  for _ in range(steps):
138
  idx = torch.randint(0, N, (bs,))
139
  xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
 
162
  fitness = (1.0 - acc) + parsimony + noise
163
  return float(max(0.0, min(1.5, fitness))), float(acc)
164
 
165
+ def evaluate_genome(genome: Genome, dataset: str, explore: float):
166
  if dataset == "Demo (Surrogate)":
167
  v = genome.vector() * 2 - 1
168
  base = rastrigin(v)
 
173
  return _train_eval_proxy(genome, "PIQA", explore)
174
  if dataset.startswith("HellaSwag"):
175
  return _train_eval_proxy(genome, "HellaSwag", explore)
 
176
  v = genome.vector() * 2 - 1
177
  return float(rastrigin(v)), None
178
 
179
+ # ---------- Viz helpers (bigger, transparent sphere) ----------
180
+ PALETTE = ["#111827", "#334155", "#475569", "#64748b", "#94a3b8"] # muted grayscale/blue
181
+ BG = "white"
182
+
183
  def sphere_project(points: np.ndarray) -> np.ndarray:
184
  rng = np.random.RandomState(42)
185
  W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
186
  Y = points @ W
187
  norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
188
+ return (Y / norms) * 1.2
189
+
190
+ def _species_colors(species: np.ndarray) -> list:
191
+ colors = []
192
+ for s in species:
193
+ c = PALETTE[int(s) % len(PALETTE)]
194
+ colors.append(c)
195
+ return colors
196
 
197
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
198
  species = np.array([g.species for g in genomes])
199
+ colors = _species_colors(species)
200
  custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
201
  g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
202
  for g in genomes], dtype=np.float32)
 
204
  scatter = go.Scatter3d(
205
  x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
206
  mode='markers',
207
+ marker=dict(size=6.5, color=colors, opacity=0.92),
208
  customdata=custom,
209
  hovertemplate=(
210
+ "<b>Genome</b><br>"
211
+ "d_model=%{customdata[0]:.0f} · layers=%{customdata[1]:.0f} · heads=%{customdata[2]:.0f}<br>"
212
  "ffn_mult=%{customdata[3]:.1f} · mem=%{customdata[4]:.0f} · drop=%{customdata[5]:.2f}<br>"
213
  "species=%{customdata[6]:.0f}<br>"
214
  "fitness=%{customdata[7]:.4f}<br>"
 
216
  )
217
  )
218
 
219
+ # Subtle, large sphere
220
+ u = np.linspace(0, 2*np.pi, 72)
221
+ v = np.linspace(0, np.pi, 36)
222
+ r = 1.2
223
  xs = r*np.outer(np.cos(u), np.sin(v))
224
  ys = r*np.outer(np.sin(u), np.sin(v))
225
  zs = r*np.outer(np.ones_like(u), np.cos(v))
226
+ sphere = go.Surface(
227
+ x=xs, y=ys, z=zs,
228
+ opacity=0.08,
229
+ showscale=False,
230
+ colorscale=[[0, "#cbd5e1"], [1, "#cbd5e1"]],
231
+ hoverinfo="skip"
232
+ )
233
 
234
  layout = go.Layout(
235
+ paper_bgcolor=BG, plot_bgcolor=BG,
236
+ title=f"Evo Architecture Sphere — Gen {gen_idx}",
237
+ scene=dict(
238
+ xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
239
+ bgcolor=BG
240
+ ),
241
+ margin=dict(l=0, r=0, t=36, b=0),
242
  showlegend=False,
243
+ height=720,
244
+ font=dict(family="Inter, Arial, sans-serif", size=14)
245
  )
246
  return go.Figure(data=[sphere, scatter], layout=layout)
247
 
248
  def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
 
249
  xs = [h[0] for h in history]
250
  if metric == "Accuracy":
251
+ ys = [h[2] if (h[2] == h[2]) else None for h in history]
252
  title, ylab = "Best Accuracy per Generation", "Accuracy"
253
  else:
254
  ys = [h[1] for h in history]
255
+ title, ylab = "Best Fitness per Generation", "Fitness ( better)"
256
+ fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers", line=dict(width=2))])
257
+ fig.update_layout(
258
+ paper_bgcolor=BG, plot_bgcolor=BG,
259
+ title=title, xaxis_title="Generation", yaxis_title=ylab,
260
+ margin=dict(l=30, r=10, t=36, b=30),
261
+ height=340,
262
+ font=dict(family="Inter, Arial, sans-serif", size=14)
263
+ )
264
  return fig
265
 
266
+ def fig_to_html(fig: go.Figure) -> str:
267
+ # Robust Plotly rendering inside Gradio
268
+ return pio.to_html(fig, include_plotlyjs="cdn", full_html=False, config=dict(displaylogo=False))
269
+
270
  def approx_params(g: Genome) -> int:
271
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
272
  total = per_layer * g.n_layers + 1000 * g.memory_tokens
273
  return int(total)
274
 
275
+ # ---------- Orchestrator ----------
 
 
276
  class EvoRunner:
277
  def __init__(self):
278
  self.lock = threading.Lock()
 
286
  self.running = True
287
 
288
  pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
 
289
  for g in pop:
290
  fit, acc = evaluate_genome(g, dataset, explore)
291
  g.fitness, g.acc = fit, acc
292
 
293
+ history: List[Tuple[int,float,float]] = []
294
  best_overall: Optional[Genome] = None
295
 
296
  for gen in range(1, generations+1):
297
  if self.stop_flag: break
298
 
 
299
  k = max(2, int(2 + exploit * 5))
300
  parents = []
301
  for _ in range(pop_size):
302
  sample = rng.sample(pop, k=k)
303
  parents.append(min(sample, key=lambda x: x.fitness))
304
 
 
305
  children = []
306
  for i in range(0, pop_size, 2):
307
  a = parents[i]; b = parents[(i+1) % pop_size]
 
310
  children.extend([child1, child2])
311
  children = children[:pop_size]
312
 
 
313
  for c in children:
314
  fit, acc = evaluate_genome(c, dataset, explore)
315
  c.fitness, c.acc = fit, acc
316
 
 
317
  elite_n = max(1, pop_size // 10)
318
  elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
 
 
319
  pop = sorted(children, key=lambda x: x.fitness)
320
  pop[-elite_n:] = elites
321
 
322
  best = min(pop, key=lambda x: x.fitness)
323
+ if best_overall is None or best.fitness < best_overall.fitness: best_overall = best
 
324
 
325
  history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
326
 
 
327
  P = np.stack([g.vector() for g in pop], axis=0)
328
  P3 = sphere_project(P)
329
  sphere_fig = make_sphere_figure(P3, pop, gen)
330
  hist_fig = make_history_figure(history, metric_choice)
331
+
332
  top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
333
  top_table = [
334
  {
 
349
 
350
  with self.lock:
351
  self.state = {
352
+ "sphere_html": fig_to_html(sphere_fig),
353
+ "history_html": fig_to_html(hist_fig),
354
  "top": top_table,
355
  "best": best_card,
356
  "gen": gen,
 
359
  }
360
 
361
  time.sleep(max(0.0, pace_ms/1000.0))
 
362
  self.running = False
363
 
364
  def start(self, *args, **kwargs):
365
  if self.running: return
366
  t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
367
  t.start()
 
368
  def stop(self): self.stop_flag = True
369
 
370
  runner = EvoRunner()
371
 
372
+ # ---------- UI callbacks ----------
 
 
373
  def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
374
  runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
375
  return (gr.update(interactive=False), gr.update(interactive=True))
 
381
  def poll_state():
382
  with runner.lock:
383
  s = runner.state.copy()
384
+ sphere_html = s.get("sphere_html", "")
385
+ history_html = s.get("history_html", "")
386
  best = s.get("best", {})
387
  gen = s.get("gen", 0)
388
  dataset = s.get("dataset", "Demo (Surrogate)")
389
  top = s.get("top", [])
 
390
  if best:
391
  acc_txt = "—" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
392
  stats_md = (
 
402
  else:
403
  stats_md = "Waiting… click **Start Evolution**."
404
  df = pd.DataFrame(top)
405
+ return sphere_html, history_html, stats_md, df
406
 
407
  def export_snapshot():
408
  from json import dumps
 
413
  f.write(payload)
414
  return path
415
 
416
+ # ---------- Build UI (minimal layout) ----------
417
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
418
+ with gr.Column(elem_id="header"):
419
+ gr.Markdown("## Evo Playground — Minimal Live Evolution (PIQA / HellaSwag accuracy)")
 
 
 
 
 
420
 
421
  with gr.Row():
422
+ with gr.Column(scale=1, elem_classes=["controls"]):
 
423
  with gr.Group():
424
  dataset = gr.Dropdown(
425
  label="Dataset",
426
  choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
427
  value="Demo (Surrogate)",
428
+ info="PIQA/HellaSwag compute real proxy accuracy; Demo is a fast surrogate."
429
  )
430
  pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
431
  gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
 
440
  start = gr.Button("▶ Start Evolution", variant="primary")
441
  stop = gr.Button("⏹ Stop", variant="secondary")
442
 
443
+ with gr.Group(elem_classes=["panel"]):
444
+ stats_md = gr.Markdown("Waiting…", elem_id="stats")
445
+
446
+ with gr.Group(elem_classes=["panel"]):
447
  export_btn = gr.Button("Export Snapshot (JSON)")
448
  export_file = gr.File(label="Download snapshot", visible=False)
449
 
 
450
  with gr.Column(scale=2):
451
+ with gr.Group(elem_classes=["panel"]):
452
+ sphere_html = gr.HTML()
453
+ with gr.Group(elem_classes=["panel"]):
454
+ hist_html = gr.HTML()
455
+ with gr.Group(elem_classes=["panel"]):
456
  top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
457
 
458
  # Wiring
 
460
  stop.click(stop_evo, [], [start, stop])
461
  export_btn.click(export_snapshot, [], [export_file])
462
 
463
+ # Initial paint + polling
464
+ demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
465
+ gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
 
 
 
466
 
467
  if __name__ == "__main__":
468
  demo.launch()