HemanM commited on
Commit
f8b9865
·
verified ·
1 Parent(s): 25ccd85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +452 -77
app.py CHANGED
@@ -1,5 +1,20 @@
1
- # app.py — Enhanced UI with better layout, visual hierarchy, and UX
2
- # ... [All your imports and backend code remain the same] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # =========================
5
  # ENHANCED CSS
@@ -44,10 +59,7 @@ h1, h2, h3, .gr-markdown {
44
  background: var(--accent) !important;
45
  border: 1px solid var(--accent) !important;
46
  }
47
-
48
- .btn-primary:hover {
49
- background: var(--accent-hover) !important;
50
- }
51
 
52
  .btn-secondary {
53
  background: transparent !important;
@@ -92,9 +104,7 @@ h1, h2, h3, .gr-markdown {
92
  color: var(--accent);
93
  }
94
 
95
- .param-slider {
96
- margin-bottom: 12px;
97
- }
98
 
99
  .visualization-container {
100
  display: flex;
@@ -103,10 +113,7 @@ h1, h2, h3, .gr-markdown {
103
  height: 100%;
104
  }
105
 
106
- .viz-panel {
107
- flex: 1;
108
- min-height: 300px;
109
- }
110
 
111
  .viz-header {
112
  display: flex;
@@ -143,17 +150,9 @@ h1, h2, h3, .gr-markdown {
143
  grid-template-columns: 1fr 1fr;
144
  gap: 16px;
145
  }
 
146
 
147
- @media (max-width: 1200px) {
148
- .controls-grid {
149
- grid-template-columns: 1fr;
150
- }
151
- }
152
-
153
- .data-table {
154
- max-height: 400px;
155
- overflow-y: auto;
156
- }
157
 
158
  .data-table table {
159
  width: 100%;
@@ -176,15 +175,9 @@ h1, h2, h3, .gr-markdown {
176
  border-bottom: 1px solid rgba(31, 43, 54, 0.5);
177
  }
178
 
179
- .data-table tr:hover {
180
- background: rgba(31, 43, 54, 0.3);
181
- }
182
 
183
- .action-buttons {
184
- display: flex;
185
- gap: 12px;
186
- margin-top: 20px;
187
- }
188
 
189
  .footer {
190
  margin-top: 20px;
@@ -196,7 +189,403 @@ h1, h2, h3, .gr-markdown {
196
  }
197
  """
198
 
199
- # ... [All your backend code remains the same] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  # =========================
202
  # BUILD ENHANCED UI
@@ -204,17 +593,16 @@ h1, h2, h3, .gr-markdown {
204
  with gr.Blocks(css=ENHANCED_CSS, theme=gr.themes.Default()) as demo:
205
  # Header
206
  with gr.Column(elem_id="header"):
207
- gr.Markdown("## 🧬 Neuroevolution Playground", elem_classes=["header-title"])
208
- gr.Markdown("Evolve neural architectures using genetic algorithms",
209
- elem_classes=["header-subtitle"])
210
-
211
  with gr.Row():
212
  # Left Panel - Controls
213
  with gr.Column(scale=1):
214
  # Parameters Group
215
  with gr.Group(elem_classes=["control-group"]):
216
  gr.Markdown("### 🛠 Evolution Parameters")
217
-
218
  with gr.Column():
219
  dataset = gr.Dropdown(
220
  label="Evaluation Dataset",
@@ -222,45 +610,38 @@ with gr.Blocks(css=ENHANCED_CSS, theme=gr.themes.Default()) as demo:
222
  value="Demo (Surrogate)",
223
  info="Dataset used for fitness evaluation"
224
  )
225
-
226
  with gr.Row():
227
  with gr.Column():
228
- pop = gr.Slider(8, 80, value=24, step=2, label="Population Size",
229
- elem_classes=["param-slider"])
230
- gens = gr.Slider(5, 200, value=60, step=1, label="Max Generations",
231
- elem_classes=["param-slider"])
232
- mut = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Mutation Rate",
233
- elem_classes=["param-slider"])
234
  with gr.Column():
235
- explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration",
236
- elem_classes=["param-slider"])
237
- exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation",
238
- elem_classes=["param-slider"])
239
  seed = gr.Number(value=42, label="Random Seed", precision=0)
240
-
241
- pace = gr.Slider(0, 1000, value=120, step=10, label="Simulation Speed (ms)",
242
- elem_classes=["param-slider"])
243
- metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy",
244
- label="History Metric Display")
245
-
246
  # Status Panel
247
  with gr.Group(elem_classes=["panel", "stats-panel"]):
248
  gr.Markdown("### 📊 Current Status")
249
  stats_md = gr.Markdown("Ready. Press **Start** to begin evolution.", elem_id="stats")
250
-
251
  # Action Buttons
252
  with gr.Row(elem_classes=["action-buttons"]):
253
  start = gr.Button("▶ Start Evolution", variant="primary", elem_classes=["btn-primary"])
254
- stop = gr.Button("⏹ Stop", variant="stop", elem_classes=["btn-danger"], interactive=False)
255
  clear = gr.Button("↻ Reset", elem_classes=["btn-secondary"])
256
-
257
  # Export
258
  with gr.Group(elem_classes=["panel"]):
259
  gr.Markdown("### 💾 Export Results")
260
  with gr.Row():
261
  export_btn = gr.Button("Save Snapshot (JSON)")
262
  export_file = gr.File(label="Download snapshot", visible=False)
263
-
264
  # Right Panel - Visualizations
265
  with gr.Column(scale=2):
266
  # 3D Visualization
@@ -270,46 +651,40 @@ with gr.Blocks(css=ENHANCED_CSS, theme=gr.themes.Default()) as demo:
270
  gr.Markdown("### 🌐 Architecture Space", elem_classes=["viz-title"])
271
  gen_counter = gr.Markdown("", elem_classes=["gen-counter"])
272
  sphere_html = gr.HTML()
273
-
274
  # History Visualization
275
  with gr.Group(elem_classes=["panel", "viz-panel"]):
276
  with gr.Column(elem_classes=["viz-header"]):
277
  gr.Markdown("### 📈 Performance History", elem_classes=["viz-title"])
278
  hist_html = gr.HTML()
279
-
280
  # Results Table
281
  with gr.Group(elem_classes=["panel"]):
282
  gr.Markdown("### 🏆 Top Genomes")
283
  with gr.Column(elem_classes=["data-table"]):
284
- top_df = gr.Dataframe(
285
- label="",
286
- headers=["Fitness", "Accuracy", "d_model", "Layers", "Heads", "FFN", "Mem", "Dropout", "Params"],
287
- datatype=["number", "number", "number", "number", "number", "number", "number", "number", "number"],
288
- wrap=True,
289
- interactive=False
290
- )
291
-
292
  # Footer
293
  with gr.Column(elem_classes=["footer"]):
294
- gr.Markdown("Evotransformer Playground v1.0 • Using Plotly and Gradio")
295
-
296
  # Wiring
297
  start.click(
298
- start_evo,
299
- [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice],
300
  [start, stop, clear]
301
  )
302
  stop.click(stop_evo, [], [start, stop, clear])
303
  clear.click(
304
- clear_evo,
305
- [],
306
- [sphere_html, hist_html, stats_md, top_df, start, stop, clear]
307
  )
308
  export_btn.click(export_snapshot, [], [export_file])
309
-
310
  # State polling
311
  demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df, gen_counter])
312
  gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df, gen_counter])
313
 
314
  if __name__ == "__main__":
315
- demo.launch()
 
1
+ # app.py — Enhanced UI + stable backend (idle sphere, Clear, inline Plotly, accuracy)
2
+ import math, random, time, threading
3
+ from dataclasses import dataclass, asdict
4
+ from typing import List, Tuple, Dict, Any, Optional
5
+ from functools import lru_cache
6
+
7
+ import numpy as np
8
+ import plotly.graph_objs as go
9
+ import plotly.io as pio
10
+ import gradio as gr
11
+ import pandas as pd
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.optim as optim
16
+
17
+ from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
  # =========================
20
  # ENHANCED CSS
 
59
  background: var(--accent) !important;
60
  border: 1px solid var(--accent) !important;
61
  }
62
+ .btn-primary:hover { background: var(--accent-hover) !important; }
 
 
 
63
 
64
  .btn-secondary {
65
  background: transparent !important;
 
104
  color: var(--accent);
105
  }
106
 
107
+ .param-slider { margin-bottom: 12px; }
 
 
108
 
109
  .visualization-container {
110
  display: flex;
 
113
  height: 100%;
114
  }
115
 
116
+ .viz-panel { flex: 1; min-height: 300px; }
 
 
 
117
 
118
  .viz-header {
119
  display: flex;
 
150
  grid-template-columns: 1fr 1fr;
151
  gap: 16px;
152
  }
153
+ @media (max-width: 1200px) { .controls-grid { grid-template-columns: 1fr; } }
154
 
155
+ .data-table { max-height: 400px; overflow-y: auto; }
 
 
 
 
 
 
 
 
 
156
 
157
  .data-table table {
158
  width: 100%;
 
175
  border-bottom: 1px solid rgba(31, 43, 54, 0.5);
176
  }
177
 
178
+ .data-table tr:hover { background: rgba(31, 43, 54, 0.3); }
 
 
179
 
180
+ .action-buttons { display: flex; gap: 12px; margin-top: 20px; }
 
 
 
 
181
 
182
  .footer {
183
  margin-top: 20px;
 
189
  }
190
  """
191
 
192
+ # =========================
193
+ # GENOME + EVOLUTION CORE
194
+ # =========================
195
+ @dataclass
196
+ class Genome:
197
+ d_model: int
198
+ n_layers: int
199
+ n_heads: int
200
+ ffn_mult: float
201
+ memory_tokens: int
202
+ dropout: float
203
+ species: int = 0
204
+ fitness: float = float("inf")
205
+ acc: Optional[float] = None
206
+
207
+ def vector(self) -> np.ndarray:
208
+ return np.array([
209
+ self.d_model / 1024.0,
210
+ self.n_layers / 24.0,
211
+ self.n_heads / 32.0,
212
+ self.ffn_mult / 8.0,
213
+ self.memory_tokens / 64.0,
214
+ self.dropout / 0.5
215
+ ], dtype=np.float32)
216
+
217
+ def random_genome(rng: random.Random) -> Genome:
218
+ return Genome(
219
+ d_model=rng.choice([256, 384, 512, 640]),
220
+ n_layers=rng.choice([4, 6, 8, 10, 12]),
221
+ n_heads=rng.choice([4, 6, 8, 10, 12]),
222
+ ffn_mult=rng.choice([2.0, 3.0, 4.0, 6.0]),
223
+ memory_tokens=rng.choice([0, 4, 8, 16]),
224
+ dropout=rng.choice([0.0, 0.05, 0.1, 0.15]),
225
+ species=rng.randrange(5)
226
+ )
227
+
228
+ def mutate(g: Genome, rng: random.Random, rate: float) -> Genome:
229
+ g = Genome(**asdict(g))
230
+ if rng.random() < rate: g.d_model = rng.choice([256, 384, 512, 640])
231
+ if rng.random() < rate: g.n_layers = rng.choice([4, 6, 8, 10, 12])
232
+ if rng.random() < rate: g.n_heads = rng.choice([4, 6, 8, 10, 12])
233
+ if rng.random() < rate: g.ffn_mult = rng.choice([2.0, 3.0, 4.0, 6.0])
234
+ if rng.random() < rate: g.memory_tokens = rng.choice([0, 4, 8, 16])
235
+ if rng.random() < rate: g.dropout = rng.choice([0.0, 0.05, 0.1, 0.15])
236
+ if rng.random() < rate * 0.5: g.species = rng.randrange(5)
237
+ g.fitness = float("inf"); g.acc = None
238
+ return g
239
+
240
+ def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
241
+ return Genome(
242
+ d_model = a.d_model if rng.random()<0.5 else b.d_model,
243
+ n_layers = a.n_layers if rng.random()<0.5 else b.n_layers,
244
+ n_heads = a.n_heads if rng.random()<0.5 else b.n_heads,
245
+ ffn_mult = a.ffn_mult if rng.random()<0.5 else b.ffn_mult,
246
+ memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
247
+ dropout = a.dropout if rng.random()<0.5 else b.dropout,
248
+ species = a.species if rng.random()<0.5 else b.species,
249
+ fitness = float("inf"), acc=None
250
+ )
251
+
252
+ # =========================
253
+ # PROXY FITNESS
254
+ # =========================
255
+ def rastrigin(x: np.ndarray) -> float:
256
+ A, n = 10.0, x.shape[0]
257
+ return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
258
+
259
+ class TinyMLP(nn.Module):
260
+ def __init__(self, in_dim: int, genome: Genome):
261
+ super().__init__()
262
+ h1 = max(64, int(0.25 * genome.d_model))
263
+ h2 = max(32, int(genome.ffn_mult * 32))
264
+ self.net = nn.Sequential(
265
+ nn.Linear(in_dim, h1), nn.ReLU(),
266
+ nn.Linear(h1, h2), nn.ReLU(),
267
+ nn.Linear(h2, 1)
268
+ )
269
+ def forward(self, x): return self.net(x).squeeze(-1)
270
+
271
+ @lru_cache(maxsize=4)
272
+ def _cached_dataset(name: str):
273
+ try:
274
+ if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
275
+ if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
276
+ except Exception:
277
+ return None
278
+ return None
279
+
280
+ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str="cpu"):
281
+ data = _cached_dataset(dataset_name)
282
+ if data is None:
283
+ # Fallback to surrogate so UI still runs
284
+ v = genome.vector() * 2 - 1
285
+ base = rastrigin(v)
286
+ parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
287
+ noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
288
+ return float(base + parsimony + noise), None
289
+
290
+ Xtr_txt, ytr, Xva_txt, yva = data
291
+ nfeat = 4096
292
+ Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
293
+ Xva = hash_vectorize(Xva_txt, n_features=nfeat, seed=5678)
294
+
295
+ Xtr_t = torch.from_numpy(Xtr); ytr_t = torch.from_numpy(ytr.astype(np.float32))
296
+ Xva_t = torch.from_numpy(Xva); yva_t = torch.from_numpy(yva.astype(np.float32))
297
+
298
+ model = TinyMLP(nfeat, genome).to(device)
299
+ opt = optim.AdamW(model.parameters(), lr=2e-3)
300
+ lossf = nn.BCEWithLogitsLoss()
301
+
302
+ model.train(); steps, bs, N = 120, 256, Xtr_t.size(0)
303
+ for _ in range(steps):
304
+ idx = torch.randint(0, N, (bs,))
305
+ xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
306
+ logits = model(xb); loss = lossf(logits, yb)
307
+ opt.zero_grad(); loss.backward()
308
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
309
+ opt.step()
310
+
311
+ model.eval()
312
+ with torch.no_grad():
313
+ logits = model(Xva_t.to(device))
314
+ probs = torch.sigmoid(logits).cpu().numpy()
315
+
316
+ if dataset_name.startswith("PIQA"):
317
+ probs = probs.reshape(-1,2); yva2 = yva.reshape(-1,2)
318
+ pred = (probs[:,0] > probs[:,1]).astype(np.int64)
319
+ truth = (yva2[:,0] == 1).astype(np.int64)
320
+ acc = float((pred == truth).mean())
321
+ else:
322
+ probs = probs.reshape(-1,4); yva2 = yva.reshape(-1,4)
323
+ pred = probs.argmax(axis=1); truth = yva2.argmax(axis=1)
324
+ acc = float((pred == truth).mean())
325
+
326
+ parsimony = 0.00000002 * (genome.d_model**2 * genome.n_layers) + 0.0001 * genome.memory_tokens
327
+ noise = np.random.normal(scale=0.01 * max(0.0, min(1.0, explore)))
328
+ fitness = (1.0 - acc) + parsimony + noise
329
+ return float(max(0.0, min(1.5, fitness))), float(acc)
330
+
331
+ def evaluate_genome(genome: Genome, dataset: str, explore: float):
332
+ if dataset == "Demo (Surrogate)":
333
+ v = genome.vector() * 2 - 1
334
+ base = rastrigin(v)
335
+ parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
336
+ noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
337
+ return float(base + parsimony + noise), None
338
+ if dataset.startswith("PIQA"): return _train_eval_proxy(genome, "PIQA", explore)
339
+ if dataset.startswith("HellaSwag"): return _train_eval_proxy(genome, "HellaSwag", explore)
340
+ v = genome.vector() * 2 - 1
341
+ return float(rastrigin(v)), None
342
+
343
+ # =========================
344
+ # VIZ — idle sphere, big transparent surface
345
+ # =========================
346
+ BG = "#0F1A24"
347
+ DOT = "#93C5FD"
348
+ SPHERE = "#cbd5e1"
349
+
350
+ def sphere_project(points: np.ndarray) -> np.ndarray:
351
+ rng = np.random.RandomState(42)
352
+ W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
353
+ Y = points @ W
354
+ norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
355
+ return (Y / norms) * 1.22
356
+
357
+ def make_idle_sphere() -> go.Figure:
358
+ u = np.linspace(0, 2*np.pi, 72)
359
+ v = np.linspace(0, np.pi, 36)
360
+ r = 1.22
361
+ xs = r*np.outer(np.cos(u), np.sin(v))
362
+ ys = r*np.outer(np.sin(u), np.sin(v))
363
+ zs = r*np.outer(np.ones_like(u), np.cos(v))
364
+ sphere = go.Surface(
365
+ x=xs, y=ys, z=zs,
366
+ opacity=0.06, showscale=False,
367
+ colorscale=[[0, SPHERE],[1, SPHERE]],
368
+ hoverinfo="skip"
369
+ )
370
+ layout = go.Layout(
371
+ paper_bgcolor=BG, plot_bgcolor=BG,
372
+ title=dict(text="Architecture Space (idle)", font=dict(color="#E5E7EB")),
373
+ scene=dict(
374
+ xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
375
+ bgcolor=BG
376
+ ),
377
+ margin=dict(l=0, r=0, t=36, b=0), showlegend=False, height=720,
378
+ font=dict(family="Inter, Arial, sans-serif", size=14, color="#E5E7EB")
379
+ )
380
+ return go.Figure(data=[sphere], layout=layout)
381
+
382
+ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
383
+ custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
384
+ g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
385
+ for g in genomes], dtype=np.float32)
386
+ scatter = go.Scatter3d(
387
+ x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
388
+ mode='markers',
389
+ marker=dict(size=7.0, color=DOT, opacity=0.92),
390
+ customdata=custom,
391
+ hovertemplate=(
392
+ "<b>Genome</b><br>"
393
+ "d_model=%{customdata[0]:.0f} · layers=%{customdata[1]:.0f} · heads=%{customdata[2]:.0f}<br>"
394
+ "ffn_mult=%{customdata[3]:.1f} · mem=%{customdata[4]:.0f} · drop=%{customdata[5]:.2f}<br>"
395
+ "fitness=%{customdata[7]:.4f} · acc=%{customdata[8]:.3f}<extra></extra>"
396
+ )
397
+ )
398
+ idle = make_idle_sphere()
399
+ fig = go.Figure(data=idle.data + (scatter,), layout=idle.layout)
400
+ fig.update_layout(title=dict(text=f"Evo Architecture Space — Gen {gen_idx}"))
401
+ return fig
402
+
403
+ def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
404
+ xs = [h[0] for h in history]
405
+ if metric == "Accuracy":
406
+ ys = [h[2] if (h[2] == h[2]) else None for h in history]
407
+ title, ylab = "Best Accuracy per Generation", "Accuracy"
408
+ else:
409
+ ys = [h[1] for h in history]
410
+ title, ylab = "Best Fitness per Generation", "Fitness (↓ better)"
411
+ fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers", line=dict(width=2), marker=dict(color=DOT))])
412
+ fig.update_layout(
413
+ paper_bgcolor=BG, plot_bgcolor=BG, font=dict(color="#E5E7EB"),
414
+ title=dict(text=title), xaxis_title="Generation", yaxis_title=ylab,
415
+ margin=dict(l=30, r=10, t=36, b=30), height=340
416
+ )
417
+ fig.update_xaxes(gridcolor="#1f2b36"); fig.update_yaxes(gridcolor="#1f2b36")
418
+ return fig
419
+
420
+ def fig_to_html(fig: go.Figure) -> str:
421
+ # Inline Plotly JS so it renders even without CDN
422
+ return pio.to_html(fig, include_plotlyjs=True, full_html=False, config=dict(displaylogo=False))
423
+
424
+ def approx_params(g: Genome) -> int:
425
+ per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
426
+ total = per_layer * g.n_layers + 1000 * g.memory_tokens
427
+ return int(total)
428
+
429
+ # =========================
430
+ # RUNNER
431
+ # =========================
432
+ class EvoRunner:
433
+ def __init__(self):
434
+ self.lock = threading.Lock()
435
+ self.running = False
436
+ self.stop_flag = False
437
+ self.state: Dict[str, Any] = {}
438
+ # Seed idle visuals
439
+ idle = fig_to_html(make_idle_sphere())
440
+ self.state = {
441
+ "sphere_html": idle,
442
+ "history_html": fig_to_html(make_history_figure([], "Accuracy")),
443
+ "top": [], "best": {}, "gen": 0,
444
+ "dataset": "Demo (Surrogate)", "metric": "Accuracy"
445
+ }
446
+
447
+ def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms, metric_choice):
448
+ rng = random.Random(int(seed))
449
+ self.stop_flag = False
450
+ self.running = True
451
+
452
+ pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
453
+ for g in pop:
454
+ fit, acc = evaluate_genome(g, dataset, explore)
455
+ g.fitness, g.acc = fit, acc
456
+
457
+ history: List[Tuple[int,float,float]] = []
458
+
459
+ for gen in range(1, generations+1):
460
+ if self.stop_flag: break
461
+
462
+ k = max(2, int(2 + exploit * 5))
463
+ parents = [min(rng.sample(pop, k=k), key=lambda x: x.fitness) for _ in range(pop_size)]
464
+
465
+ children = []
466
+ for i in range(0, pop_size, 2):
467
+ a = parents[i]; b = parents[(i+1) % pop_size]
468
+ child1 = mutate(crossover(a,b,rng), rng, mutation_rate)
469
+ child2 = mutate(crossover(b,a,rng), rng, mutation_rate)
470
+ children.extend([child1, child2])
471
+ children = children[:pop_size]
472
+
473
+ for c in children:
474
+ fit, acc = evaluate_genome(c, dataset, explore)
475
+ c.fitness, c.acc = fit, acc
476
+
477
+ elite_n = max(1, pop_size // 10)
478
+ elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
479
+ pop = sorted(children, key=lambda x: x.fitness)
480
+ pop[-elite_n:] = elites
481
+
482
+ best = min(pop, key=lambda x: x.fitness)
483
+ history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
484
+
485
+ P = np.stack([g.vector() for g in pop], axis=0)
486
+ P3 = sphere_project(P)
487
+ sphere_fig = make_sphere_figure(P3, pop, gen)
488
+ hist_fig = make_history_figure(history, metric_choice)
489
+
490
+ top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
491
+ top_table = [{
492
+ "gen": gen, "fitness": round(t.fitness, 4),
493
+ "accuracy": (None if t.acc is None else round(float(t.acc), 4)),
494
+ "d_model": t.d_model, "layers": t.n_layers, "heads": t.n_heads,
495
+ "ffn_mult": t.ffn_mult, "mem": t.memory_tokens, "dropout": t.dropout,
496
+ "params_approx": approx_params(t)
497
+ } for t in top]
498
+ best_card = top_table[0] if top_table else {}
499
+
500
+ with self.lock:
501
+ self.state = {
502
+ "sphere_html": fig_to_html(sphere_fig),
503
+ "history_html": fig_to_html(hist_fig),
504
+ "top": top_table,
505
+ "best": best_card,
506
+ "gen": gen,
507
+ "dataset": dataset,
508
+ "metric": metric_choice
509
+ }
510
+
511
+ time.sleep(max(0.0, pace_ms/1000.0))
512
+ self.running = False
513
+
514
+ def start(self, *args, **kwargs):
515
+ if self.running: return
516
+ t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
517
+ t.start()
518
+
519
+ def stop(self): self.stop_flag = True
520
+
521
+ def clear(self):
522
+ # stop and reset to idle sphere
523
+ self.stop_flag = True
524
+ idle = fig_to_html(make_idle_sphere())
525
+ with self.lock:
526
+ self.running = False
527
+ self.state = {
528
+ "sphere_html": idle,
529
+ "history_html": fig_to_html(make_history_figure([], "Accuracy")),
530
+ "top": [], "best": {}, "gen": 0,
531
+ "dataset": "Demo (Surrogate)", "metric": "Accuracy"
532
+ }
533
+
534
+ runner = EvoRunner()
535
+
536
+ # =========================
537
+ # UI CALLBACKS
538
+ # =========================
539
+ def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
540
+ runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
541
+ return (gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False))
542
+
543
+ def stop_evo():
544
+ runner.stop()
545
+ return (gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True))
546
+
547
+ def clear_evo():
548
+ runner.clear()
549
+ sphere_html, history_html, stats_md, df, gen_counter_md = poll_state()
550
+ return sphere_html, history_html, stats_md, df, gen_counter_md, gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True)
551
+
552
+ def poll_state():
553
+ with runner.lock:
554
+ s = runner.state.copy()
555
+ sphere_html = s.get("sphere_html", "")
556
+ history_html = s.get("history_html", "")
557
+ best = s.get("best", {})
558
+ gen = s.get("gen", 0)
559
+ dataset = s.get("dataset", "Demo (Surrogate)")
560
+ top = s.get("top", [])
561
+
562
+ if best:
563
+ acc_txt = "—" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
564
+ stats_md = (
565
+ f"**Dataset:** {dataset} \n"
566
+ f"**Generation:** {gen} \n"
567
+ f"**Best fitness:** {best.get('fitness','–')} \n"
568
+ f"**Best accuracy:** {acc_txt} \n"
569
+ f"**Config:** d_model={best.get('d_model')} · layers={best.get('layers')} · "
570
+ f"heads={best.get('heads')} · ffn_mult={best.get('ffn_mult')} · mem={best.get('mem')} · "
571
+ f"dropout={best.get('dropout')} \n"
572
+ f"**~Params (rough):** {best.get('params_approx'):,}"
573
+ )
574
+ else:
575
+ stats_md = "Ready. Press **Start** to begin evolution."
576
+
577
+ df = pd.DataFrame(top)
578
+ gen_counter_md = f"Gen **{gen}**"
579
+ return sphere_html, history_html, stats_md, df, gen_counter_md
580
+
581
+ def export_snapshot():
582
+ from json import dumps
583
+ with runner.lock:
584
+ payload = dumps(runner.state, default=lambda o: o, indent=2)
585
+ path = "evo_snapshot.json"
586
+ with open(path, "w", encoding="utf-8") as f:
587
+ f.write(payload)
588
+ return path
589
 
