""" beeper_model.py - Core model module for Beeper Extracted from the training code for use in inference/apps """ import os import re import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from safetensors.torch import load_file as load_safetensors # ========================================================================================= # Model Components # ========================================================================================= class CausalSelfAttention(nn.Module): def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0): super().__init__() assert dim % n_heads == 0 self.nh = n_heads self.hd = dim // n_heads self.qkv = nn.Linear(dim, 3 * dim, bias=False) self.proj = nn.Linear(dim, dim, bias=False) self.attn_dropout = attn_dropout def forward(self, x): B, T, C = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(B, T, self.nh, self.hd).transpose(1, 2) k = k.view(B, T, self.nh, self.hd).transpose(1, 2) v = v.view(B, T, self.nh, self.hd).transpose(1, 2) # Use scaled_dot_product_attention when available y = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.attn_dropout if self.training else 0.0, ) y = y.transpose(1, 2).contiguous().view(B, T, C) return self.proj(y) class MLP(nn.Module): def __init__(self, dim, mlp_ratio=4.0, dropout=0.1): super().__init__() hidden = int(dim * mlp_ratio) self.fc1 = nn.Linear(dim, hidden) self.fc2 = nn.Linear(hidden, dim) self.drop = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = F.gelu(x, approximate="tanh") x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class BeeperRoseGPT(nn.Module): def __init__(self, cfg: dict): super().__init__() V = cfg.get("vocab_size", 8192) D = cfg.get("dim", 512) Ctx = cfg.get("context", 512) H = cfg.get("n_heads", 8) L = cfg.get("n_layers", 6) MR = cfg.get("mlp_ratio", 4.0) RD = cfg.get("resid_dropout", 0.1) AD = cfg.get("dropout", 0.0) self.vocab_size = V self.context = Ctx # Core transformer components self.token_emb = nn.Embedding(V, D) self.pos_emb = nn.Parameter(torch.zeros(1, Ctx, D)) self.drop = nn.Dropout(RD) self.blocks = nn.ModuleList([ nn.ModuleDict({ "norm1": nn.LayerNorm(D), "attn": CausalSelfAttention(D, H, attn_dropout=AD), "norm2": nn.LayerNorm(D), "mlp": MLP(D, mlp_ratio=MR, dropout=RD), }) for _ in range(L) ]) self.norm = nn.LayerNorm(D) self.lm_head = nn.Linear(D, V, bias=False) # Weight tying self.lm_head.weight = self.token_emb.weight # Rose components (for compatibility, may not be used in inference) self.rose_proj = nn.Linear(D, D, bias=False) self.rose_anchors = nn.Parameter(torch.randn(3, D) / (D**0.5)) # Pentachora placeholders (not needed for inference but for weight compatibility) self.register_buffer("pent_inited", torch.tensor(0, dtype=torch.uint8), persistent=False) self.penta_coarse = None self.penta_medium = None self.penta_fine = None self.apply(self._init) self.grad_checkpoint = False @staticmethod def _init(m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def _block_forward(self, blk, x): x = x + blk["attn"](blk["norm1"](x)) x = x + blk["mlp"](blk["norm2"](x)) return x def backbone(self, idx): B, T = idx.shape x = self.token_emb(idx) + self.pos_emb[:, :T, :] x = self.drop(x) for blk in self.blocks: x = self._block_forward(blk, x) return self.norm(x) def forward(self, idx): h = self.backbone(idx) return self.lm_head(h) def hidden_states(self, idx): return self.backbone(idx) def load_state_dict(self, state_dict, strict=False): """Custom load that handles pentachora bank initialization gracefully""" # Clean state dict keys cleaned = {} for k, v in state_dict.items(): if k.startswith("_orig_mod."): k = k[10:] if k.startswith("module."): k = k[7:] cleaned[k] = v # Initialize pentachora if present in checkpoint if "penta_coarse" in cleaned: self.penta_coarse = nn.Parameter(cleaned["penta_coarse"]) if "penta_medium" in cleaned: self.penta_medium = nn.Parameter(cleaned["penta_medium"]) if "penta_fine" in cleaned: self.penta_fine = nn.Parameter(cleaned["penta_fine"]) return super().load_state_dict(cleaned, strict=strict) # ========================================================================================= # Generation # ========================================================================================= def _detokenize(text: str) -> str: """Clean up tokenization artifacts""" text = re.sub(r"\s+([,.;:!?%])", r"\1", text) text = re.sub(r"\s+([\)\]\}])", r"\1", text) text = re.sub(r"([\(\[\{])\s+", r"\1", text) return text @torch.no_grad() def generate( model: BeeperRoseGPT, tok, # Tokenizer cfg: dict, prompt: str, max_new_tokens: int = 120, temperature: float = None, top_k: int = None, top_p: float = None, repetition_penalty: float = None, presence_penalty: float = None, frequency_penalty: float = None, device: Optional[torch.device] = None, detokenize: bool = True ) -> str: """ Generate text from Beeper model with various sampling strategies. """ # Use defaults from config if not specified temperature = temperature if temperature is not None else cfg.get("temperature", 0.9) top_k = top_k if top_k is not None else cfg.get("top_k", 40) top_p = top_p if top_p is not None else cfg.get("top_p", 0.9) repetition_penalty = repetition_penalty if repetition_penalty is not None else cfg.get("repetition_penalty", 1.1) presence_penalty = presence_penalty if presence_penalty is not None else cfg.get("presence_penalty", 0.6) frequency_penalty = frequency_penalty if frequency_penalty is not None else cfg.get("frequency_penalty", 0.0) device = device or next(model.parameters()).device model.eval() # Encode prompt ids = tok.encode(prompt).ids x = torch.tensor([ids], dtype=torch.long, device=device) # Track token frequencies for penalties vocab_size = cfg.get("vocab_size", 8192) counts = torch.zeros(vocab_size, dtype=torch.int32, device=device) for t in ids: if 0 <= t < vocab_size: counts[t] += 1 # Generate tokens for _ in range(max_new_tokens): # Get logits for next token context_window = cfg.get("context", 512) logits = model(x[:, -context_window:]) logits = logits[:, -1, :] # Apply repetition penalty if repetition_penalty and repetition_penalty != 1.0: mask = counts > 0 if mask.any(): pos = logits[:, mask] > 0 logits[:, mask][pos] /= repetition_penalty logits[:, mask][~pos] *= repetition_penalty # Apply presence and frequency penalties if presence_penalty or frequency_penalty: pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0) logits = logits - pen.unsqueeze(0) # Temperature scaling logits = logits / max(1e-8, temperature) # Top-k filtering if top_k and top_k > 0: k = min(top_k, logits.size(-1)) v, ix = torch.topk(logits, k, dim=-1) filt = torch.full_like(logits, float("-inf")) logits = filt.scatter_(-1, ix, v) # Top-p (nucleus) filtering if top_p and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) # Find cutoff cutoff_idx = (cumulative_probs > top_p).float().argmax(dim=-1) mask = torch.arange(logits.size(-1), device=device).unsqueeze(0) > cutoff_idx.unsqueeze(-1) sorted_logits = sorted_logits.masked_fill(mask, float("-inf")) logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_indices, sorted_logits) # Sample next token probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) # Append to sequence x = torch.cat([x, next_id], dim=1) counts[next_id.item()] += 1 # Decode output output = tok.decode(x[0].tolist()) return _detokenize(output) if detokenize else output