AbstractPhil commited on
Commit
0e31052
·
verified ·
1 Parent(s): bbb5633

Update beeper_model.py

Browse files
Files changed (1) hide show
  1. beeper_model.py +164 -202
beeper_model.py CHANGED
@@ -1,33 +1,38 @@
1
- """
2
- Rose Beeper Model - Inference Components
3
- Extracted classes and utilities for model inference
4
- """
 
 
 
 
 
5
 
6
- import os
7
  import math
 
 
 
 
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from typing import Optional, Tuple, Dict, Any
12
- from contextlib import nullcontext
13
- import inspect
14
- import re
15
- from tokenizers import Tokenizer
16
- from safetensors.torch import load_file as load_safetensors
17
-
18
 
19
- # ============================================================================
20
- # SDPA (Scaled Dot Product Attention) Configuration
21
- # ============================================================================
 
22
 
23
- # Version-safe SDPA context helper
24
  try:
 
25
  from torch.nn.attention import sdpa_kernel as _sdpa_kernel_modern
26
  from torch.nn.attention import SDPBackend as _SDPBackend
27
  _SDPA_SIG = inspect.signature(_sdpa_kernel_modern)
28
  _sdpa_kernel = _sdpa_kernel_modern
29
  except Exception:
30
  try:
 
31
  from torch.backends.cuda import sdp_kernel as _sdpa_kernel_legacy
32
  _SDPA_SIG = inspect.signature(_sdpa_kernel_legacy)
33
  _SDPBackend = None
@@ -39,23 +44,23 @@ except Exception:
39
 
40
 
41
  def sdpa_ctx_prefer_flash():
42
- """Bias SDPA toward FlashAttention when available; no-op if unknown."""
 
 
 
43
  if _sdpa_kernel is None or _SDPA_SIG is None:
44
  return nullcontext()
45
 
46
  params = {p.name for p in _SDPA_SIG.parameters.values()}
47
  try:
48
- # Modern API (PyTorch 2.3+): backends=[...]
49
  if "backends" in params and _SDPBackend is not None:
50
  return _sdpa_kernel(backends=[
51
  _SDPBackend.FLASH_ATTENTION,
52
  _SDPBackend.EFFICIENT_ATTENTION,
53
  _SDPBackend.MATH
54
  ])
55
- # Modern API (alt): backend=...
56
  if "backend" in params and _SDPBackend is not None:
57
  return _sdpa_kernel(backend=_SDPBackend.FLASH_ATTENTION)
58
- # Legacy boolean flags (old CUDA backend)
59
  if {"enable_flash", "enable_math", "enable_mem_efficient"} <= params:
60
  return _sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True)
61
  if {"use_flash", "use_math", "use_mem_efficient"} <= params:
@@ -65,27 +70,27 @@ def sdpa_ctx_prefer_flash():
65
  return nullcontext()
66
 
67
 
68
- # ============================================================================
69
- # Model Components
70
- # ============================================================================
71
-
72
  class CausalSelfAttention(nn.Module):
73
- """Multi-head causal self-attention with optional FlashAttention."""
74
-
 
 
 
75
  def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
76
  super().__init__()
77
- assert dim % n_heads == 0
78
- self.nh = n_heads
79
- self.hd = dim // n_heads
80
  self.qkv = nn.Linear(dim, 3 * dim, bias=False)
81
  self.proj = nn.Linear(dim, dim, bias=False)
82
- self.attn_dropout = attn_dropout
83
 
84
- def forward(self, x):
85
  B, T, C = x.shape
86
  qkv = self.qkv(x)
87
  q, k, v = qkv.chunk(3, dim=-1)
88
- q = q.view(B, T, self.nh, self.hd).transpose(1, 2)
89
  k = k.view(B, T, self.nh, self.hd).transpose(1, 2)
90
  v = v.view(B, T, self.nh, self.hd).transpose(1, 2)
91
 
@@ -109,16 +114,15 @@ class CausalSelfAttention(nn.Module):
109
 
110
 
111
  class MLP(nn.Module):