590
  # =========================
591
  # BUILD ENHANCED UI
 
593
  with gr.Blocks(css=ENHANCED_CSS, theme=gr.themes.Default()) as demo:
594
  # Header
595
  with gr.Column(elem_id="header"):
596
+ gr.Markdown("## 🧬 Neuroevolution Playground")
597
+ gr.Markdown("Evolve neural architectures using genetic algorithms")
598
+
 
599
  with gr.Row():
600
  # Left Panel - Controls
601
  with gr.Column(scale=1):
602
  # Parameters Group
603
  with gr.Group(elem_classes=["control-group"]):
604
  gr.Markdown("### 🛠 Evolution Parameters")
605
+
606
  with gr.Column():
607
  dataset = gr.Dropdown(
608
  label="Evaluation Dataset",
 
610
  value="Demo (Surrogate)",
611
  info="Dataset used for fitness evaluation"
612
  )
613
+
614
  with gr.Row():
615
  with gr.Column():
616
+ pop = gr.Slider(8, 80, value=24, step=2, label="Population Size", elem_classes=["param-slider"])
617
+ gens = gr.Slider(5, 200, value=60, step=1, label="Max Generations", elem_classes=["param-slider"])
618
+ mut = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Mutation Rate", elem_classes=["param-slider"])
 
 
 
