HemanM commited on
Commit
25ccd85
Β·
verified Β·
1 Parent(s): a8f377f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -462
app.py CHANGED
@@ -1,481 +1,315 @@
1
- # app.py β€” Minimal dark UI, default idle sphere, Clear button, inline Plotly
2
- import math, json, random, time, threading
3
- from dataclasses import dataclass, asdict
4
- from typing import List, Tuple, Dict, Any, Optional
5
- from functools import lru_cache
6
-
7
- import numpy as np
8
- import plotly.graph_objs as go
9
- import plotly.io as pio
10
- import gradio as gr
11
- import pandas as pd
12
-
13
- import torch
14
- import torch.nn as nn
15
- import torch.optim as optim
16
-
17
- from data_utils import load_piqa, load_hellaswag, hash_vectorize
18
 
19
  # =========================
20
- # STYLE β€” calm, dark, thin
21
  # =========================
22
- CUSTOM_CSS = """
23
- :root { --radius: 14px; --fg:#E5E7EB; --muted:#94A3B8; --line:#111827; --bg:#0F1A24; }
24
- * { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Arial; font-weight: 300; }
25
- .gradio-container { max-width: 1140px !important; background: var(--bg); }
26
- #header { border-radius: var(--radius); padding: 6px 2px; }
27
- h1, h2, h3, .gr-markdown { color: var(--fg); }
28
- .gr-button { border-radius: 10px; }
29
- .controls .gr-group, .panel { border: 1px solid #1f2b36; border-radius: var(--radius); background: #0c161f; }
30
- .panel { padding: 10px; }
31
- #stats { color: var(--fg); }
32
- #stats strong { font-weight: 500; }
33
- .small { font-size: 12px; color: var(--muted); }
34
- label, .gradio-container * { color: var(--fg); }
35
- input, textarea, select { color: var(--fg) !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  """
37
 