112
- """Feed-forward network with GELU activation."""
113
-
114
- def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
115
  super().__init__()
116
  hidden = int(dim * mlp_ratio)
117
  self.fc1 = nn.Linear(dim, hidden)
118
  self.fc2 = nn.Linear(hidden, dim)
119
  self.drop = nn.Dropout(dropout)
120
-
121
- def forward(self, x):
122
  x = self.fc1(x)
123
  x = F.gelu(x, approximate="tanh")
124
  x = self.drop(x)
@@ -127,47 +131,62 @@ class MLP(nn.Module):
127
  return x
128
 
129
 
 
130
  class BeeperRoseGPT(nn.Module):
131
- """Rose Beeper GPT model with pentachora banks for multi-level control."""
132
-
 
 
 
 
 
 
 
 
 
133
  def __init__(self, cfg: dict):
134
  super().__init__()
135
  V, D, Ctx = cfg["vocab_size"], cfg["dim"], cfg["context"]
136
  H, L, MR = cfg["n_heads"], cfg["n_layers"], cfg["mlp_ratio"]
137
- RD, AD, CKPT = cfg["resid_dropout"], cfg["dropout"], cfg["grad_checkpoint"]
 
 
 
138
 
139
- self.vocab_size, self.context = V, Ctx
140
  self.token_emb = nn.Embedding(V, D)
141
- self.pos_emb = nn.Parameter(torch.zeros(1, Ctx, D))
142
- self.drop = nn.Dropout(RD)
143
 
144
  self.blocks = nn.ModuleList([
145
  nn.ModuleDict({
146
  "norm1": nn.LayerNorm(D),
147
- "attn": CausalSelfAttention(D, H, attn_dropout=AD),
148
  "norm2": nn.LayerNorm(D),
149
- "mlp": MLP(D, mlp_ratio=MR, dropout=RD),
150
- }) for _ in range(L)
 
151
  ])
152
- self.norm = nn.LayerNorm(D)
 
153
  self.lm_head = nn.Linear(D, V, bias=False)
 
 
154
  self.lm_head.weight = self.token_emb.weight
155
 
156
- # Optional Rose projection + anchors
157
- self.rose_proj = nn.Linear(D, D, bias=False)
158
- self.rose_anchors = nn.Parameter(torch.randn(3, D) / (D**0.5))
159
 
160
- # Multi-level pentachora; lazily initialized
161
  self.register_buffer("pent_inited", torch.tensor(0, dtype=torch.uint8), persistent=False)
162
- self.penta_coarse = None
163
- self.penta_medium = None
164
- self.penta_fine = None
165
 
166
- self.apply(self._init)
167
- self.grad_checkpoint = CKPT
168
 
169
  @staticmethod
170
- def _init(m):
171
  if isinstance(m, nn.Linear):
172
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
173
  if m.bias is not None:
@@ -175,92 +194,95 @@ class BeeperRoseGPT(nn.Module):
175
  elif isinstance(m, nn.Embedding):
176
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
177
 
178
- def ensure_pentachora(self, coarse_C: int, medium_C: int, fine_C: int, dim: int, device):
179
- """Initialize three pentachora banks."""
 
 
 
 
180
  if self.pent_inited.item() == 1:
181
  return
182
 
183
- def bank(C):
184
- pts = []
185
- for _ in range(int(C)):
186
- A = torch.randn(5, dim, device=device)
187
- A = F.normalize(A - A.mean(dim=0, keepdim=True), dim=-1)
188
- pts.append(A)
189
- return nn.Parameter(torch.stack(pts, dim=0))
190
-
191
- self.penta_coarse = bank(coarse_C)
192
- self.penta_medium = bank(medium_C)
193
- self.penta_fine = bank(fine_C)
194
  self.pent_inited.fill_(1)
195
 
196
- def _block_forward(self, blk, x):
 
197
  x = x + blk["attn"](blk["norm1"](x))
198
  x = x + blk["mlp"](blk["norm2"](x))
199
  return x
200
 
201
- def backbone(self, idx):
202
  B, T = idx.shape
203
  x = self.token_emb(idx) + self.pos_emb[:, :T, :]