619
  with gr.Column():
620
+ explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration", elem_classes=["param-slider"])
621
+ exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation", elem_classes=["param-slider"])
 
 
622
  seed = gr.Number(value=42, label="Random Seed", precision=0)
623
+
624
+ pace = gr.Slider(0, 1000, value=120, step=10, label="Simulation Speed (ms)", elem_classes=["param-slider"])
625
+ metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy", label="History Metric Display")
626
+
 
 
627
  # Status Panel
628
  with gr.Group(elem_classes=["panel", "stats-panel"]):
629
  gr.Markdown("### 📊 Current Status")
630
  stats_md = gr.Markdown("Ready. Press **Start** to begin evolution.", elem_id="stats")
631
+
632
  # Action Buttons
633
  with gr.Row(elem_classes=["action-buttons"]):
634
  start = gr.Button("▶ Start Evolution", variant="primary", elem_classes=["btn-primary"])
635
+ stop = gr.Button("⏹ Stop", variant="secondary", elem_classes=["btn-danger"], interactive=False)
636
  clear = gr.Button("↻ Reset", elem_classes=["btn-secondary"])
637
+
638
  # Export
639
  with gr.Group(elem_classes=["panel"]):
640
  gr.Markdown("### 💾 Export Results")
641
  with gr.Row():
