AbstractPhil commited on
Commit
d08afa7
·
verified ·
1 Parent(s): 5d8d3ef

Update beeper_model.py

Browse files
Files changed (1) hide show
  1. beeper_model.py +129 -44
beeper_model.py CHANGED
@@ -1,10 +1,10 @@
1
  # beeper.py
2
  # --------------------------------------------------------------------------------------------------
3
- # Beeper — Rose-based tiny GPT (inference module)
4
- # - Decoder-only GPT with SDPA (FlashAttention path on Ampere+)
5
- # - Model exactly mirrors the training-time architecture you provided (dim=512, L=6, H=8)
6
- # - Safe state-dict loader that auto-sizes pentachora banks before strict load
7
- # - Generation API with repetition/presence/frequency penalties (same defaults as training)
8
  # --------------------------------------------------------------------------------------------------
9
  from __future__ import annotations
10
 
@@ -12,7 +12,7 @@ import math
12
  import re
13
  import inspect
14
  from contextlib import nullcontext
15
- from typing import Optional, Tuple
16
 
17
  import torch
18
  import torch.nn as nn
@@ -44,13 +44,9 @@ except Exception:
44
 
45
 
46
  def sdpa_ctx_prefer_flash():
47
- """
48
- Best-effort context to bias SDPA toward FlashAttention on supported GPUs.
49
- Falls back to no-op if not available.
50
- """
51
  if _sdpa_kernel is None or _SDPA_SIG is None:
52
  return nullcontext()
53
-
54
  params = {p.name for p in _SDPA_SIG.parameters.values()}
55
  try:
56
  if "backends" in params and _SDPBackend is not None:
@@ -72,11 +68,7 @@ def sdpa_ctx_prefer_flash():
72
 
73
  # --------------------------------- Core blocks ------------------------------------------------------
74
  class CausalSelfAttention(nn.Module):
75
- """
76
- Multi-head causal self-attention layer using PyTorch SDPA.
77
- - On CUDA, uses scaled_dot_product_attention with is_causal=True and dropout during training.
78
- - On CPU, falls back to manual masked attention.
79
- """
80
  def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
81
  super().__init__()
82
  assert dim % n_heads == 0, "dim must be divisible by n_heads"
@@ -139,10 +131,18 @@ class BeeperRoseGPT(nn.Module):
139
  Config keys used:
140
  - vocab_size, dim, context, n_heads, n_layers, mlp_ratio
141
  - resid_dropout, dropout, grad_checkpoint
 
 
 
 
 
 
 
 
142
  Notes:
143
  - Shares token embedding with LM head (tied weights).