204
  x = self.drop(x)
205
  if self.grad_checkpoint and self.training:
206
  from torch.utils.checkpoint import checkpoint
207
  for blk in self.blocks:
208
- x = checkpoint(lambda _x: self._block_forward(blk, _x), x)
209
  else:
210
  for blk in self.blocks:
211
  x = self._block_forward(blk, x)
212
  return self.norm(x)
213
 
214
- def forward(self, idx):
215
  h = self.backbone(idx)
216
  return self.lm_head(h)
217
 
218
- def hidden_states(self, idx):
 
 
219
  return self.backbone(idx)
220
 
221
- def rose_hidden_pool(self, h: torch.Tensor, mode="mean"):
 
222
  return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
223
 
224
 
225
- # ============================================================================
226
- # Model I/O Utilities
227
- # ============================================================================
 
 
 
 
 
 
228
 
229
- class BeeperIO:
230
- """Utilities for saving and loading model weights."""
231
-
232
- @staticmethod
233
- def clean_state(sd: dict):
234
- """Clean state dict keys from various wrappings."""
235
- out = {}
236
- for k, v in sd.items():
237
- if k.startswith("_orig_mod."):
238
- k = k[10:]
239
- if k.startswith("module."):
240
- k = k[7:]
241
- out[k] = v
242
- return out
243
 
244
- @staticmethod
245
- def load_into_model(model: nn.Module, path: str, map_location="cpu", strict: bool = False):
246
- """Load weights from file into model."""
247
- ext = os.path.splitext(path)[1].lower()
248
- if ext == ".safetensors":
249
- sd = load_safetensors(path, device="cpu")
250
- else:
251
- raw = torch.load(path, map_location="cpu")
252
- sd = raw["model"] if isinstance(raw, dict) and "model" in raw else raw
253
- sd = BeeperIO.clean_state(sd)
254
- result = model.load_state_dict(sd, strict=strict)
255
- return result.missing_keys, result.unexpected_keys
256
 
 
 
 
 
257
 
258
- # ============================================================================
259
- # Text Generation
260
- # ============================================================================
261
 
 
262
  def _detok(text: str) -> str:
263
- """Clean up tokenized text spacing."""
264
  text = re.sub(r"\s+([,.;:!?%])", r"\1", text)
265
  text = re.sub(r"\s+([\)\]\}])", r"\1", text)
266
  text = re.sub(r"([\(\[\{])\s+", r"\1", text)
@@ -268,92 +290,67 @@ def _detok(text: str) -> str:
268
 
269
 
270
  @torch.no_grad()
271
- def generate(model: BeeperRoseGPT,
272
- tok: Tokenizer,
273
- cfg: dict,
274
- prompt: str,
275
- max_new_tokens: int = 120,
276
- temperature: float = None,
277
- top_k: int = None,
278
- top_p: float = None,
279
- repetition_penalty: float = None,
280
- presence_penalty: float = None,
281
- frequency_penalty: float = None,
282
- device: Optional[torch.device] = None,
283
- detokenize: bool = True) -> str:
 
 
284
  """
285
- Generate text from a prompt using the model.
286
-
287
- Args:
288
- model: The BeeperRoseGPT model
289
- tok: Tokenizer instance
290
- cfg: Configuration dictionary
291
- prompt: Input text prompt
292
- max_new_tokens: Maximum number of tokens to generate
293
- temperature: Sampling temperature (higher = more random)
294
- top_k: Top-k sampling parameter
295
- top_p: Top-p (nucleus) sampling parameter
296
- repetition_penalty: Penalty for repeated tokens
297
- presence_penalty: Penalty for tokens that have appeared
298
- frequency_penalty: Penalty based on token frequency
299
- device: Device to run on
300
- detokenize: Whether to clean up tokenization artifacts
301
-
302
- Returns:
303
- Generated text string
304
  """
