Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py — Minimal, pro UI with big transparent sphere and
|
2 |
import math, json, random, time, threading
|
3 |
from dataclasses import dataclass, asdict
|
4 |
from typing import List, Tuple, Dict, Any, Optional
|
@@ -23,7 +23,7 @@ CUSTOM_CSS = """
|
|
23 |
.gradio-container { max-width: 1180px !important; }
|
24 |
#header { border-radius: var(--radius); padding: 8px 6px; }
|
25 |
h1, h2, h3, .gr-markdown { color: var(--fg); }
|
26 |
-
.gr-button { border-radius: 10px
|
27 |
.controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
|
28 |
.panel { padding: 10px; }
|
29 |
#stats { font-weight: 300; color: var(--fg); }
|
@@ -109,15 +109,24 @@ class TinyMLP(nn.Module):
|
|
109 |
|
110 |
@lru_cache(maxsize=4)
|
111 |
def _cached_dataset(name: str):
|
112 |
-
if
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
return None
|
115 |
|
116 |
def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
|
117 |
data = _cached_dataset(dataset_name)
|
118 |
if data is None:
|
|
|
119 |
v = genome.vector() * 2 - 1
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
Xtr_txt, ytr, Xva_txt, yva = data
|
122 |
nfeat = 4096
|
123 |
Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
|
@@ -173,6 +182,7 @@ def evaluate_genome(genome: Genome, dataset: str, explore: float):
|
|
173 |
return _train_eval_proxy(genome, "PIQA", explore)
|
174 |
if dataset.startswith("HellaSwag"):
|
175 |
return _train_eval_proxy(genome, "HellaSwag", explore)
|
|
|
176 |
v = genome.vector() * 2 - 1
|
177 |
return float(rastrigin(v)), None
|
178 |
|
@@ -188,11 +198,7 @@ def sphere_project(points: np.ndarray) -> np.ndarray:
|
|
188 |
return (Y / norms) * 1.2
|
189 |
|
190 |
def _species_colors(species: np.ndarray) -> list:
|
191 |
-
|
192 |
-
for s in species:
|
193 |
-
c = PALETTE[int(s) % len(PALETTE)]
|
194 |
-
colors.append(c)
|
195 |
-
return colors
|
196 |
|
197 |
def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
|
198 |
species = np.array([g.species for g in genomes])
|
@@ -264,8 +270,13 @@ def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> g
|
|
264 |
return fig
|
265 |
|
266 |
def fig_to_html(fig: go.Figure) -> str:
|
267 |
-
#
|
268 |
-
return pio.to_html(
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
def approx_params(g: Genome) -> int:
|
271 |
per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
|
|
|
1 |
+
# app.py — Minimal, pro UI with big transparent sphere, accuracy, and robust rendering
|
2 |
import math, json, random, time, threading
|
3 |
from dataclasses import dataclass, asdict
|
4 |
from typing import List, Tuple, Dict, Any, Optional
|
|
|
23 |
.gradio-container { max-width: 1180px !important; }
|
24 |
#header { border-radius: var(--radius); padding: 8px 6px; }
|
25 |
h1, h2, h3, .gr-markdown { color: var(--fg); }
|
26 |
+
.gr-button { border-radius: 10px }
|
27 |
.controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
|
28 |
.panel { padding: 10px; }
|
29 |
#stats { font-weight: 300; color: var(--fg); }
|
|
|
109 |
|
110 |
@lru_cache(maxsize=4)
|
111 |
def _cached_dataset(name: str):
|
112 |
+
# Defensive: if loading fails (e.g., datasets version / no internet), return None
|
113 |
+
try:
|
114 |
+
if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
|
115 |
+
if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
|
116 |
+
except Exception:
|
117 |
+
return None
|
118 |
return None
|
119 |
|
120 |
def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
|
121 |
data = _cached_dataset(dataset_name)
|
122 |
if data is None:
|
123 |
+
# Fallback to surrogate to keep the UI alive
|
124 |
v = genome.vector() * 2 - 1
|
125 |
+
base = rastrigin(v)
|
126 |
+
parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
|
127 |
+
noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
|
128 |
+
return float(base + parsimony + noise), None
|
129 |
+
|
130 |
Xtr_txt, ytr, Xva_txt, yva = data
|
131 |
nfeat = 4096
|
132 |
Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
|
|
|
182 |
return _train_eval_proxy(genome, "PIQA", explore)
|
183 |
if dataset.startswith("HellaSwag"):
|
184 |
return _train_eval_proxy(genome, "HellaSwag", explore)
|
185 |
+
# Fallback
|
186 |
v = genome.vector() * 2 - 1
|
187 |
return float(rastrigin(v)), None
|
188 |
|
|
|
198 |
return (Y / norms) * 1.2
|
199 |
|
200 |
def _species_colors(species: np.ndarray) -> list:
|
201 |
+
return [PALETTE[int(s) % len(PALETTE)] for s in species]
|
|
|
|
|
|
|
|
|
202 |
|
203 |
def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
|
204 |
species = np.array([g.species for g in genomes])
|
|
|
270 |
return fig
|
271 |
|
272 |
def fig_to_html(fig: go.Figure) -> str:
|
273 |
+
# Inline Plotly JS to avoid any CDN/network dependency in Spaces
|
274 |
+
return pio.to_html(
|
275 |
+
fig,
|
276 |
+
include_plotlyjs=True, # IMPORTANT: inline JS so the sphere always renders
|
277 |
+
full_html=False,
|
278 |
+
config=dict(displaylogo=False)
|
279 |
+
)
|
280 |
|
281 |
def approx_params(g: Genome) -> int:
|
282 |
per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
|