144
- - Includes Rose projection/anchors and pentachora banks; unused for plain generation,
145
- but kept for full compatibility with trained checkpoints.
146
  """
147
  def __init__(self, cfg: dict):
148
  super().__init__()
@@ -150,6 +150,7 @@ class BeeperRoseGPT(nn.Module):
150
  H, L, MR = cfg["n_heads"], cfg["n_layers"], cfg["mlp_ratio"]
151
  RD, AD = cfg.get("resid_dropout", 0.1), cfg.get("dropout", 0.0)
152
  self.grad_checkpoint = bool(cfg.get("grad_checkpoint", False))
 
153
 
154
  self.vocab_size, self.context = int(V), int(Ctx)
155
 
@@ -169,9 +170,7 @@ class BeeperRoseGPT(nn.Module):
169
 
170
  self.norm = nn.LayerNorm(D)
171
  self.lm_head = nn.Linear(D, V, bias=False)
172
-
173
- # Weight tying
174
- self.lm_head.weight = self.token_emb.weight
175
 
176
  # Rose projection + anchors (present in checkpoints)
177
  self.rose_proj = nn.Linear(D, D, bias=False)
@@ -196,16 +195,12 @@ class BeeperRoseGPT(nn.Module):
196
 
197
  # ---- Pentachora creation (must match sizes in checkpoint before strict load) -------------------
198
  def ensure_pentachora(self, coarse_C: int, medium_C: int, fine_C: int, dim: int, device: torch.device):
199
- """
200
- Initialize pentachora banks if not already present.
201
- Shapes must match checkpoint entries for strict loading.
202
- """
203
  if self.pent_inited.item() == 1:
204
  return
205
 
206
  def bank(C: int) -> nn.Parameter:
207
  if C <= 0:
208
- # Keep a zero-sized parameter to satisfy strict loading (rare).
209
  return nn.Parameter(torch.zeros((0, 5, dim), device=device))
210
  pts = torch.randn(C, 5, dim, device=device)
211
  pts = F.normalize(pts - pts.mean(dim=1, keepdim=True), dim=-1)
@@ -216,6 +211,94 @@ class BeeperRoseGPT(nn.Module):
216
  self.penta_fine = bank(int(fine_C))
217
  self.pent_inited.fill_(1)
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  # ---- Backbone / forward -----------------------------------------------------------------------
220
  def _block_forward(self, blk: nn.ModuleDict, x: torch.Tensor) -> torch.Tensor:
221
  x = x + blk["attn"](blk["norm1"](x))
@@ -235,8 +318,14 @@ class BeeperRoseGPT(nn.Module):
235
  x = self._block_forward(blk, x)
236
  return self.norm(x)
237
 
238
- def forward(self, idx: torch.Tensor) -> torch.Tensor:
 
 
 
 
239
  h = self.backbone(idx)
 
 
240
  return self.lm_head(h)
241
 
242
  # ---- Utilities ---------------------------------------------------------------------------------
@@ -245,7 +334,7 @@ class BeeperRoseGPT(nn.Module):
245
  return self.backbone(idx)
246
 
247
  def rose_hidden_pool(self, h: torch.Tensor, mode: str = "mean") -> torch.Tensor:
248
- """Pool hidden states for Rose-related terms (unused in plain generation)."""
249
  return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
250
 
251
 
@@ -257,9 +346,7 @@ def prepare_model_for_state_dict(
257
  ) -> None:
258
  """
259
  Ensure model has pentachora parameters sized to match the incoming state_dict,
260
- so we can load with strict=True.
261
-
262
- If the checkpoint has no pentachora (older versions), we do nothing.
263
  """
264
  device = device or next(model.parameters()).device
265
  need = all(k in state_dict for k in ("penta_coarse", "penta_medium", "penta_fine"))
@@ -267,18 +354,15 @@ def prepare_model_for_state_dict(
267
  return
268
 
269
  pc, pt, pm = state_dict["penta_coarse"], state_dict["penta_medium"], state_dict["penta_fine"]
270
- # Expect [C,5,D]
271
- def dims_ok(t: torch.Tensor) -> bool:
272
- return t.ndim == 3 and t.size(1) == 5 and t.size(2) == model.token_emb.embedding_dim
273
 
274
- if not (dims_ok(pc) and dims_ok(pt) and dims_ok(pm)):
275
- # Shapes inconsistent; fall back to non-strict load later.
 
 
 
276
  return
277
 
278
- coarse_C = pc.size(0)
279
- topic_C = pt.size(0)
280
- mood_C = pm.size(0)
281
- model.ensure_pentachora(coarse_C, topic_C, mood_C, dim=pc.size(2), device=device)
282
 
283
 
284
  # --------------------------------- Generation -------------------------------------------------------
@@ -304,9 +388,10 @@ def generate(
304
  frequency_penalty: Optional[float] = None,
305
  device: Optional[torch.device] = None,
306
  detokenize: bool = True,
 
307
  ) -> str:
308
  """