305
-
306
- # Use defaults from config if not specified
307
- temperature = cfg["temperature"] if temperature is None else temperature
308
- top_k = cfg["top_k"] if top_k is None else top_k
309
- top_p = cfg["top_p"] if top_p is None else top_p
310
- repetition_penalty = cfg["repetition_penalty"] if repetition_penalty is None else repetition_penalty
311
- presence_penalty = cfg["presence_penalty"] if presence_penalty is None else presence_penalty
312
- frequency_penalty = cfg["frequency_penalty"] if frequency_penalty is None else frequency_penalty
313
 
314
  device = device or next(model.parameters()).device
315
  model.eval()
316
-
317
- # Tokenize prompt
318
  ids = tok.encode(prompt).ids
319
  x = torch.tensor([ids], dtype=torch.long, device=device)
320
-
321
- # Track token counts for penalties
322
- counts = torch.zeros(cfg["vocab_size"], dtype=torch.int32, device=device)
323
  for t in ids:
324
- if 0 <= t < cfg["vocab_size"]:
325
  counts[t] += 1
326
 
327
- # Generate tokens
328
- for _ in range(max_new_tokens):
329
- # Get logits for next token
330
  logits = model(x[:, -cfg["context"]:])
331
  logits = logits[:, -1, :]
332
 
333
- # Apply repetition penalty
334
  if repetition_penalty and repetition_penalty != 1.0:
335
  mask = counts > 0
336
  if mask.any():
337
  pos = logits[:, mask] > 0
338
- logits[:, mask][pos] /= repetition_penalty
339
  logits[:, mask][~pos] *= repetition_penalty
340
 
341
- # Apply presence and frequency penalties
342
  if presence_penalty or frequency_penalty:
343
  pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0)
344
  logits = logits - pen.unsqueeze(0)
345
 
346
- # Apply temperature
347
  logits = logits / max(1e-8, temperature)
348
 
349
- # Apply top-k sampling
350
  if top_k and top_k > 0:
351
  k = min(top_k, logits.size(-1))
352
  v, ix = torch.topk(logits, k, dim=-1)
353
  filt = torch.full_like(logits, float("-inf"))
354
  logits = filt.scatter_(-1, ix, v)
355
 
356
- # Apply top-p (nucleus) sampling
357
  if top_p and top_p < 1.0:
358
  sl, si = torch.sort(logits, descending=True)
359
  ps = F.softmax(sl, dim=-1)