38
- # =========================
39
- # GENOME
40
- # =========================
41
- @dataclass
42
- class Genome:
43
- d_model: int
44
- n_layers: int
45
- n_heads: int
46
- ffn_mult: float
47
- memory_tokens: int
48
- dropout: float
49
- species: int = 0
50
- fitness: float = float("inf")
51
- acc: Optional[float] = None
52
-
53
- def vector(self) -> np.ndarray:
54
- return np.array([
55
- self.d_model / 1024.0,
56
- self.n_layers / 24.0,
57
- self.n_heads / 32.0,
58
- self.ffn_mult / 8.0,
59
- self.memory_tokens / 64.0,
60
- self.dropout / 0.5
61
- ], dtype=np.float32)
62
-
63
- def random_genome(rng: random.Random) -> Genome:
64
- return Genome(
65
- d_model=rng.choice([256, 384, 512, 640]),
66
- n_layers=rng.choice([4, 6, 8, 10, 12]),
67
- n_heads=rng.choice([4, 6, 8, 10, 12]),
68
- ffn_mult=rng.choice([2.0, 3.0, 4.0, 6.0]),
69
- memory_tokens=rng.choice([0, 4, 8, 16]),
70
- dropout=rng.choice([0.0, 0.05, 0.1, 0.15]),
71
- species=rng.randrange(5)
72
- )
73
-
74
- def mutate(g: Genome, rng: random.Random, rate: float) -> Genome:
75
- g = Genome(**asdict(g))
76
- if rng.random() < rate: g.d_model = rng.choice([256, 384, 512, 640])
77
- if rng.random() < rate: g.n_layers = rng.choice([4, 6, 8, 10, 12])
78
- if rng.random() < rate: g.n_heads = rng.choice([4, 6, 8, 10, 12])
79
- if rng.random() < rate: g.ffn_mult = rng.choice([2.0, 3.0, 4.0, 6.0])
80
- if rng.random() < rate: g.memory_tokens = rng.choice([0, 4, 8, 16])
81
- if rng.random() < rate: g.dropout = rng.choice([0.0, 0.05, 0.1, 0.15])
82
- if rng.random() < rate * 0.5: g.species = rng.randrange(5)
83
- g.fitness = float("inf"); g.acc = None
84
- return g
85
-
86
- def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
87
- return Genome(
88
- d_model = a.d_model if rng.random()<0.5 else b.d_model,
89
- n_layers = a.n_layers if rng.random()<0.5 else b.n_layers,
90
- n_heads = a.n_heads if rng.random()<0.5 else b.n_heads,
91
- ffn_mult = a.ffn_mult if rng.random()<0.5 else b.ffn_mult,
92
- memory_tokens = a.memory_tokens if rng.random()<0.5 else b.memory_tokens,
93
- dropout = a.dropout if rng.random()<0.5 else b.dropout,
94
- species = a.species if rng.random()<0.5 else b.species,
95
- fitness = float("inf"), acc=None
96
- )
97
-
98
- # =========================
99
- # PROXY FITNESS
100
- # =========================
101
- def rastrigin(x: np.ndarray) -> float:
102
- A, n = 10.0, x.shape[0]
103
- return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
104
-
105
- class TinyMLP(nn.Module):
106
- def __init__(self, in_dim: int, genome: Genome):
107
- super().__init__()
108
- h1 = max(64, int(0.25 * genome.d_model))
109
- h2 = max(32, int(genome.ffn_mult * 32))
110
- self.net = nn.Sequential(
111
- nn.Linear(in_dim, h1), nn.ReLU(),
112
- nn.Linear(h1, h2), nn.ReLU(),
113
- nn.Linear(h2, 1)
114
- )
115
- def forward(self, x): return self.net(x).squeeze(-1)
116
-
117
- from functools import lru_cache
118
- @lru_cache(maxsize=4)
119
- def _cached_dataset(name: str):
120
- try:
121
- if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
122
- if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
123
- except Exception:
124
- return None
125
- return None
126
-
127
- def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str="cpu"):
128
- data = _cached_dataset(dataset_name)
129
- if data is None:
130
- v = genome.vector() * 2 - 1
131
- base = rastrigin(v)
132
- parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
133
- noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
134
- return float(base + parsimony + noise), None
135
-
136
- Xtr_txt, ytr, Xva_txt, yva = data
137
- nfeat = 4096
138
- Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
139
- Xva = hash_vectorize(Xva_txt, n_features=nfeat, seed=5678)
140
-
141
- Xtr_t = torch.from_numpy(Xtr); ytr_t = torch.from_numpy(ytr.astype(np.float32))
142
- Xva_t = torch.from_numpy(Xva); yva_t = torch.from_numpy(yva.astype(np.float32))
143
-
144
- model = TinyMLP(nfeat, genome).to(device)
145
- opt = optim.AdamW(model.parameters(), lr=2e-3)
146
- lossf = nn.BCEWithLogitsLoss()
147
-
148
- model.train(); steps, bs, N = 120, 256, Xtr_t.size(0)
149
- for _ in range(steps):
150
- idx = torch.randint(0, N, (bs,))
151
- xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
152
- logits = model(xb); loss = lossf(logits, yb)
153
- opt.zero_grad(); loss.backward()
154
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
155
- opt.step()
156
-
157
- model.eval()
158
- with torch.no_grad():
159
- logits = model(Xva_t.to(device))
160
- probs = torch.sigmoid(logits).cpu().numpy()
161
-
162
- if dataset_name.startswith("PIQA"):
163
- probs = probs.reshape(-1,2); yva2 = yva.reshape(-1,2)
164
- pred = (probs[:,0] > probs[:,1]).astype(np.int64)
165
- truth = (yva2[:,0] == 1).astype(np.int64)
166
- acc = float((pred == truth).mean())
167
- else:
168
- probs = probs.reshape(-1,4); yva2 = yva.reshape(-1,4)
169
- pred = probs.argmax(axis=1); truth = yva2.argmax(axis=1)
170
- acc = float((pred == truth).mean())
171
-
172
- parsimony = 0.00000002 * (genome.d_model**2 * genome.n_layers) + 0.0001 * genome.memory_tokens
173
- noise = np.random.normal(scale=0.01 * max(0.0, min(1.0, explore)))
174
- fitness = (1.0 - acc) + parsimony + noise
175
- return float(max(0.0, min(1.5, fitness))), float(acc)
176
-
177
- def evaluate_genome(genome: Genome, dataset: str, explore: float):
178
- if dataset == "Demo (Surrogate)":
179
- v = genome.vector() * 2 - 1
180
- base = rastrigin(v)
181
- parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
182
- noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
183
- return float(base + parsimony + noise), None
184
- if dataset.startswith("PIQA"): return _train_eval_proxy(genome, "PIQA", explore)
185
- if dataset.startswith("HellaSwag"): return _train_eval_proxy(genome, "HellaSwag", explore)
186
- v = genome.vector() * 2 - 1
187
- return float(rastrigin(v)), None
188
 