309
- Penalized nucleus sampling (same knobs as training script).
310
  """
311
  temperature = cfg.get("temperature", 0.9) if temperature is None else float(temperature)
312
  top_k = cfg.get("top_k", 40) if top_k is None else int(top_k)
@@ -327,10 +412,10 @@ def generate(
327
  counts[t] += 1
328
 
329
  for _ in range(int(max_new_tokens)):
330
- logits = model(x[:, -cfg["context"]:])
331
  logits = logits[:, -1, :]
332
 
333
- # Repetition penalty (CTRL-like)
334
  if repetition_penalty and repetition_penalty != 1.0:
335
  mask = counts > 0
336
  if mask.any():
@@ -338,7 +423,7 @@ def generate(
338
  logits[:, mask][pos] /= repetition_penalty
339
  logits[:, mask][~pos] *= repetition_penalty
340
 
341
- # Presence/frequency penalties (OpenAI-like)
342
  if presence_penalty or frequency_penalty:
343
  pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0)
344
  logits = logits - pen.unsqueeze(0)
 
1
  # beeper.py
2
  # --------------------------------------------------------------------------------------------------
3
+ # Beeper Full Penta Controller — Rose-based tiny GPT (inference module with runtime pentachora influence)
4
+ # - Decoder-only GPT with SDPA (FlashAttention path on Ampere/Hopper)
5
+ # - Runtime "vertex pull" uses config["runtime_pentachora"] to bias hidden states toward
6
+ # pentachora vertices (coarse/topic/mood) exactly like training-time behavior, but non-destructive
7
+ # and fully toggleable.
8
  # --------------------------------------------------------------------------------------------------
9
  from __future__ import annotations
10
 
 
12
  import re
13
  import inspect
14
  from contextlib import nullcontext
15
+ from typing import Optional, Tuple, Dict, Any
16
 
17
  import torch
18
  import torch.nn as nn
 
44
 
45
 
46
  def sdpa_ctx_prefer_flash():
47
+ """Bias SDPA toward FlashAttention where possible; otherwise no-op."""
 
 
 
48
  if _sdpa_kernel is None or _SDPA_SIG is None:
49
  return nullcontext()
 
50
  params = {p.name for p in _SDPA_SIG.parameters.values()}
51
  try:
52
  if "backends" in params and _SDPBackend is not None:
 
68
 
69
  # --------------------------------- Core blocks ------------------------------------------------------
70
  class CausalSelfAttention(nn.Module):
71
+ """Multi-head causal self-attention using PyTorch SDPA."""
 
 
 
 
72
  def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
73
  super().__init__()
74
  assert dim % n_heads == 0, "dim must be divisible by n_heads"
 
131
  Config keys used:
132
  - vocab_size, dim, context, n_heads, n_layers, mlp_ratio
133
  - resid_dropout, dropout, grad_checkpoint
134
+ - runtime_pentachora: {
135
+ "enable": bool,
136
+ "pool": "mean" | "last",
137
+ "temp": float, # similarity temperature (default: 0.10)
138
+ "coarse_alpha": float, # hidden blend strength for coarse bank
139
+ "topic_alpha": float, # hidden blend strength for topic bank
140
+ "mood_alpha": float # hidden blend strength for mood bank
141
+ }
142
  Notes:
143
  - Shares token embedding with LM head (tied weights).
144
+ - Includes Rose anchors and pentachora banks; at runtime we can apply a *non-destructive*
145
+ vertex pull to hidden states before the LM head using the above config.
146
  """
147
  def __init__(self, cfg: dict):
148
  super().__init__()
 
150
  H, L, MR = cfg["n_heads"], cfg["n_layers"], cfg["mlp_ratio"]
151
  RD, AD = cfg.get("resid_dropout", 0.1), cfg.get("dropout", 0.0)
152
  self.grad_checkpoint = bool(cfg.get("grad_checkpoint", False))
153
+ self.runtime_cfg: Dict[str, Any] = dict(cfg.get("runtime_pentachora", {}) or {})
154
 
155
  self.vocab_size, self.context = int(V), int(Ctx)
156
 
 
170
 
171
  self.norm = nn.LayerNorm(D)
172
  self.lm_head = nn.Linear(D, V, bias=False)
173
+ self.lm_head.weight = self.token_emb.weight # weight tying
 
 
174
 
175
  # Rose projection + anchors (present in checkpoints)
176
  self.rose_proj = nn.Linear(D, D, bias=False)
 
195
 
196
  # ---- Pentachora creation (must match sizes in checkpoint before strict load) -------------------
197
  def ensure_pentachora(self, coarse_C: int, medium_C: int, fine_C: int, dim: int, device: torch.device):
198
+ """Initialize pentachora banks if not already present."""
 
 
 
199
  if self.pent_inited.item() == 1:
200
  return
201
 
202
  def bank(C: int) -> nn.Parameter:
203
  if C <= 0:
 
204
  return nn.Parameter(torch.zeros((0, 5, dim), device=device))
205
  pts = torch.randn(C, 5, dim, device=device)
206
  pts = F.normalize(pts - pts.mean(dim=1, keepdim=True), dim=-1)
 