@@ -363,47 +360,12 @@ def generate(model: BeeperRoseGPT,
363
  sl = sl.masked_fill(mask, float("-inf"))
364
  logits = torch.full_like(logits, float("-inf")).scatter(-1, si, sl)
365
 
366
- # Sample next token
367
  probs = F.softmax(logits, dim=-1)
368
  next_id = torch.multinomial(probs, num_samples=1)
369
  x = torch.cat([x, next_id], dim=1)
370
- counts[next_id.item()] += 1
 
 
371
 
372
- # Decode output
373
  out = tok.decode(x[0].tolist())
374
  return _detok(out) if detokenize else out
375
-
376
-
377
- # ============================================================================
378
- # Default Configuration
379
- # ============================================================================
380
-
381
- def get_default_config():
382
- """Get the default configuration for the model."""
383
- return {
384
- "name": "Rose-Beeper",
385
- "context": 512,
386
- "vocab_size": 8192,
387
- "dim": 512,
388
- "n_layers": 6,
389
- "n_heads": 8,
390
- "mlp_ratio": 4.0,
391
- "dropout": 0.0,
392
- "resid_dropout": 0.1,
393
- "grad_checkpoint": False,
394
-
395
- # Generation defaults
396
- "temperature": 0.9,
397
- "top_k": 40,
398
- "top_p": 0.9,
399
- "repetition_penalty": 1.10,
400
- "presence_penalty": 0.6,
401
- "frequency_penalty": 0.0,
402
-
403
- # Capoera configuration
404
- "capoera": {
405
- "enable": True,
406
- "topic_bins": 512,
407
- "mood_bins": 7,
408
- }
409
- }
 
1
+ # beeper.py
2
+ # --------------------------------------------------------------------------------------------------
3
+ # Beeper Rose-based tiny GPT (inference module)
4
+ # - Decoder-only GPT with SDPA (FlashAttention path on Ampere+)
5
+ # - Model exactly mirrors the training-time architecture you provided (dim=512, L=6, H=8)
6
+ # - Safe state-dict loader that auto-sizes pentachora banks before strict load
7
+ # - Generation API with repetition/presence/frequency penalties (same defaults as training)
8
+ # --------------------------------------------------------------------------------------------------
9
+ from __future__ import annotations
10
 
 
11
  import math
12
+ import re
13
+ import inspect
14
+ from contextlib import nullcontext
15
+ from typing import Optional, Tuple
16
+
17
  import torch
18
  import torch.nn as nn
19
  import torch.nn.functional as F
 
 
 
 
 
 
 
20
 
21
+ # --- Prefer high-throughput matmul where possible (Ampere/Hopper) ---
22
+ torch.set_float32_matmul_precision("high")
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
 
26
+ # ---- Version-safe SDPA (FlashAttention) selection -------------------------------------------------
27
  try:
28
+ # PyTorch 2.3+ modern API
29
  from torch.nn.attention import sdpa_kernel as _sdpa_kernel_modern
30
  from torch.nn.attention import SDPBackend as _SDPBackend
31
  _SDPA_SIG = inspect.signature(_sdpa_kernel_modern)
32
  _sdpa_kernel = _sdpa_kernel_modern
33
  except Exception:
34
  try:
35
+ # Legacy API
36
  from torch.backends.cuda import sdp_kernel as _sdpa_kernel_legacy
37
  _SDPA_SIG = inspect.signature(_sdpa_kernel_legacy)
38
  _SDPBackend = None
 
44
 
45
 
46
  def sdpa_ctx_prefer_flash():
47
+ """
48
+ Best-effort context to bias SDPA toward FlashAttention on supported GPUs.
49
+ Falls back to no-op if not available.
50
+ """
51
  if _sdpa_kernel is None or _SDPA_SIG is None:
52
  return nullcontext()
53
 
54
  params = {p.name for p in _SDPA_SIG.parameters.values()}
55
  try:
 
56
  if "backends" in params and _SDPBackend is not None:
57
  return _sdpa_kernel(backends=[
58
  _SDPBackend.FLASH_ATTENTION,
59
  _SDPBackend.EFFICIENT_ATTENTION,
60
  _SDPBackend.MATH
61
  ])
 
62
  if "backend" in params and _SDPBackend is not None:
63
  return _sdpa_kernel(backend=_SDPBackend.FLASH_ATTENTION)
 
64
  if {"enable_flash", "enable_math", "enable_mem_efficient"} <= params:
65
  return _sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True)
66
  if {"use_flash", "use_math", "use_mem_efficient"} <= params:
 
70
  return nullcontext()
71
 
72
 
73
+ # --------------------------------- Core blocks ------------------------------------------------------
 
 
 
74
  class CausalSelfAttention(nn.Module):
75
+ """
76
+ Multi-head causal self-attention layer using PyTorch SDPA.
77
+ - On CUDA, uses scaled_dot_product_attention with is_causal=True and dropout during training.
78
+ - On CPU, falls back to manual masked attention.
79
+ """
80
  def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
81
  super().__init__()
82
+ assert dim % n_heads == 0, "dim must be divisible by n_heads"
83
+ self.nh = int(n_heads)
84
+ self.hd = dim // self.nh
85
  self.qkv = nn.Linear(dim, 3 * dim, bias=False)
86
  self.proj = nn.Linear(dim, dim, bias=False)
87
+ self.attn_dropout = float(attn_dropout)
88
 
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
  B, T, C = x.shape
91
  qkv = self.qkv(x)
92
  q, k, v = qkv.chunk(3, dim=-1)
93
+ q = q.view(B, T, self.nh, self.hd).transpose(1, 2) # [B,H,T,D]
94
  k = k.view(B, T, self.nh, self.hd).transpose(1, 2)
95
  v = v.view(B, T, self.nh, self.hd).transpose(1, 2)
96
 
 
114
 
115
 
116
  class MLP(nn.Module):