189
  # =========================
190
- # VIZ β€” big transparent sphere
191
- # =========================
192
- BG = "#0F1A24"
193
- DOT = "#93C5FD" # soft blue dot
194
- SPHERE = "#cbd5e1" # subtle sphere tint
195
-
196
- def sphere_project(points: np.ndarray) -> np.ndarray:
197
- rng = np.random.RandomState(42)
198
- W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
199
- Y = points @ W
200
- norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
201
- return (Y / norms) * 1.22
202
-
203
- def make_idle_sphere() -> go.Figure:
204
- # empty scatter, only sphere
205
- u = np.linspace(0, 2*np.pi, 72)
206
- v = np.linspace(0, np.pi, 36)
207
- r = 1.22
208
- xs = r*np.outer(np.cos(u), np.sin(v))
209
- ys = r*np.outer(np.sin(u), np.sin(v))
210
- zs = r*np.outer(np.ones_like(u), np.cos(v))
211
- sphere = go.Surface(x=xs, y=ys, z=zs, opacity=0.06, showscale=False,
212
- colorscale=[[0, SPHERE],[1, SPHERE]], hoverinfo="skip")
213
- layout = go.Layout(
214
- paper_bgcolor=BG, plot_bgcolor=BG,
215
- title="Architecture Sphere (idle)", titlefont=dict(color="#E5E7EB"),
216
- scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), bgcolor=BG),
217
- margin=dict(l=0, r=0, t=36, b=0), showlegend=False, height=720,
218
- font=dict(family="Inter, Arial, sans-serif", size=14, color="#E5E7EB")
219
- )
220
- return go.Figure(data=[sphere], layout=layout)
221
-
222
- def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
223
- # single-color dots for a sober look
224
- custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
225
- g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
226
- for g in genomes], dtype=np.float32)
227
- scatter = go.Scatter3d(
228
- x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
229
- mode='markers',
230
- marker=dict(size=7.2, color=DOT, opacity=0.92),
231
- customdata=custom,
232
- hovertemplate=(
233
- "<b>Genome</b><br>"
234
- "d_model=%{customdata[0]:.0f} Β· layers=%{customdata[1]:.0f} Β· heads=%{customdata[2]:.0f}<br>"
235
- "ffn_mult=%{customdata[3]:.1f} Β· mem=%{customdata[4]:.0f} Β· drop=%{customdata[5]:.2f}<br>"
236
- "species=%{customdata[6]:.0f}<br>"
237
- "fitness=%{customdata[7]:.4f}<br>"
238
- "accuracy=%{customdata[8]:.3f}<extra></extra>"
239
- )
240
- )
241
- idle = make_idle_sphere()
242
- layout = idle.layout.update(title=f"Evo Architecture Sphere β€” Gen {gen_idx}")
243
- fig = go.Figure(data=idle.data + (scatter,), layout=layout)
244
- return fig
245
-
246
- def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
247
- xs = [h[0] for h in history]
248
- if metric == "Accuracy":
249
- ys = [h[2] if (h[2] == h[2]) else None for h in history]
250
- title, ylab = "Best Accuracy per Generation", "Accuracy"
251
- else:
252
- ys = [h[1] for h in history]
253
- title, ylab = "Best Fitness per Generation", "Fitness (↓ better)"
254
- fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers", line=dict(width=2), marker=dict(color=DOT))])
255
- fig.update_layout(
256
- paper_bgcolor=BG, plot_bgcolor=BG, font=dict(color="#E5E7EB"),
257
- title=title, xaxis_title="Generation", yaxis_title=ylab,
258
- margin=dict(l=30, r=10, t=36, b=30), height=340
259
- )
260
- fig.update_xaxes(gridcolor="#1f2b36"); fig.update_yaxes(gridcolor="#1f2b36")
261
- return fig
262
-
263
- def fig_to_html(fig: go.Figure) -> str:
264
- return pio.to_html(fig, include_plotlyjs=True, full_html=False, config=dict(displaylogo=False))
265
-
266
- def approx_params(g: Genome) -> int:
267
- per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
268
- total = per_layer * g.n_layers + 1000 * g.memory_tokens
269
- return int(total)
270
-
271
- # =========================
272
- # RUNNER
273
- # =========================
274
- class EvoRunner:
275
- def __init__(self):
276
- self.lock = threading.Lock()
277
- self.running = False
278
- self.stop_flag = False
279
- self.state: Dict[str, Any] = {}
280
- # seed the idle sphere immediately
281
- idle = fig_to_html(make_idle_sphere())
282
- self.state = {"sphere_html": idle, "history_html": fig_to_html(make_history_figure([], "Accuracy")),
283
- "top": [], "best": {}, "gen": 0, "dataset": "Demo (Surrogate)", "metric": "Accuracy"}
284
-
285
- def run(self, dataset, pop_size, generations, mutation_rate, explore, exploit, seed, pace_ms, metric_choice):
286
- rng = random.Random(int(seed))
287
- self.stop_flag = False
288
- self.running = True
289
-
290
- pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
291
- for g in pop:
292
- fit, acc = evaluate_genome(g, dataset, explore)
293
- g.fitness, g.acc = fit, acc
294
-
295
- history: List[Tuple[int,float,float]] = []
296
-
297
- for gen in range(1, generations+1):
298
- if self.stop_flag: break
299
-
300
- k = max(2, int(2 + exploit * 5))
301
- parents = [min(rng.sample(pop, k=k), key=lambda x: x.fitness) for _ in range(pop_size)]
302
-
303
- children = []
304
- for i in range(0, pop_size, 2):
305
- a = parents[i]; b = parents[(i+1) % pop_size]
306
- child1 = mutate(crossover(a,b,rng), rng, mutation_rate)
307
- child2 = mutate(crossover(b,a,rng), rng, mutation_rate)
308
- children.extend([child1, child2])
309
- children = children[:pop_size]
310
-
311
- for c in children:
312
- fit, acc = evaluate_genome(c, dataset, explore)
313
- c.fitness, c.acc = fit, acc
314
-
315
- elite_n = max(1, pop_size // 10)
316
- elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
317
- pop = sorted(children, key=lambda x: x.fitness)
318
- pop[-elite_n:] = elites
319
-
320
- best = min(pop, key=lambda x: x.fitness)
321
- history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
322
-
323
- P = np.stack([g.vector() for g in pop], axis=0)
324
- P3 = sphere_project(P)
325
- sphere_fig = make_sphere_figure(P3, pop, gen)
326
- hist_fig = make_history_figure(history, metric_choice)
327
-
328
- top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
329
- top_table = [{
330
- "gen": gen, "fitness": round(t.fitness, 4),
331
- "accuracy": (None if t.acc is None else round(float(t.acc), 4)),
332
- "d_model": t.d_model, "layers": t.n_layers, "heads": t.n_heads,
333
- "ffn_mult": t.ffn_mult, "mem": t.memory_tokens, "dropout": t.dropout,
334
- "params_approx": approx_params(t)
335
- } for t in top]
336
- best_card = top_table[0] if top_table else {}
337
-
338
- with self.lock:
339
- self.state = {
340
- "sphere_html": fig_to_html(sphere_fig),
341
- "history_html": fig_to_html(hist_fig),
342
- "top": top_table,
343
- "best": best_card,
344
- "gen": gen,
345
- "dataset": dataset,
346
- "metric": metric_choice
347
- }
348
-
349
- time.sleep(max(0.0, pace_ms/1000.0))
350
- self.running = False
351
-
352
- def start(self, *args, **kwargs):
353
- if self.running: return
354
- t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
355
- t.start()
356
-
357
- def stop(self): self.stop_flag = True
358
-
359
- def clear(self):
360
- # stop and reset to idle sphere
361
- self.stop_flag = True
362
- idle = fig_to_html(make_idle_sphere())
363
- with self.lock:
364
- self.running = False
365
- self.state = {"sphere_html": idle, "history_html": fig_to_html(make_history_figure([], "Accuracy")),
366
- "top": [], "best": {}, "gen": 0, "dataset": "Demo (Surrogate)", "metric": "Accuracy"}
367
-
368
- runner = EvoRunner()
369
-
370
  # =========================
371
- # UI CALLBACKS
372
- # =========================
373
- def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
374
- runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
375
- return (gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False))
376
-
377
- def stop_evo():
378
- runner.stop()
379
- return (gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True))
380
-
381
- def clear_evo():
382
- runner.clear()
383
- # return updated visuals + reset buttons
384
- sphere_html, history_html, stats_md, df = poll_state()
385
- return sphere_html, history_html, stats_md, df, gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True)
386
-
387
- def poll_state():
388
- with runner.lock:
389
- s = runner.state.copy()
390
- sphere_html = s.get("sphere_html", fig_to_html(make_idle_sphere()))
391
- history_html = s.get("history_html", fig_to_html(make_history_figure([], "Accuracy")))
392
- best = s.get("best", {})
393
- gen = s.get("gen", 0)
394
- dataset = s.get("dataset", "Demo (Surrogate)")
395
- top = s.get("top", [])
396
- if best:
397
- acc_txt = "β€”" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
398
- stats_md = (
399
- f"**Dataset:** {dataset} \n"
400
- f"**Generation:** {gen} \n"
401
- f"**Best fitness:** {best.get('fitness','–')} \n"
402
- f"**Best accuracy:** {acc_txt} \n"
403
- f"**Config:** d_model={best.get('d_model')} Β· layers={best.get('layers')} Β· "
404
- f"heads={best.get('heads')} Β· ffn_mult={best.get('ffn_mult')} Β· mem={best.get('mem')} Β· "
405
- f"dropout={best.get('dropout')} \n"
406
- f"**~Params (rough):** {best.get('params_approx'):,}"
407
- )
408
- else:
409
- stats_md = "Ready. Press **Start** to evolve, or **Clear** anytime."
410
- df = pd.DataFrame(top)
411
- return sphere_html, history_html, stats_md, df
412
-
413
- def export_snapshot():
414
- from json import dumps
415
- with runner.lock:
416
- payload = dumps(runner.state, default=lambda o: o, indent=2)
417
- path = "evo_snapshot.json"
418
- with open(path, "w", encoding="utf-8") as f:
419
- f.write(payload)
420
- return path
421
-
422
- # =========================
423
- # BUILD UI
424
- # =========================
425
- with gr.Blocks(css=CUSTOM_CSS) as demo:
426
  with gr.Column(elem_id="header"):