211
  self.penta_fine = bank(int(fine_C))
212
  self.pent_inited.fill_(1)
213
 
214
+ # ---- Runtime configuration helpers -------------------------------------------------------------
215
+ def set_runtime_pentachora(self, cfg: Dict[str, Any]) -> None:
216
+ """Update runtime pentachora behavior (enable/alphas/temp/pool)."""
217
+ self.runtime_cfg.update(cfg or {})
218
+
219
+ def _pool_hidden(self, h: torch.Tensor, mode: str) -> torch.Tensor:
220
+ return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
221
+
222
+ @staticmethod
223
+ def _weighted_nearest_vertex_target(
224
+ pooled: torch.Tensor, # [B,D]
225
+ bank: torch.Tensor, # [C,5,D]
226
+ temp: float
227
+ ) -> torch.Tensor:
228
+ """
229
+ For each class (simplex) pick its nearest vertex to the pooled latent,
230
+ then compute a softmax over classes of -min_dists/temp and take the
231
+ weighted average of those nearest vertices => [B,D] target.
232
+ """
233
+ B, D = pooled.shape
234
+ C = bank.size(0)
235
+ if C == 0:
236
+ return pooled
237
+
238
+ # distances to each vertex
239
+ diffs = pooled[:, None, None, :] - bank[None, :, :, :] # [B,C,5,D]
240
+ dists = torch.norm(diffs, dim=-1) # [B,C,5]
241
+
242
+ min_dists, min_idx = dists.min(dim=2) # [B,C], [B,C]
243
+ sims = -min_dists / max(1e-8, float(temp)) # [B,C]
244
+ weights = F.softmax(sims, dim=-1) # [B,C]
245
+
246
+ # gather nearest vertex vectors: [B,C,D]
247
+ bank_exp = bank.unsqueeze(0).expand(B, -1, -1, -1) # [B,C,5,D]
248
+ gather_idx = min_idx.unsqueeze(-1).unsqueeze(-1).expand(B, C, 1, D)
249
+ nearest = torch.gather(bank_exp, 2, gather_idx).squeeze(2) # [B,C,D]
250
+
251
+ target = (weights.unsqueeze(-1) * nearest).sum(dim=1) # [B,D]
252
+ return target
253
+
254
+ def _apply_runtime_vertex_pull(
255
+ self,
256
+ h: torch.Tensor, # [B,T,D]
257
+ runtime_cfg: Dict[str, Any]
258
+ ) -> torch.Tensor:
259
+ """
260
+ Apply non-destructive vertex pull to hidden states using banks selected by runtime_cfg.
261
+ We compute a pooled latent, a per-bank target vector, form a delta, and blend it back into h.
262
+ """
263
+ if not runtime_cfg or not runtime_cfg.get("enable", False):
264
+ return h
265
+
266
+ pool_mode = str(runtime_cfg.get("pool", "mean"))
267
+ temp = float(runtime_cfg.get("temp", 0.10))
268
+
269
+ # Strengths per bank
270
+ alpha_coarse = float(runtime_cfg.get("coarse_alpha", 0.0))
271
+ alpha_topic = float(runtime_cfg.get("topic_alpha", 0.0))
272
+ alpha_mood = float(runtime_cfg.get("mood_alpha", 0.0))
273
+
274
+ if (alpha_coarse <= 0 and alpha_topic <= 0 and alpha_mood <= 0):
275
+ return h
276
+
277
+ pooled = self._pool_hidden(h, pool_mode) # [B,D]
278
+
279
+ total_delta = None
280
+ if alpha_coarse > 0 and getattr(self, "penta_coarse", None) is not None:
281
+ tgt = self._weighted_nearest_vertex_target(pooled, self.penta_coarse, temp)
282
+ delta = tgt - pooled
283
+ total_delta = (alpha_coarse * delta) if total_delta is None else total_delta + alpha_coarse * delta
284
+
285
+ if alpha_topic > 0 and getattr(self, "penta_medium", None) is not None:
286
+ tgt = self._weighted_nearest_vertex_target(pooled, self.penta_medium, temp)
287
+ delta = tgt - pooled
288
+ total_delta = delta * alpha_topic if total_delta is None else total_delta + alpha_topic * delta
289
+
290
+ if alpha_mood > 0 and getattr(self, "penta_fine", None) is not None:
291
+ tgt = self._weighted_nearest_vertex_target(pooled, self.penta_fine, temp)
292
+ delta = tgt - pooled
293
+ total_delta = delta * alpha_mood if total_delta is None else total_delta + alpha_mood * delta
294
+
295
+ if total_delta is None:
296
+ return h
297
+
298
+ # Broadcast same delta to all time steps (global conditioning shift)
299
+ h = h + total_delta.unsqueeze(1) # [B,T,D]
300
+ return h
301
+
302
  # ---- Backbone / forward -----------------------------------------------------------------------