117
+ """GELU MLP with dropout, sized by mlp_ratio."""
118
+ def __init__(self, dim: int, mlp_ratio: float = 4.0, dropout: float = 0.1):
 
119
  super().__init__()
120
  hidden = int(dim * mlp_ratio)
121
  self.fc1 = nn.Linear(dim, hidden)
122
  self.fc2 = nn.Linear(hidden, dim)
123
  self.drop = nn.Dropout(dropout)
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
  x = self.fc1(x)
127
  x = F.gelu(x, approximate="tanh")
128
  x = self.drop(x)
 
131
  return x
132
 
133
 
134
+ # --------------------------------- Beeper Model -----------------------------------------------------
135
  class BeeperRoseGPT(nn.Module):
136
+ """
137
+ Decoder-only GPT used by Beeper during training and inference.
138
+
139
+ Config keys used:
140
+ - vocab_size, dim, context, n_heads, n_layers, mlp_ratio
141
+ - resid_dropout, dropout, grad_checkpoint
142
+ Notes:
143
+ - Shares token embedding with LM head (tied weights).
144
+ - Includes Rose projection/anchors and pentachora banks; unused for plain generation,
145
+ but kept for full compatibility with trained checkpoints.
146
+ """
147
  def __init__(self, cfg: dict):
148
  super().__init__()
149
  V, D, Ctx = cfg["vocab_size"], cfg["dim"], cfg["context"]
150
  H, L, MR = cfg["n_heads"], cfg["n_layers"], cfg["mlp_ratio"]
151
+ RD, AD = cfg.get("resid_dropout", 0.1), cfg.get("dropout", 0.0)
152
+ self.grad_checkpoint = bool(cfg.get("grad_checkpoint", False))
153
+
154
+ self.vocab_size, self.context = int(V), int(Ctx)
155
 
 
156
  self.token_emb = nn.Embedding(V, D)
157
+ self.pos_emb = nn.Parameter(torch.zeros(1, Ctx, D))
158
+ self.drop = nn.Dropout(RD)
159
 
160
  self.blocks = nn.ModuleList([
161
  nn.ModuleDict({
162
  "norm1": nn.LayerNorm(D),
163
+ "attn": CausalSelfAttention(D, H, attn_dropout=AD),
164
  "norm2": nn.LayerNorm(D),
165
+ "mlp": MLP(D, mlp_ratio=MR, dropout=RD),
166
+ })
167
+ for _ in range(L)
168
  ])
169
+
170
+ self.norm = nn.LayerNorm(D)
171
  self.lm_head = nn.Linear(D, V, bias=False)
172
+
173
+ # Weight tying
174
  self.lm_head.weight = self.token_emb.weight
175
 
176
+ # Rose projection + anchors (present in checkpoints)
177
+ self.rose_proj = nn.Linear(D, D, bias=False)
178
+ self.rose_anchors = nn.Parameter(torch.randn(3, D) / (D ** 0.5))
179
 
180
+ # Pentachora banks (created lazily to match state dict)
181
  self.register_buffer("pent_inited", torch.tensor(0, dtype=torch.uint8), persistent=False)
182
+ self.penta_coarse: Optional[nn.Parameter] = None # [C,5,D]
183
+ self.penta_medium: Optional[nn.Parameter] = None # [T,5,D]
184
+ self.penta_fine: Optional[nn.Parameter] = None # [M,5,D]
185
 
186
+ self.apply(self._init_weights)
 
187
 
188
  @staticmethod
189
+ def _init_weights(m: nn.Module):
190
  if isinstance(m, nn.Linear):
191
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
192
  if m.bias is not None:
 
194
  elif isinstance(m, nn.Embedding):
195
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
196
 
197
+ # ---- Pentachora creation (must match sizes in checkpoint before strict load) -------------------
198
+ def ensure_pentachora(self, coarse_C: int, medium_C: int, fine_C: int, dim: int, device: torch.device):
199
+ """
200
+ Initialize pentachora banks if not already present.
201
+ Shapes must match checkpoint entries for strict loading.
202
+ """
203
  if self.pent_inited.item() == 1:
204
  return