427
- gr.Markdown("### Evo Playground β€” Live Evolution (clean dark)")
428
-
 
 
429
  with gr.Row():
430
- with gr.Column(scale=1, elem_classes=["controls"]):
431
- with gr.Group():
432
- dataset = gr.Dropdown(
433
- label="Dataset",
434
- choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
435
- value="Demo (Surrogate)",
436
- info="PIQA/HellaSwag compute proxy accuracy; Demo is a fast surrogate."
437
- )
438
- pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
439
- gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
440
- mut = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Mutation rate")
441
- with gr.Row():
442
- explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration")
443
- exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation")
444
- seed = gr.Number(value=42, label="Seed", precision=0)
445
- pace = gr.Slider(0, 1000, value=120, step=10, label="Pace (ms)")
446
- metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy", label="History Metric")
447
-
448
- with gr.Row():
449
- start = gr.Button("β–Ά Start", variant="primary")
450
- stop = gr.Button("⏹ Stop", variant="secondary", interactive=False)
451
- clear = gr.Button("β†Ί Clear", variant="secondary")
452
-
453
- with gr.Group(elem_classes=["panel"]):
454
- stats_md = gr.Markdown("Ready. Press **Start** to evolve, or **Clear** anytime.", elem_id="stats")
455
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  with gr.Group(elem_classes=["panel"]):
457
- export_btn = gr.Button("Export Snapshot (JSON)")
458
- export_file = gr.File(label="Download snapshot", visible=False)
459
-
 
 
 
