HemanM commited on
Commit
7a05320
·
verified ·
1 Parent(s): 1bd6310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — Minimal, pro UI with big transparent sphere and clean hover
2
  import math, json, random, time, threading
3
  from dataclasses import dataclass, asdict
4
  from typing import List, Tuple, Dict, Any, Optional
@@ -23,7 +23,7 @@ CUSTOM_CSS = """
23
  .gradio-container { max-width: 1180px !important; }
24
  #header { border-radius: var(--radius); padding: 8px 6px; }
25
  h1, h2, h3, .gr-markdown { color: var(--fg); }
26
- .gr-button { border-radius: 10px; }
27
  .controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
28
  .panel { padding: 10px; }
29
  #stats { font-weight: 300; color: var(--fg); }
@@ -109,15 +109,24 @@ class TinyMLP(nn.Module):
109
 
110
  @lru_cache(maxsize=4)
111
  def _cached_dataset(name: str):
112
- if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
113
- if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
 
 
 
 
114
  return None
115
 
116
  def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
117
  data = _cached_dataset(dataset_name)
118
  if data is None:
 
119
  v = genome.vector() * 2 - 1
120
- return float(rastrigin(v)), None
 
 
 
 
121
  Xtr_txt, ytr, Xva_txt, yva = data
122
  nfeat = 4096
123
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
@@ -173,6 +182,7 @@ def evaluate_genome(genome: Genome, dataset: str, explore: float):
173
  return _train_eval_proxy(genome, "PIQA", explore)
174
  if dataset.startswith("HellaSwag"):
175
  return _train_eval_proxy(genome, "HellaSwag", explore)
 
176
  v = genome.vector() * 2 - 1
177
  return float(rastrigin(v)), None
178
 
@@ -188,11 +198,7 @@ def sphere_project(points: np.ndarray) -> np.ndarray:
188
  return (Y / norms) * 1.2
189
 
190
  def _species_colors(species: np.ndarray) -> list:
191
- colors = []
192
- for s in species:
193
- c = PALETTE[int(s) % len(PALETTE)]
194
- colors.append(c)
195
- return colors
196
 
197
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
198
  species = np.array([g.species for g in genomes])
@@ -264,8 +270,13 @@ def make_history_figure(history: List[Tuple[int,float,float]], metric: str) -> g
264
  return fig
265
 
266
  def fig_to_html(fig: go.Figure) -> str:
267
- # Robust Plotly rendering inside Gradio
268
- return pio.to_html(fig, include_plotlyjs="cdn", full_html=False, config=dict(displaylogo=False))
 
 
 
 
 
269
 
270
  def approx_params(g: Genome) -> int:
271
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)
 
1
+ # app.py — Minimal, pro UI with big transparent sphere, accuracy, and robust rendering
2
  import math, json, random, time, threading
3
  from dataclasses import dataclass, asdict
4
  from typing import List, Tuple, Dict, Any, Optional
 
23
  .gradio-container { max-width: 1180px !important; }
24
  #header { border-radius: var(--radius); padding: 8px 6px; }
25
  h1, h2, h3, .gr-markdown { color: var(--fg); }
26
+ .gr-button { border-radius: 10px }
27
  .controls .gr-group, .panel { border: 1px solid var(--line); border-radius: var(--radius); }
28
  .panel { padding: 10px; }
29
  #stats { font-weight: 300; color: var(--fg); }
 
109
 
110
  @lru_cache(maxsize=4)
111
  def _cached_dataset(name: str):
112
+ # Defensive: if loading fails (e.g., datasets version / no internet), return None
113
+ try:
114
+ if name.startswith("PIQA"): return load_piqa(subset=800, seed=42)
115
+ if name.startswith("HellaSwag"): return load_hellaswag(subset=800, seed=42)
116
+ except Exception:
117
+ return None
118
  return None
119
 
120
  def _train_eval_proxy(genome: Genome, dataset_name: str, explore: float, device: str = "cpu"):
121
  data = _cached_dataset(dataset_name)
122
  if data is None:
123
+ # Fallback to surrogate to keep the UI alive
124
  v = genome.vector() * 2 - 1
125
+ base = rastrigin(v)
126
+ parsimony = 0.001 * (genome.d_model + 50*genome.n_layers + 20*genome.n_heads + 100*genome.memory_tokens)
127
+ noise = np.random.normal(scale=0.05 * max(0.0, min(1.0, explore)))
128
+ return float(base + parsimony + noise), None
129
+
130
  Xtr_txt, ytr, Xva_txt, yva = data
131
  nfeat = 4096
132
  Xtr = hash_vectorize(Xtr_txt, n_features=nfeat, seed=1234)
 
182
  return _train_eval_proxy(genome, "PIQA", explore)
183
  if dataset.startswith("HellaSwag"):
184
  return _train_eval_proxy(genome, "HellaSwag", explore)
185
+ # Fallback
186
  v = genome.vector() * 2 - 1
187
  return float(rastrigin(v)), None
188
 
 
198
  return (Y / norms) * 1.2
199
 
200
  def _species_colors(species: np.ndarray) -> list:
201
+ return [PALETTE[int(s) % len(PALETTE)] for s in species]
 
 
 
 
202
 
203
  def make_sphere_figure(points3d: np.ndarray, genomes: List[Genome], gen_idx: int) -> go.Figure:
204
  species = np.array([g.species for g in genomes])
 
270
  return fig
271
 
272
  def fig_to_html(fig: go.Figure) -> str:
273
+ # Inline Plotly JS to avoid any CDN/network dependency in Spaces
274
+ return pio.to_html(
275
+ fig,
276
+ include_plotlyjs=True, # IMPORTANT: inline JS so the sphere always renders
277
+ full_html=False,
278
+ config=dict(displaylogo=False)
279
+ )
280
 
281
  def approx_params(g: Genome) -> int:
282
  per_layer = (4.0 + 2.0 * float(g.ffn_mult)) * (g.d_model ** 2)