205
 
206
+ def bank(C: int) -> nn.Parameter:
207
+ if C <= 0:
208
+ # Keep a zero-sized parameter to satisfy strict loading (rare).
209
+ return nn.Parameter(torch.zeros((0, 5, dim), device=device))
210
+ pts = torch.randn(C, 5, dim, device=device)
211
+ pts = F.normalize(pts - pts.mean(dim=1, keepdim=True), dim=-1)
212
+ return nn.Parameter(pts)
213
+
214
+ self.penta_coarse = bank(int(coarse_C))
215
+ self.penta_medium = bank(int(medium_C))
216
+ self.penta_fine = bank(int(fine_C))
217
  self.pent_inited.fill_(1)
218
 
219
+ # ---- Backbone / forward -----------------------------------------------------------------------
220
+ def _block_forward(self, blk: nn.ModuleDict, x: torch.Tensor) -> torch.Tensor:
221
  x = x + blk["attn"](blk["norm1"](x))
222
  x = x + blk["mlp"](blk["norm2"](x))
223
  return x
224
 
225
+ def backbone(self, idx: torch.Tensor) -> torch.Tensor:
226
  B, T = idx.shape
227
  x = self.token_emb(idx) + self.pos_emb[:, :T, :]
228
  x = self.drop(x)
229
  if self.grad_checkpoint and self.training:
230
  from torch.utils.checkpoint import checkpoint
231
  for blk in self.blocks:
232
+ x = checkpoint(lambda _x: self._block_forward(blk, _x), x) # type: ignore[arg-type]
233
  else:
234
  for blk in self.blocks:
235
  x = self._block_forward(blk, x)
236
  return self.norm(x)
237
 
238
+ def forward(self, idx: torch.Tensor) -> torch.Tensor:
239
  h = self.backbone(idx)
240
  return self.lm_head(h)
241
 
242
+ # ---- Utilities ---------------------------------------------------------------------------------
243
+ def hidden_states(self, idx: torch.Tensor) -> torch.Tensor:
244
+ """Return final hidden states (pre-LM head)."""
245
  return self.backbone(idx)
246
 
247
+ def rose_hidden_pool(self, h: torch.Tensor, mode: str = "mean") -> torch.Tensor:
248
+ """Pool hidden states for Rose-related terms (unused in plain generation)."""
249
  return h.mean(dim=1) if mode == "mean" else h[:, -1, :]
250
 
251
 
252
+ # --------------------------------- Loader helpers ---------------------------------------------------
253
+ def prepare_model_for_state_dict(
254
+ model: BeeperRoseGPT,
255
+ state_dict: "dict[str, torch.Tensor]",
256
+ device: Optional[torch.device] = None,
257
+ ) -> None:
258
+ """
259
+ Ensure model has pentachora parameters sized to match the incoming state_dict,
260
+ so we can load with strict=True.
261
 
262
+ If the checkpoint has no pentachora (older versions), we do nothing.
263
+ """
264
+ device = device or next(model.parameters()).device
265
+ need = all(k in state_dict for k in ("penta_coarse", "penta_medium", "penta_fine"))
266
+ if not need:
267
+ return
 
 
 
 
 
 
 
 
268
 
269
+ pc, pt, pm = state_dict["penta_coarse"], state_dict["penta_medium"], state_dict["penta_fine"]
270
+ # Expect [C,5,D]
271
+ def dims_ok(t: torch.Tensor) -> bool:
272
+ return t.ndim == 3 and t.size(1) == 5 and t.size(2) == model.token_emb.embedding_dim
273
+
274
+ if not (dims_ok(pc) and dims_ok(pt) and dims_ok(pm)):
275
+ # Shapes inconsistent; fall back to non-strict load later.
276
+ return
 
 
 
 
277
 
278
+ coarse_C = pc.size(0)
279
+ topic_C = pt.size(0)
280
+ mood_C = pm.size(0)
281
+ model.ensure_pentachora(coarse_C, topic_C, mood_C, dim=pc.size(2), device=device)
282
 
 
 
 
283
 