460
  with gr.Column(scale=2):
461
- with gr.Group(elem_classes=["panel"]):
 
 
 
 
 
462
  sphere_html = gr.HTML()
463
- with gr.Group(elem_classes=["panel"]):
 
 
 
 
464
  hist_html = gr.HTML()
 
 
465
  with gr.Group(elem_classes=["panel"]):
466
- top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
467
-
468
- # wiring
469
- start.click(start_evo, [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice], [start, stop, clear])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  stop.click(stop_evo, [], [start, stop, clear])
471
- clear.click(clear_evo, [], [sphere_html, hist_html, stats_md, top_df, start, stop, clear])
 
 
 
 
472
  export_btn.click(export_snapshot, [], [export_file])
473
-
474
- # initial paint + polling
475
- demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
476
- gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
477
 
478
  if __name__ == "__main__":
479
- demo.launch()
480
-
481
- ##
 
1
+ # app.py β€” Enhanced UI with better layout, visual hierarchy, and UX
2
+ # ... [All your imports and backend code remain the same] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # =========================
5
+ # ENHANCED CSS
6
  # =========================
7
+ ENHANCED_CSS = """
8
+ :root {
9
+ --radius: 14px;
10
+ --fg: #E5E7EB;
11
+ --muted: #94A3B8;
12
+ --line: #1f2b36;
13
+ --bg: #0F1A24;
14
+ --panel-bg: #0c161f;
15
+ --accent: #3B82F6;
16
+ --accent-hover: #2563EB;
17
+ --danger: #EF4444;
18
+ }
19
+
20
+ .gradio-container {
21
+ max-width: 1400px !important;
22
+ background: var(--bg);
23
+ padding: 16px !important;
24
+ }
25
+
26
+ #header {
27
+ padding: 16px 0;
28
+ margin-bottom: 16px;
29
+ border-bottom: 1px solid var(--line);
30
+ }
31
+
32
+ h1, h2, h3, .gr-markdown {
33
+ color: var(--fg);
34
+ }
35
+
36
+ .gr-button {
37
+ border-radius: 8px;
38
+ padding: 8px 16px;
39
+ transition: all 0.2s ease;
40
+ font-weight: 500 !important;
41
+ }
42
+
43
+ .btn-primary {
44
+ background: var(--accent) !important;
45
+ border: 1px solid var(--accent) !important;
46
+ }
47
+
48
+ .btn-primary:hover {
49
+ background: var(--accent-hover) !important;
50
+ }
51
+
52
+ .btn-secondary {
53
+ background: transparent !important;
54
+ border: 1px solid var(--line) !important;
55
+ }
56
+
57
+ .btn-danger {
58
+ background: var(--danger) !important;
59
+ border: 1px solid var(--danger) !important;
60
+ }
61
+
62
+ .control-group {
63
+ border: 1px solid var(--line);
64
+ border-radius: var(--radius);
65
+ background: var(--panel-bg);
66
+ padding: 20px;
67
+ margin-bottom: 20px;
68
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
69
+ }
70
+
71
+ .panel {
72
+ border: 1px solid var(--line);
73
+ border-radius: var(--radius);
74
+ background: var(--panel-bg);
75
+ padding: 20px;
76
+ margin-bottom: 20px;
77
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
78
+ }
79
+
80
+ .stats-panel {
81
+ background: linear-gradient(145deg, #0a121b, #0c161f);
82
+ border-left: 3px solid var(--accent);
83
+ }
84
+
85
+ #stats {
86
+ color: var(--fg);
87
+ line-height: 1.6;
88
+ }
89
+
90
+ #stats strong {
91
+ font-weight: 500;
92
+ color: var(--accent);
93
+ }
94
+
95
+ .param-slider {
96
+ margin-bottom: 12px;
97
+ }
98
+
99
+ .visualization-container {
100
+ display: flex;
101
+ flex-direction: column;
102
+ gap: 20px;
103
+ height: 100%;
104
+ }
105
+
106
+ .viz-panel {
107
+ flex: 1;
108
+ min-height: 300px;
109
+ }
110
+
111
+ .viz-header {
112
+ display: flex;
113
+ justify-content: space-between;
114
+ align-items: center;
115
+ margin-bottom: 12px;
116
+ padding-bottom: 8px;
117
+ border-bottom: 1px solid var(--line);
118
+ }
119
+
120
+ .viz-title {
121
+ font-size: 1.1rem;
122
+ font-weight: 500;
123
+ color: var(--accent);
124
+ }
125
+
126
+ .gen-counter {
127
+ font-size: 0.9rem;
128
+ background: rgba(59, 130, 246, 0.15);
129
+ padding: 4px 10px;
130
+ border-radius: 12px;
131
+ }
132
+
133
+ .slider-info {
134
+ display: flex;
135
+ justify-content: space-between;
136
+ font-size: 0.85rem;
137
+ color: var(--muted);
138
+ margin-top: 4px;
139
+ }
140
+
141
+ .controls-grid {
142
+ display: grid;
143
+ grid-template-columns: 1fr 1fr;
144
+ gap: 16px;
145
+ }
146
+
147
+ @media (max-width: 1200px) {
148
+ .controls-grid {
149
+ grid-template-columns: 1fr;
150
+ }
151
+ }
152
+
153
+ .data-table {
154
+ max-height: 400px;
155
+ overflow-y: auto;
156
+ }
157
+
158
+ .data-table table {
159
+ width: 100%;
160
+ border-collapse: collapse;
161
+ }
162
+
163
+ .data-table th {
164
+ background: rgba(15, 26, 36, 0.8);
165
+ position: sticky;
166
+ top: 0;
167
+ text-align: left;
168
+ padding: 10px 12px;
169
+ font-weight: 500;
170
+ color: var(--accent);
171
+ border-bottom: 1px solid var(--line);
172
+ }
173
+
174
+ .data-table td {
175
+ padding: 8px 12px;
176
+ border-bottom: 1px solid rgba(31, 43, 54, 0.5);
177
+ }
178
+
179
+ .data-table tr:hover {
180
+ background: rgba(31, 43, 54, 0.3);
181
+ }
182
+
183
+ .action-buttons {
184
+ display: flex;
185
+ gap: 12px;
186
+ margin-top: 20px;
187
+ }
188
+
189
+ .footer {
190
+ margin-top: 20px;
191
+ padding-top: 20px;
192
+ border-top: 1px solid var(--line);
193
+ font-size: 0.85rem;
194
+ color: var(--muted);
195
+ text-align: center;
196
+ }
197
  """
