File size: 9,522 Bytes
e9b0f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
beeper_model.py - Core model module for Beeper
Extracted from the training code for use in inference/apps
"""

import os
import re
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from safetensors.torch import load_file as load_safetensors

# =========================================================================================
# Model Components
# =========================================================================================

class CausalSelfAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        assert dim % n_heads == 0
        self.nh = n_heads
        self.hd = dim // n_heads
        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
        self.attn_dropout = attn_dropout

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(B, T, self.nh, self.hd).transpose(1, 2)
        k = k.view(B, T, self.nh, self.hd).transpose(1, 2)
        v = v.view(B, T, self.nh, self.hd).transpose(1, 2)

        # Use scaled_dot_product_attention when available
        y = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.attn_dropout if self.training else 0.0,
        )
        
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(y)


class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x, approximate="tanh")
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class BeeperRoseGPT(nn.Module):
    def __init__(self, cfg: dict):
        super().__init__()
        V = cfg.get("vocab_size", 8192)
        D = cfg.get("dim", 512)
        Ctx = cfg.get("context", 512)
        H = cfg.get("n_heads", 8)
        L = cfg.get("n_layers", 6)
        MR = cfg.get("mlp_ratio", 4.0)
        RD = cfg.get("resid_dropout", 0.1)
        AD = cfg.get("dropout", 0.0)

        self.vocab_size = V
        self.context = Ctx
        
        # Core transformer components
        self.token_emb = nn.Embedding(V, D)
        self.pos_emb = nn.Parameter(torch.zeros(1, Ctx, D))
        self.drop = nn.Dropout(RD)

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "norm1": nn.LayerNorm(D),
                "attn": CausalSelfAttention(D, H, attn_dropout=AD),
                "norm2": nn.LayerNorm(D),
                "mlp": MLP(D, mlp_ratio=MR, dropout=RD),
            }) for _ in range(L)
        ])
        
        self.norm = nn.LayerNorm(D)
        self.lm_head = nn.Linear(D, V, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.token_emb.weight

        # Rose components (for compatibility, may not be used in inference)
        self.rose_proj = nn.Linear(D, D, bias=False)
        self.rose_anchors = nn.Parameter(torch.randn(3, D) / (D**0.5))

        # Pentachora placeholders (not needed for inference but for weight compatibility)
        self.register_buffer("pent_inited", torch.tensor(0, dtype=torch.uint8), persistent=False)
        self.penta_coarse = None
        self.penta_medium = None
        self.penta_fine = None

        self.apply(self._init)
        self.grad_checkpoint = False

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def _block_forward(self, blk, x):
        x = x + blk["attn"](blk["norm1"](x))
        x = x + blk["mlp"](blk["norm2"](x))
        return x

    def backbone(self, idx):
        B, T = idx.shape
        x = self.token_emb(idx) + self.pos_emb[:, :T, :]
        x = self.drop(x)
        
        for blk in self.blocks:
            x = self._block_forward(blk, x)
        
        return self.norm(x)

    def forward(self, idx):
        h = self.backbone(idx)
        return self.lm_head(h)

    def hidden_states(self, idx):
        return self.backbone(idx)

    def load_state_dict(self, state_dict, strict=False):
        """Custom load that handles pentachora bank initialization gracefully"""
        # Clean state dict keys
        cleaned = {}
        for k, v in state_dict.items():
            if k.startswith("_orig_mod."):
                k = k[10:]
            if k.startswith("module."):
                k = k[7:]
            cleaned[k] = v
        
        # Initialize pentachora if present in checkpoint
        if "penta_coarse" in cleaned:
            self.penta_coarse = nn.Parameter(cleaned["penta_coarse"])
        if "penta_medium" in cleaned:
            self.penta_medium = nn.Parameter(cleaned["penta_medium"])
        if "penta_fine" in cleaned:
            self.penta_fine = nn.Parameter(cleaned["penta_fine"])
        
        return super().load_state_dict(cleaned, strict=strict)


# =========================================================================================
# Generation
# =========================================================================================

def _detokenize(text: str) -> str:
    """Clean up tokenization artifacts"""
    text = re.sub(r"\s+([,.;:!?%])", r"\1", text)
    text = re.sub(r"\s+([\)\]\}])", r"\1", text)
    text = re.sub(r"([\(\[\{])\s+", r"\1", text)
    return text


@torch.no_grad()
def generate(
    model: BeeperRoseGPT,
    tok,  # Tokenizer
    cfg: dict,
    prompt: str,
    max_new_tokens: int = 120,
    temperature: float = None,
    top_k: int = None,
    top_p: float = None,
    repetition_penalty: float = None,
    presence_penalty: float = None,
    frequency_penalty: float = None,
    device: Optional[torch.device] = None,
    detokenize: bool = True
) -> str:
    """
    Generate text from Beeper model with various sampling strategies.
    """
    # Use defaults from config if not specified
    temperature = temperature if temperature is not None else cfg.get("temperature", 0.9)
    top_k = top_k if top_k is not None else cfg.get("top_k", 40)
    top_p = top_p if top_p is not None else cfg.get("top_p", 0.9)
    repetition_penalty = repetition_penalty if repetition_penalty is not None else cfg.get("repetition_penalty", 1.1)
    presence_penalty = presence_penalty if presence_penalty is not None else cfg.get("presence_penalty", 0.6)
    frequency_penalty = frequency_penalty if frequency_penalty is not None else cfg.get("frequency_penalty", 0.0)

    device = device or next(model.parameters()).device
    model.eval()
    
    # Encode prompt
    ids = tok.encode(prompt).ids
    x = torch.tensor([ids], dtype=torch.long, device=device)
    
    # Track token frequencies for penalties
    vocab_size = cfg.get("vocab_size", 8192)
    counts = torch.zeros(vocab_size, dtype=torch.int32, device=device)
    for t in ids:
        if 0 <= t < vocab_size:
            counts[t] += 1

    # Generate tokens
    for _ in range(max_new_tokens):
        # Get logits for next token
        context_window = cfg.get("context", 512)
        logits = model(x[:, -context_window:])
        logits = logits[:, -1, :]

        # Apply repetition penalty
        if repetition_penalty and repetition_penalty != 1.0:
            mask = counts > 0
            if mask.any():
                pos = logits[:, mask] > 0
                logits[:, mask][pos] /= repetition_penalty
                logits[:, mask][~pos] *= repetition_penalty

        # Apply presence and frequency penalties
        if presence_penalty or frequency_penalty:
            pen = counts.float() * (frequency_penalty or 0.0) + (counts > 0).float() * (presence_penalty or 0.0)
            logits = logits - pen.unsqueeze(0)

        # Temperature scaling
        logits = logits / max(1e-8, temperature)

        # Top-k filtering
        if top_k and top_k > 0:
            k = min(top_k, logits.size(-1))
            v, ix = torch.topk(logits, k, dim=-1)
            filt = torch.full_like(logits, float("-inf"))
            logits = filt.scatter_(-1, ix, v)

        # Top-p (nucleus) filtering
        if top_p and top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            probs = F.softmax(sorted_logits, dim=-1)
            cumulative_probs = torch.cumsum(probs, dim=-1)
            
            # Find cutoff
            cutoff_idx = (cumulative_probs > top_p).float().argmax(dim=-1)
            mask = torch.arange(logits.size(-1), device=device).unsqueeze(0) > cutoff_idx.unsqueeze(-1)
            sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
            logits = torch.full_like(logits, float("-inf")).scatter(-1, sorted_indices, sorted_logits)

        # Sample next token
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        
        # Append to sequence
        x = torch.cat([x, next_id], dim=1)
        counts[next_id.item()] += 1

    # Decode output
    output = tok.decode(x[0].tolist())
    return _detokenize(output) if detokenize else output