284
+ # --------------------------------- Generation -------------------------------------------------------
285
  def _detok(text: str) -> str:
 
286
  text = re.sub(r"\s+([,.;:!?%])", r"\1", text)
287
  text = re.sub(r"\s+([\)\]\}])", r"\1", text)
288
  text = re.sub(r"([\(\[\{])\s+", r"\1", text)
 
290
 
291
 
292
  @torch.no_grad()
293
+ def generate(
294
+ model: BeeperRoseGPT,
295
+ tok, # Hugging Face Tokenizers `Tokenizer`
296
+ cfg: dict,
297
+ prompt: str,
298
+ max_new_tokens: int = 120,
299
+ temperature: Optional[float] = None,
300
+ top_k: Optional[int] = None,
301
+ top_p: Optional[float] = None,
302
+ repetition_penalty: Optional[float] = None,
303
+ presence_penalty: Optional[float] = None,
304
+ frequency_penalty: Optional[float] = None,
305
+ device: Optional[torch.device] = None,
306
+ detokenize: bool = True,
307
+ ) -> str:
308
  """
309
+ Penalized nucleus sampling (same knobs as training script).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  """
311
+ temperature = cfg.get("temperature", 0.9) if temperature is None else float(temperature)
312
+ top_k = cfg.get("top_k", 40) if top_k is None else int(top_k)
313
+ top_p = cfg.get("top_p", 0.9) if top_p is None else float(top_p)
314
+ repetition_penalty = cfg.get("repetition_penalty", 1.10) if repetition_penalty is None else float(repetition_penalty)
315
+ presence_penalty = cfg.get("presence_penalty", 0.6) if presence_penalty is None else float(presence_penalty)
316
+ frequency_penalty = cfg.get("frequency_penalty", 0.0) if frequency_penalty is None else float(frequency_penalty)
 
 
317
 
318
  device = device or next(model.parameters()).device
319
  model.eval()
320
+
 
321
  ids = tok.encode(prompt).ids
322
  x = torch.tensor([ids], dtype=torch.long, device=device)
323
+ V = int(cfg["vocab_size"])
324
+ counts = torch.zeros(V, dtype=torch.int32, device=device)
 
325
  for t in ids:
326
+ if 0 <= t < V:
327
  counts[t] += 1
328
 
329
+ for _ in range(int(max_new_tokens)):
 
 
330
  logits = model(x[:, -cfg["context"]:])
331
  logits = logits[:, -1, :]
332
 
333
+ # Repetition penalty (CTRL-like)
334
  if repetition_penalty and repetition_penalty != 1.0:
335
  mask = counts > 0
336
  if mask.any():
337
  pos = logits[:, mask] > 0
338
+ logits[:, mask][pos] /= repetition_penalty
339
  logits[:, mask][~pos] *= repetition_penalty
340
 
341
+ # Presence/frequency penalties (OpenAI-like)
342
  if presence_penalty or frequency_penalty:
343
  pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0)
344
  logits = logits - pen.unsqueeze(0)
345
 
 
346
  logits = logits / max(1e-8, temperature)
347
 
 
348
  if top_k and top_k > 0:
349
  k = min(top_k, logits.size(-1))
350
  v, ix = torch.topk(logits, k, dim=-1)
351
  filt = torch.full_like(logits, float("-inf"))
352
  logits = filt.scatter_(-1, ix, v)
353
 
 
354
  if top_p and top_p < 1.0:
355
  sl, si = torch.sort(logits, descending=True)
356
  ps = F.softmax(sl, dim=-1)
 
360
  sl = sl.masked_fill(mask, float("-inf"))
361
  logits = torch.full_like(logits, float("-inf")).scatter(-1, si, sl)
362
 
 
363
  probs = F.softmax(logits, dim=-1)
364
  next_id = torch.multinomial(probs, num_samples=1)
365
  x = torch.cat([x, next_id], dim=1)
366
+ nid = next_id.item()
367
+ if 0 <= nid < V:
368
+ counts[nid] += 1
369
 
 
370
  out = tok.decode(x[0].tolist())
371
  return _detok(out) if detokenize else out