198
 
199
+ # ... [All your backend code remains the same] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  # =========================
202
+ # BUILD ENHANCED UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  # =========================
204
+ with gr.Blocks(css=ENHANCED_CSS, theme=gr.themes.Default()) as demo:
205
+ # Header
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  with gr.Column(elem_id="header"):
207
+ gr.Markdown("## 🧬 Neuroevolution Playground", elem_classes=["header-title"])
208
+ gr.Markdown("Evolve neural architectures using genetic algorithms",
209
+ elem_classes=["header-subtitle"])
210
+
211
  with gr.Row():
212
+ # Left Panel - Controls
213
+ with gr.Column(scale=1):
214
+ # Parameters Group
215
+ with gr.Group(elem_classes=["control-group"]):
216
+ gr.Markdown("### πŸ›  Evolution Parameters")
217
+
218
+ with gr.Column():
219
+ dataset = gr.Dropdown(
220
+ label="Evaluation Dataset",
221
+ choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
222
+ value="Demo (Surrogate)",
223
+ info="Dataset used for fitness evaluation"
224
+ )
225
+
226
+ with gr.Row():
227
+ with gr.Column():
228
+ pop = gr.Slider(8, 80, value=24, step=2, label="Population Size",
229
+ elem_classes=["param-slider"])
230
+ gens = gr.Slider(5, 200, value=60, step=1, label="Max Generations",
231
+ elem_classes=["param-slider"])
232
+ mut = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Mutation Rate",
233
+ elem_classes=["param-slider"])
234
+ with gr.Column():
235
+ explore = gr.Slider(0.0, 1.0, value=0.35, step=0.05, label="Exploration",
236
+ elem_classes=["param-slider"])
237
+ exploit = gr.Slider(0.0, 1.0, value=0.65, step=0.05, label="Exploitation",
238
+ elem_classes=["param-slider"])
239
+ seed = gr.Number(value=42, label="Random Seed", precision=0)
240
+
241
+ pace = gr.Slider(0, 1000, value=120, step=10, label="Simulation Speed (ms)",
242
+ elem_classes=["param-slider"])
243
+ metric_choice = gr.Radio(choices=["Accuracy", "Fitness"], value="Accuracy",
244
+ label="History Metric Display")
245
+
246
+ # Status Panel
247
+ with gr.Group(elem_classes=["panel", "stats-panel"]):
248
+ gr.Markdown("### πŸ“Š Current Status")
249
+ stats_md = gr.Markdown("Ready. Press **Start** to begin evolution.", elem_id="stats")
250
+
251
+ # Action Buttons
252
+ with gr.Row(elem_classes=["action-buttons"]):
253
+ start = gr.Button("β–Ά Start Evolution", variant="primary", elem_classes=["btn-primary"])
254
+ stop = gr.Button("⏹ Stop", variant="stop", elem_classes=["btn-danger"], interactive=False)
255
+ clear = gr.Button("↻ Reset", elem_classes=["btn-secondary"])
256
+
257
+ # Export
258
  with gr.Group(elem_classes=["panel"]):
