Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py
|
2 |
import math, json, random, time, threading
|
3 |
from dataclasses import dataclass, asdict
|
4 |
from typing import List, Tuple, Dict, Any, Optional
|
@@ -6,33 +6,32 @@ from functools import lru_cache
|
|
6 |
|
7 |
import numpy as np
|
8 |
import plotly.graph_objs as go
|
|
|
9 |
import gradio as gr
|
10 |
import pandas as pd
|
11 |
|
12 |
-
# Proxy fitness deps
|
13 |
import torch
|
14 |
import torch.nn as nn
|
15 |
import torch.optim as optim
|
16 |
|
17 |
from data_utils import load_piqa, load_hellaswag, hash_vectorize
|
18 |
|
19 |
-
#
|
20 |
-
# UX THEME & STYLES (cleaner, pro)
|
21 |
-
# =========================
|
22 |
CUSTOM_CSS = """
|
23 |
-
:root { --radius
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
}
|
29 |
-
.gr-
|
30 |
-
|
|
|
|
|
|
|
31 |
"""
|
32 |
|
33 |
-
#
|
34 |
-
# GENOME & EVOLUTION CORE
|
35 |
-
# =========================
|
36 |
@dataclass
|
37 |
class Genome:
|
38 |
d_model: int
|
@@ -43,7 +42,7 @@ class Genome:
|
|
43 |
dropout: float
|
44 |
species: int = 0
|
45 |
fitness: float = float("inf")
|
46 |
-
acc: Optional[float] = None
|
47 |
|
48 |
def vector(self) -> np.ndarray:
|
49 |
return np.array([
|
@@ -91,9 +90,7 @@ def crossover(a: Genome, b: Genome, rng: random.Random) -> Genome:
|
|
91 |
acc = None
|
92 |
)
|
93 |
|
94 |
-
#
|
95 |
-
# PROXY FITNESS (Phase 2a)
|
96 |
-
# =========================
|
97 |
def rastrigin(x: np.ndarray) -> float:
|
98 |
A, n = 10.0, x.shape[0]
|
99 |
return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
|
@@ -116,14 +113,11 @@ def _cached_dataset(name: str):
|
|
116 |
if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
|
117 |
return None
|
118 |
|
119 |
-
def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu")
|
120 |
-
"""Returns (fitness, accuracy or None)."""
|
121 |
data = _cached_dataset(dataset_name)
|
122 |
if data is None:
|
123 |
-
# Demo path handled elsewhere
|
124 |
v = genome.vector() * 2 - 1
|
125 |
return float(rastrigin(v)), None
|
126 |
-
|
127 |
Xtr_txt, ytr, Xva_txt, yva = data
|
128 |
nfeat = 4096
|
129 |
Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
|
@@ -139,8 +133,7 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
|
|
139 |
lossf = nn.BCEWithLogitsLoss()
|
140 |
|
141 |
model.train()
|
142 |
-
steps, bs = 120, 256
|
143 |
-
N = Xtr_t.size(0)
|
144 |
for _ in range(steps):
|
145 |
idx = torch.randint(0, N, (bs,))
|
146 |
xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
|
@@ -169,7 +162,7 @@ def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device:
|
|
169 |
fitness = (1.0 - acc) + parsimony + noise
|
170 |
return float(max(0.0, min(1.5, fitness))), float(acc)
|
171 |
|
172 |
-
def evaluate_genome(genome: Genome, dataset: str, explore: float)
|
173 |
if dataset == "Demo (Surrogate)":
|
174 |
v = genome.vector() * 2 - 1
|
175 |
base = rastrigin(v)
|
@@ -180,23 +173,30 @@ def evaluate_genome(genome: Genome, dataset: str, explore: float) -> Tuple[float
|
|
180 |
return _train_eval_proxy(genome, "PIQA", explore)
|
181 |
if dataset.startswith("HellaSwag"):
|
182 |
return _train_eval_proxy(genome, "HellaSwag", explore)
|
183 |
-
# fallback
|
184 |
v = genome.vector() * 2 - 1
|
185 |
return float(rastrigin(v)), None
|
186 |
|
187 |
-
#
|
188 |
-
|
189 |
-
|
|
|
190 |
def sphere_project(points: np.ndarray) -> np.ndarray:
|
191 |
rng = np.random.RandomState(42)
|
192 |
W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
|
193 |
Y = points @ W
|
194 |
norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
|
195 |
-
return (Y / norms) * 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
|
198 |
species = np.array([g.species for g in genomes])
|
199 |
-
|
200 |
custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
|
201 |
g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
|
202 |
for g in genomes], dtype=np.float32)
|
@@ -204,11 +204,11 @@ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int
|
|
204 |
scatter = go.Scatter3d(
|
205 |
x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
|
206 |
mode='markers',
|
207 |
-
marker=dict(size=
|
208 |
customdata=custom,
|
209 |
hovertemplate=(
|
210 |
-
"
|
211 |
-
"layers=%{customdata[1]:.0f} · heads=%{customdata[2]:.0f}<br>"
|
212 |
"ffn_mult=%{customdata[3]:.1f} · mem=%{customdata[4]:.0f} · drop=%{customdata[5]:.2f}<br>"
|
213 |
"species=%{customdata[6]:.0f}<br>"
|
214 |
"fitness=%{customdata[7]:.4f}<br>"
|
@@ -216,46 +216,63 @@ def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int
|
|
216 |
)
|
217 |
)
|
218 |
|
219 |
-
#
|
220 |
-
u = np.linspace(0, 2*np.pi,
|
221 |
-
v = np.linspace(0, np.pi,
|
222 |
-
r = 1.
|
223 |
xs = r*np.outer(np.cos(u), np.sin(v))
|
224 |
ys = r*np.outer(np.sin(u), np.sin(v))
|
225 |
zs = r*np.outer(np.ones_like(u), np.cos(v))
|
226 |
-
sphere = go.Surface(
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
layout = go.Layout(
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
232 |
showlegend=False,
|
233 |
-
height=
|
|
|
234 |
)
|
235 |
return go.Figure(data=[sphere, scatter], layout=layout)
|
236 |
|
237 |
def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
|
238 |
-
# history items: (gen, best_fitness, best_acc or NaN)
|
239 |
xs = [h[0] for h in history]
|
240 |
if metric == "Accuracy":
|
241 |
-
ys = [h[2] if (h[2] == h[2]) else None for h in history]
|
242 |
title, ylab = "Best Accuracy per Generation", "Accuracy"
|
243 |
else:
|
244 |
ys = [h[1] for h in history]
|
245 |
-
title, ylab = "Best Fitness per Generation", "Fitness (
|
246 |
-
fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers")])
|
247 |
-
fig.update_layout(
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
249 |
return fig
|
250 |
|
|
|
|
|
|
|
|
|
251 |
def approx_params(g: Genome) -> int:
|
252 |
per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
|
253 |
total = per_layer * g.n_layers + 1000 * g.memory_tokens
|
254 |
return int(total)
|
255 |
|
256 |
-
#
|
257 |
-
# ORCHESTRATOR
|
258 |
-
# =========================
|
259 |
class EvoRunner:
|
260 |
def __init__(self):
|
261 |
self.lock = threading.Lock()
|
@@ -269,25 +286,22 @@ class EvoRunner:
|
|
269 |
self.running = True
|
270 |
|
271 |
pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
|
272 |
-
# initial eval
|
273 |
for g in pop:
|
274 |
fit, acc = evaluate_genome(g, dataset, explore)
|
275 |
g.fitness, g.acc = fit, acc
|
276 |
|
277 |
-
history: List[Tuple[int,float,float]] = []
|
278 |
best_overall: Optional[Genome] = None
|
279 |
|
280 |
for gen in range(1, generations+1):
|
281 |
if self.stop_flag: break
|
282 |
|
283 |
-
# Selection (tournament)
|
284 |
k = max(2, int(2 + exploit * 5))
|
285 |
parents = []
|
286 |
for _ in range(pop_size):
|
287 |
sample = rng.sample(pop, k=k)
|
288 |
parents.append(min(sample, key=lambda x: x.fitness))
|
289 |
|
290 |
-
# Reproduce
|
291 |
children = []
|
292 |
for i in range(0, pop_size, 2):
|
293 |
a = parents[i]; b = parents[(i+1) % pop_size]
|
@@ -296,30 +310,25 @@ class EvoRunner:
|
|
296 |
children.extend([child1, child2])
|
297 |
children = children[:pop_size]
|
298 |
|
299 |
-
# Evaluate children
|
300 |
for c in children:
|
301 |
fit, acc = evaluate_genome(c, dataset, explore)
|
302 |
c.fitness, c.acc = fit, acc
|
303 |
|
304 |
-
# Elitism
|
305 |
elite_n = max(1, pop_size // 10)
|
306 |
elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
|
307 |
-
|
308 |
-
# Next pop
|
309 |
pop = sorted(children, key=lambda x: x.fitness)
|
310 |
pop[-elite_n:] = elites
|
311 |
|
312 |
best = min(pop, key=lambda x: x.fitness)
|
313 |
-
if best_overall is None or best.fitness < best_overall.fitness:
|
314 |
-
best_overall = best
|
315 |
|
316 |
history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
|
317 |
|
318 |
-
# Viz snapshot
|
319 |
P = np.stack([g.vector() for g in pop], axis=0)
|
320 |
P3 = sphere_project(P)
|
321 |
sphere_fig = make_sphere_figure(P3, pop, gen)
|
322 |
hist_fig = make_history_figure(history, metric_choice)
|
|
|
323 |
top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
|
324 |
top_table = [
|
325 |
{
|
@@ -340,8 +349,8 @@ class EvoRunner:
|
|
340 |
|
341 |
with self.lock:
|
342 |
self.state = {
|
343 |
-
"
|
344 |
-
"
|
345 |
"top": top_table,
|
346 |
"best": best_card,
|
347 |
"gen": gen,
|
@@ -350,21 +359,17 @@ class EvoRunner:
|
|
350 |
}
|
351 |
|
352 |
time.sleep(max(0.0, pace_ms/1000.0))
|
353 |
-
|
354 |
self.running = False
|
355 |
|
356 |
def start(self, *args, **kwargs):
|
357 |
if self.running: return
|
358 |
t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
|
359 |
t.start()
|
360 |
-
|
361 |
def stop(self): self.stop_flag = True
|
362 |
|
363 |
runner = EvoRunner()
|
364 |
|
365 |
-
#
|
366 |
-
# UI CALLBACKS
|
367 |
-
# =========================
|
368 |
def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
|
369 |
runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
|
370 |
return (gr.update(interactive=False), gr.update(interactive=True))
|
@@ -376,13 +381,12 @@ def stop_evo():
|
|
376 |
def poll_state():
|
377 |
with runner.lock:
|
378 |
s = runner.state.copy()
|
379 |
-
|
380 |
-
|
381 |
best = s.get("best", {})
|
382 |
gen = s.get("gen", 0)
|
383 |
dataset = s.get("dataset", "Demo (Surrogate)")
|
384 |
top = s.get("top", [])
|
385 |
-
# Stats text
|
386 |
if best:
|
387 |
acc_txt = "—" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
|
388 |
stats_md = (
|
@@ -398,7 +402,7 @@ def poll_state():
|
|
398 |
else:
|
399 |
stats_md = "Waiting… click **Start Evolution**."
|
400 |
df = pd.DataFrame(top)
|
401 |
-
return
|
402 |
|
403 |
def export_snapshot():
|
404 |
from json import dumps
|
@@ -409,25 +413,19 @@ def export_snapshot():
|
|
409 |
f.write(payload)
|
410 |
return path
|
411 |
|
412 |
-
#
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
with gr.Column(elem_id="header-card"):
|
417 |
-
gr.Markdown(
|
418 |
-
"# Evo Playground — Live Evolution of Transformer Architectures\n"
|
419 |
-
"Tune the search, watch the population converge, and track **accuracy** in real time (PIQA/HellaSwag)."
|
420 |
-
)
|
421 |
|
422 |
with gr.Row():
|
423 |
-
|
424 |
-
with gr.Column(scale=1):
|
425 |
with gr.Group():
|
426 |
dataset = gr.Dropdown(
|
427 |
label="Dataset",
|
428 |
choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
|
429 |
value="Demo (Surrogate)",
|
430 |
-
info="PIQA/HellaSwag compute real proxy accuracy; Demo
|
431 |
)
|
432 |
pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
|
433 |
gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
|
@@ -442,18 +440,19 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
|
442 |
start = gr.Button("▶ Start Evolution", variant="primary")
|
443 |
stop = gr.Button("⏹ Stop", variant="secondary")
|
444 |
|
445 |
-
with gr.Group(
|
446 |
-
stats_md = gr.Markdown("Waiting…", elem_id="stats
|
|
|
|
|
447 |
export_btn = gr.Button("Export Snapshot (JSON)")
|
448 |
export_file = gr.File(label="Download snapshot", visible=False)
|
449 |
|
450 |
-
# RIGHT: Viz + Table
|
451 |
with gr.Column(scale=2):
|
452 |
-
with gr.Group(
|
453 |
-
|
454 |
-
with gr.Group(
|
455 |
-
|
456 |
-
with gr.Group(
|
457 |
top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
|
458 |
|
459 |
# Wiring
|
@@ -461,12 +460,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
|
461 |
stop.click(stop_evo, [], [start, stop])
|
462 |
export_btn.click(export_snapshot, [], [export_file])
|
463 |
|
464 |
-
# Initial paint
|
465 |
-
demo.load(poll_state, None, [
|
466 |
-
|
467 |
-
# Continuous polling (every 0.7s)
|
468 |
-
poller = gr.Timer(0.7)
|
469 |
-
poller.tick(poll_state, None, [sphere_plot, hist_plot, stats_md, top_df])
|
470 |
|
471 |
if __name__ == "__main__":
|
472 |
demo.launch()
|
|
|
1 |
+
# app.py — Minimal, pro UI with big transparent sphere and clean hover
|
2 |
import math, json, random, time, threading
|
3 |
from dataclasses import dataclass, asdict
|
4 |
from typing import List, Tuple, Dict, Any, Optional
|
|
|
6 |
|
7 |
import numpy as np
|
8 |
import plotly.graph_objs as go
|
9 |
+
import plotly.io as pio
|
10 |
import gradio as gr
|
11 |
import pandas as pd
|
12 |
|
|
|
13 |
import torch
|
14 |
import torch.nn as nn
|
15 |
import torch.optim as optim
|
16 |
|
17 |
from data_utils import load_piqa, load_hellaswag, hash_vectorize
|
18 |
|
19 |
+
# ---------- Minimal style ----------
|
|
|
|
|
20 |
CUSTOM_CSS = """
|
21 |
+
:root { --radius: 14px; --fg:#0f172a; --muted:#64748b; --line:#e5e7eb; }
|
22 |
+
* { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Arial, "Noto Sans", "Apple Color Emoji", "Segoe UI Emoji"; }
|
23 |
+
.gradio-container { max-width: 1180px !important; }
|
24 |
+
#header { border-radius: var(--radius); padding: 8px 6px; }
|
25 |
+
h1, h2, h3, .gr-markdown { color: var(--fg); }
|
26 |
+
.gr-button { border-radius: 10px; }
|
27 |
+
.controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
|
28 |
+
.panel { padding: 10px; }
|
29 |
+
#stats { font-weight: 300; color: var(--fg); }
|
30 |
+
#stats strong { font-weight: 500; }
|
31 |
+
.small { font-size: 13px; color: var(--muted); }
|
32 |
"""
|
33 |
|
34 |
+
# ---------- Genome ----------
|
|
|
|
|
35 |
@dataclass
|
36 |
class Genome:
|
37 |
d_model: int
|
|
|
42 |
dropout: float
|
43 |
species: int = 0
|
44 |
fitness: float = float("inf")
|
45 |
+
acc: Optional[float] = None
|
46 |
|
47 |
def vector(self) -> np.ndarray:
|
48 |
return np.array([
|
|
|
90 |
acc = None
|
91 |
)
|
92 |
|
93 |
+
# ---------- Proxy fitness ----------
|
|
|
|
|
94 |
def rastrigin(x: np.ndarray) -> float:
|
95 |
A, n = 10.0, x.shape[0]
|
96 |
return A * n + np.sum(x**2 - A * np.cos(2 * math.pi * x))
|
|
|
113 |
if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
|
114 |
return None
|
115 |
|
116 |
+
def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
|
|
|
117 |
data = _cached_dataset(dataset_name)
|
118 |
if data is None:
|
|
|
119 |
v = genome.vector() * 2 - 1
|
120 |
return float(rastrigin(v)), None
|
|
|
121 |
Xtr_txt, ytr, Xva_txt, yva = data
|
122 |
nfeat = 4096
|
123 |
Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
|
|
|
133 |
lossf = nn.BCEWithLogitsLoss()
|
134 |
|
135 |
model.train()
|
136 |
+
steps, bs, N = 120, 256, Xtr_t.size(0)
|
|
|
137 |
for _ in range(steps):
|
138 |
idx = torch.randint(0, N, (bs,))
|
139 |
xb = Xtr_t[idx].to(device); yb = ytr_t[idx].to(device)
|
|
|
162 |
fitness = (1.0 - acc) + parsimony + noise
|
163 |
return float(max(0.0, min(1.5, fitness))), float(acc)
|
164 |
|
165 |
+
def evaluate_genome(genome: Genome, dataset: str, explore: float):
|
166 |
if dataset == "Demo (Surrogate)":
|
167 |
v = genome.vector() * 2 - 1
|
168 |
base = rastrigin(v)
|
|
|
173 |
return _train_eval_proxy(genome, "PIQA", explore)
|
174 |
if dataset.startswith("HellaSwag"):
|
175 |
return _train_eval_proxy(genome, "HellaSwag", explore)
|
|
|
176 |
v = genome.vector() * 2 - 1
|
177 |
return float(rastrigin(v)), None
|
178 |
|
179 |
+
# ---------- Viz helpers (bigger, transparent sphere) ----------
|
180 |
+
PALETTE = ["#111827", "#334155", "#475569", "#64748b", "#94a3b8"] # muted grayscale/blue
|
181 |
+
BG = "white"
|
182 |
+
|
183 |
def sphere_project(points: np.ndarray) -> np.ndarray:
|
184 |
rng = np.random.RandomState(42)
|
185 |
W = rng.normal(size=(points.shape[1], 3)).astype(np.float32)
|
186 |
Y = points @ W
|
187 |
norms = np.linalg.norm(Y, axis=1, keepdims=True) + 1e-8
|
188 |
+
return (Y / norms) * 1.2
|
189 |
+
|
190 |
+
def _species_colors(species: np.ndarray) -> list:
|
191 |
+
colors = []
|
192 |
+
for s in species:
|
193 |
+
c = PALETTE[int(s) % len(PALETTE)]
|
194 |
+
colors.append(c)
|
195 |
+
return colors
|
196 |
|
197 |
def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
|
198 |
species = np.array([g.species for g in genomes])
|
199 |
+
colors = _species_colors(species)
|
200 |
custom = np.array([[g.d_model, g.n_layers, g.n_heads, g.ffn_mult, g.memory_tokens, g.dropout,
|
201 |
g.species, g.fitness, (g.acc if g.acc is not None else -1.0)]
|
202 |
for g in genomes], dtype=np.float32)
|
|
|
204 |
scatter = go.Scatter3d(
|
205 |
x=points3d[:,0], y=points3d[:,1], z=points3d[:,2],
|
206 |
mode='markers',
|
207 |
+
marker=dict(size=6.5, color=colors, opacity=0.92),
|
208 |
customdata=custom,
|
209 |
hovertemplate=(
|
210 |
+
"<b>Genome</b><br>"
|
211 |
+
"d_model=%{customdata[0]:.0f} · layers=%{customdata[1]:.0f} · heads=%{customdata[2]:.0f}<br>"
|
212 |
"ffn_mult=%{customdata[3]:.1f} · mem=%{customdata[4]:.0f} · drop=%{customdata[5]:.2f}<br>"
|
213 |
"species=%{customdata[6]:.0f}<br>"
|
214 |
"fitness=%{customdata[7]:.4f}<br>"
|
|
|
216 |
)
|
217 |
)
|
218 |
|
219 |
+
# Subtle, large sphere
|
220 |
+
u = np.linspace(0, 2*np.pi, 72)
|
221 |
+
v = np.linspace(0, np.pi, 36)
|
222 |
+
r = 1.2
|
223 |
xs = r*np.outer(np.cos(u), np.sin(v))
|
224 |
ys = r*np.outer(np.sin(u), np.sin(v))
|
225 |
zs = r*np.outer(np.ones_like(u), np.cos(v))
|
226 |
+
sphere = go.Surface(
|
227 |
+
x=xs, y=ys, z=zs,
|
228 |
+
opacity=0.08,
|
229 |
+
showscale=False,
|
230 |
+
colorscale=[[0, "#cbd5e1"], [1, "#cbd5e1"]],
|
231 |
+
hoverinfo="skip"
|
232 |
+
)
|
233 |
|
234 |
layout = go.Layout(
|
235 |
+
paper_bgcolor=BG, plot_bgcolor=BG,
|
236 |
+
title=f"Evo Architecture Sphere — Gen {gen_idx}",
|
237 |
+
scene=dict(
|
238 |
+
xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False),
|
239 |
+
bgcolor=BG
|
240 |
+
),
|
241 |
+
margin=dict(l=0, r=0, t=36, b=0),
|
242 |
showlegend=False,
|
243 |
+
height=720,
|
244 |
+
font=dict(family="Inter, Arial, sans-serif", size=14)
|
245 |
)
|
246 |
return go.Figure(data=[sphere, scatter], layout=layout)
|
247 |
|
248 |
def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> go.Figure:
|
|
|
249 |
xs = [h[0] for h in history]
|
250 |
if metric == "Accuracy":
|
251 |
+
ys = [h[2] if (h[2] == h[2]) else None for h in history]
|
252 |
title, ylab = "Best Accuracy per Generation", "Accuracy"
|
253 |
else:
|
254 |
ys = [h[1] for h in history]
|
255 |
+
title, ylab = "Best Fitness per Generation", "Fitness (↓ better)"
|
256 |
+
fig = go.Figure(data=[go.Scatter(x=xs, y=ys, mode="lines+markers", line=dict(width=2))])
|
257 |
+
fig.update_layout(
|
258 |
+
paper_bgcolor=BG, plot_bgcolor=BG,
|
259 |
+
title=title, xaxis_title="Generation", yaxis_title=ylab,
|
260 |
+
margin=dict(l=30, r=10, t=36, b=30),
|
261 |
+
height=340,
|
262 |
+
font=dict(family="Inter, Arial, sans-serif", size=14)
|
263 |
+
)
|
264 |
return fig
|
265 |
|
266 |
+
def fig_to_html(fig: go.Figure) -> str:
|
267 |
+
# Robust Plotly rendering inside Gradio
|
268 |
+
return pio.to_html(fig, include_plotlyjs="cdn", full_html=False, config=dict(displaylogo=False))
|
269 |
+
|
270 |
def approx_params(g: Genome) -> int:
|
271 |
per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
|
272 |
total = per_layer * g.n_layers + 1000 * g.memory_tokens
|
273 |
return int(total)
|
274 |
|
275 |
+
# ---------- Orchestrator ----------
|
|
|
|
|
276 |
class EvoRunner:
|
277 |
def __init__(self):
|
278 |
self.lock = threading.Lock()
|
|
|
286 |
self.running = True
|
287 |
|
288 |
pop: List[Genome] = [random_genome(rng) for _ in range(pop_size)]
|
|
|
289 |
for g in pop:
|
290 |
fit, acc = evaluate_genome(g, dataset, explore)
|
291 |
g.fitness, g.acc = fit, acc
|
292 |
|
293 |
+
history: List[Tuple[int,float,float]] = []
|
294 |
best_overall: Optional[Genome] = None
|
295 |
|
296 |
for gen in range(1, generations+1):
|
297 |
if self.stop_flag: break
|
298 |
|
|
|
299 |
k = max(2, int(2 + exploit * 5))
|
300 |
parents = []
|
301 |
for _ in range(pop_size):
|
302 |
sample = rng.sample(pop, k=k)
|
303 |
parents.append(min(sample, key=lambda x: x.fitness))
|
304 |
|
|
|
305 |
children = []
|
306 |
for i in range(0, pop_size, 2):
|
307 |
a = parents[i]; b = parents[(i+1) % pop_size]
|
|
|
310 |
children.extend([child1, child2])
|
311 |
children = children[:pop_size]
|
312 |
|
|
|
313 |
for c in children:
|
314 |
fit, acc = evaluate_genome(c, dataset, explore)
|
315 |
c.fitness, c.acc = fit, acc
|
316 |
|
|
|
317 |
elite_n = max(1, pop_size // 10)
|
318 |
elites = sorted(pop, key=lambda x: x.fitness)[:elite_n]
|
|
|
|
|
319 |
pop = sorted(children, key=lambda x: x.fitness)
|
320 |
pop[-elite_n:] = elites
|
321 |
|
322 |
best = min(pop, key=lambda x: x.fitness)
|
323 |
+
if best_overall is None or best.fitness < best_overall.fitness: best_overall = best
|
|
|
324 |
|
325 |
history.append((gen, best.fitness, (best.acc if best.acc is not None else float("nan"))))
|
326 |
|
|
|
327 |
P = np.stack([g.vector() for g in pop], axis=0)
|
328 |
P3 = sphere_project(P)
|
329 |
sphere_fig = make_sphere_figure(P3, pop, gen)
|
330 |
hist_fig = make_history_figure(history, metric_choice)
|
331 |
+
|
332 |
top = sorted(pop, key=lambda x: x.fitness)[: min(12, len(pop))]
|
333 |
top_table = [
|
334 |
{
|
|
|
349 |
|
350 |
with self.lock:
|
351 |
self.state = {
|
352 |
+
"sphere_html": fig_to_html(sphere_fig),
|
353 |
+
"history_html": fig_to_html(hist_fig),
|
354 |
"top": top_table,
|
355 |
"best": best_card,
|
356 |
"gen": gen,
|
|
|
359 |
}
|
360 |
|
361 |
time.sleep(max(0.0, pace_ms/1000.0))
|
|
|
362 |
self.running = False
|
363 |
|
364 |
def start(self, *args, **kwargs):
|
365 |
if self.running: return
|
366 |
t = threading.Thread(target=self.run, args=args, kwargs=kwargs, daemon=True)
|
367 |
t.start()
|
|
|
368 |
def stop(self): self.stop_flag = True
|
369 |
|
370 |
runner = EvoRunner()
|
371 |
|
372 |
+
# ---------- UI callbacks ----------
|
|
|
|
|
373 |
def start_evo(dataset, pop, gens, mut, explore, exploit, seed, pace_ms, metric_choice):
|
374 |
runner.start(dataset, int(pop), int(gens), float(mut), float(explore), float(exploit), int(seed), int(pace_ms), metric_choice)
|
375 |
return (gr.update(interactive=False), gr.update(interactive=True))
|
|
|
381 |
def poll_state():
|
382 |
with runner.lock:
|
383 |
s = runner.state.copy()
|
384 |
+
sphere_html = s.get("sphere_html", "")
|
385 |
+
history_html = s.get("history_html", "")
|
386 |
best = s.get("best", {})
|
387 |
gen = s.get("gen", 0)
|
388 |
dataset = s.get("dataset", "Demo (Surrogate)")
|
389 |
top = s.get("top", [])
|
|
|
390 |
if best:
|
391 |
acc_txt = "—" if best.get("accuracy") is None else f"{best.get('accuracy'):.3f}"
|
392 |
stats_md = (
|
|
|
402 |
else:
|
403 |
stats_md = "Waiting… click **Start Evolution**."
|
404 |
df = pd.DataFrame(top)
|
405 |
+
return sphere_html, history_html, stats_md, df
|
406 |
|
407 |
def export_snapshot():
|
408 |
from json import dumps
|
|
|
413 |
f.write(payload)
|
414 |
return path
|
415 |
|
416 |
+
# ---------- Build UI (minimal layout) ----------
|
417 |
+
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
|
418 |
+
with gr.Column(elem_id="header"):
|
419 |
+
gr.Markdown("## Evo Playground — Minimal Live Evolution (PIQA / HellaSwag accuracy)")
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
with gr.Row():
|
422 |
+
with gr.Column(scale=1, elem_classes=["controls"]):
|
|
|
423 |
with gr.Group():
|
424 |
dataset = gr.Dropdown(
|
425 |
label="Dataset",
|
426 |
choices=["Demo (Surrogate)", "PIQA (Phase 2)", "HellaSwag (Phase 2)"],
|
427 |
value="Demo (Surrogate)",
|
428 |
+
info="PIQA/HellaSwag compute real proxy accuracy; Demo is a fast surrogate."
|
429 |
)
|
430 |
pop = gr.Slider(8, 80, value=24, step=2, label="Population size")
|
431 |
gens = gr.Slider(5, 200, value=60, step=1, label="Max generations")
|
|
|
440 |
start = gr.Button("▶ Start Evolution", variant="primary")
|
441 |
stop = gr.Button("⏹ Stop", variant="secondary")
|
442 |
|
443 |
+
with gr.Group(elem_classes=["panel"]):
|
444 |
+
stats_md = gr.Markdown("Waiting…", elem_id="stats")
|
445 |
+
|
446 |
+
with gr.Group(elem_classes=["panel"]):
|
447 |
export_btn = gr.Button("Export Snapshot (JSON)")
|
448 |
export_file = gr.File(label="Download snapshot", visible=False)
|
449 |
|
|
|
450 |
with gr.Column(scale=2):
|
451 |
+
with gr.Group(elem_classes=["panel"]):
|
452 |
+
sphere_html = gr.HTML()
|
453 |
+
with gr.Group(elem_classes=["panel"]):
|
454 |
+
hist_html = gr.HTML()
|
455 |
+
with gr.Group(elem_classes=["panel"]):
|
456 |
top_df = gr.Dataframe(label="Top Genomes (live)", wrap=True, interactive=False)
|
457 |
|
458 |
# Wiring
|
|
|
460 |
stop.click(stop_evo, [], [start, stop])
|
461 |
export_btn.click(export_snapshot, [], [export_file])
|
462 |
|
463 |
+
# Initial paint + polling
|
464 |
+
demo.load(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
|
465 |
+
gr.Timer(0.7).tick(poll_state, None, [sphere_html, hist_html, stats_md, top_df])
|
|
|
|
|
|
|
466 |
|
467 |
if __name__ == "__main__":
|
468 |
demo.launch()
|