642
  export_btn = gr.Button("Save Snapshot (JSON)")
643
  export_file = gr.File(label="Download snapshot", visible=False)
644
+
645
  # Right Panel - Visualizations
646
  with gr.Column(scale=2):
647
  # 3D Visualization
 
651
  gr.Markdown("### 🌐 Architecture Space", elem_classes=["viz-title"])
652
  gen_counter = gr.Markdown("", elem_classes=["gen-counter"])
653
  sphere_html = gr.HTML()
654
+
655
  # History Visualization
656
  with gr.Group(elem_classes=["panel", "viz-panel"]):
657
  with gr.Column(elem_classes=["viz-header"]):
658
  gr.Markdown("### 📈 Performance History", elem_classes=["viz-title"])
659
  hist_html = gr.HTML()
660
+
661
  # Results Table
662
  with gr.Group(elem_classes=["panel"]):
663
  gr.Markdown("### 🏆 Top Genomes")
664
  with gr.Column(elem_classes=["data-table"]):
665
+ top_df = gr.Dataframe(label="", wrap=True, interactive=False)
666
+
 
 
 
 
 
 
667
  # Footer
668
  with gr.Column(elem_classes=["footer"]):
669
+ gr.Markdown("Neuroevolution Playground v1.0 • Plotly + Gradio")
670
+
671
  # Wiring
672
  start.click(
673
+ start_evo,
674
+ [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice],
675
  [start, stop, clear]
676
  )
677
  stop.click(stop_evo, [], [start, stop, clear])
678
  clear.click(
679
+ clear_evo,
680
+ [],
681
+ [sphere_html, hist_html, stats_md, top_df, gen_counter, start, stop, clear]
682
  )
683
  export_btn.click(export_snapshot, [], [export_file])
684
+
685
  # State polling
686
  demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df, gen_counter])
687
  gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df, gen_counter])
688
 
689
  if __name__ == "__main__":
690
+ demo.launch()