259
+ gr.Markdown("### πŸ’Ύ Export Results")
260
+ with gr.Row():
261
+ export_btn = gr.Button("Save Snapshot (JSON)")
262
+ export_file = gr.File(label="Download snapshot", visible=False)
263
+
264
+ # Right Panel - Visualizations
265
  with gr.Column(scale=2):
266
+ # 3D Visualization
267
+ with gr.Group(elem_classes=["panel", "viz-panel"]):
268
+ with gr.Column(elem_classes=["viz-header"]):
269
+ with gr.Row():
270
+ gr.Markdown("### 🌐 Architecture Space", elem_classes=["viz-title"])
271
+ gen_counter = gr.Markdown("", elem_classes=["gen-counter"])
272
  sphere_html = gr.HTML()
273
+
274
+ # History Visualization
275
+ with gr.Group(elem_classes=["panel", "viz-panel"]):
276
+ with gr.Column(elem_classes=["viz-header"]):
277
+ gr.Markdown("### πŸ“ˆ Performance History", elem_classes=["viz-title"])
278
  hist_html = gr.HTML()
279
+
280
+ # Results Table
281
  with gr.Group(elem_classes=["panel"]):
282
+ gr.Markdown("### πŸ† Top Genomes")
283
+ with gr.Column(elem_classes=["data-table"]):
284
+ top_df = gr.Dataframe(
285
+ label="",
286
+ headers=["Fitness", "Accuracy", "d_model", "Layers", "Heads", "FFN", "Mem", "Dropout", "Params"],
287
+ datatype=["number", "number", "number", "number", "number", "number", "number", "number", "number"],
288
+ wrap=True,
289
+ interactive=False
290
+ )
291
+
292
+ # Footer
293
+ with gr.Column(elem_classes=["footer"]):
294
+ gr.Markdown("Evotransformer Playground v1.0 β€’ Using Plotly and Gradio")
295
+
296
+ # Wiring
297
+ start.click(
298
+ start_evo,
299
+ [dataset, pop, gens, mut, explore, exploit, seed, pace, metric_choice],
300
+ [start, stop, clear]
301
+ )
302
  stop.click(stop_evo, [], [start, stop, clear])
303
+ clear.click(
304
+ clear_evo,
305
+ [],
306
+ [sphere_html, hist_html, stats_md, top_df, start, stop, clear]
307
+ )
308
  export_btn.click(export_snapshot, [], [export_file])
309
+
310
+ # State polling
311
+ demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df, gen_counter])
312
+ gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df, gen_counter])
313
 
314
  if __name__ == "__main__":
315
+ demo.launch()