303
  def _block_forward(self, blk: nn.ModuleDict, x: torch.Tensor) -> torch.Tensor:
304
  x = x + blk["attn"](blk["norm1"](x))
 
318
  x = self._block_forward(blk, x)
319
  return self.norm(x)
320
 
321
+ def forward(self, idx: torch.Tensor, runtime_cfg: Optional[Dict[str, Any]] = None) -> torch.Tensor:
322
+ """
323
+ Forward pass with optional runtime pentachora influence.
324
+ If runtime_cfg is None, falls back to self.runtime_cfg set at init or via set_runtime_pentachora().
325
+ """
326
  h = self.backbone(idx)
327
+ cfg = self.runtime_cfg if runtime_cfg is None else {**self.runtime_cfg, **(runtime_cfg or {})}
328
+ h = self._apply_runtime_vertex_pull(h, cfg)
329
  return self.lm_head(h)
330
 
331
  # ---- Utilities ---------------------------------------------------------------------------------
 
334
  return self.backbone(idx)
335
 
336
  def rose_hidden_pool(self, h: torch.Tensor, mode: str = "mean") -> torch.Tensor:
337
+ """Pool hidden states for Rose-related terms."""
338
  return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
339
 
340
 
 
346
  ) -> None:
347
  """
348
  Ensure model has pentachora parameters sized to match the incoming state_dict,
349
+ so we can load with strict=True. No-op if checkpoint lacks penta_* keys.
 
 
350
  """
351
  device = device or next(model.parameters()).device
352
  need = all(k in state_dict for k in ("penta_coarse", "penta_medium", "penta_fine"))
 
354
  return
355
 
356
  pc, pt, pm = state_dict["penta_coarse"], state_dict["penta_medium"], state_dict["penta_fine"]
 
 
 
357
 
358
+ def dims_ok(t: torch.Tensor, D: int) -> bool:
359
+ return t.ndim == 3 and t.size(1) == 5 and t.size(2) == D
360
+
361
+ D = model.token_emb.embedding_dim
362
+ if not (dims_ok(pc, D) and dims_ok(pt, D) and dims_ok(pm, D)):
363
  return
364
 
365
+ model.ensure_pentachora(pc.size(0), pt.size(0), pm.size(0), dim=D, device=device)
 
 
 
366
 
367
 
368
  # --------------------------------- Generation -------------------------------------------------------
 
388
  frequency_penalty: Optional[float] = None,
389
  device: Optional[torch.device] = None,
390
  detokenize: bool = True,
391
+ runtime_cfg: Optional[Dict[str, Any]] = None, # <— NEW: pass-through to forward()
392
  ) -> str:
393
  """
394
+ Penalized nucleus sampling with optional runtime pentachora influence.
395
  """
396
  temperature = cfg.get("temperature", 0.9) if temperature is None else float(temperature)
397
  top_k = cfg.get("top_k", 40) if top_k is None else int(top_k)
 
412
  counts[t] += 1
413
 
414
  for _ in range(int(max_new_tokens)):
415
+ logits = model(x[:, -cfg["context"]:], runtime_cfg=runtime_cfg)
416
  logits = logits[:, -1, :]
417
 
418
+ # Repetition penalty
419
  if repetition_penalty and repetition_penalty != 1.0:
420
  mask = counts > 0
421
  if mask.any():
 
423
  logits[:, mask][pos] /= repetition_penalty
424
  logits[:, mask][~pos] *= repetition_penalty
425
 
426
+ # Presence/frequency penalties
427
  if presence_penalty or frequency_penalty:
428
  pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0)
429
  logits = logits - pen.unsqueeze(0)