File size: 14,766 Bytes
6d3d780
 
 
 
 
 
 
 
 
 
 
 
7afe1ac
6d3d780
 
 
 
 
 
7afe1ac
6d3d780
 
 
7afe1ac
6d3d780
7afe1ac
 
6d3d780
7afe1ac
 
 
 
 
 
6d3d780
 
 
 
7afe1ac
6d3d780
 
7afe1ac
6d3d780
7afe1ac
 
6d3d780
 
7afe1ac
6d3d780
 
7afe1ac
 
6d3d780
 
 
7afe1ac
6d3d780
 
 
 
7afe1ac
6d3d780
 
7afe1ac
6d3d780
 
7afe1ac
6d3d780
 
 
 
 
 
7afe1ac
 
6d3d780
 
 
 
 
 
7afe1ac
6d3d780
 
7afe1ac
 
 
 
 
 
 
 
 
 
 
 
810a93a
7afe1ac
 
 
 
 
 
 
 
 
 
 
 
 
515c5ed
810a93a
7afe1ac
 
 
 
 
810a93a
 
 
 
 
 
 
 
6d3d780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810a93a
6d3d780
 
 
7afe1ac
810a93a
6d3d780
 
 
 
810a93a
7afe1ac
6d3d780
810a93a
 
6d3d780
7afe1ac
 
 
 
 
 
 
 
 
 
 
 
810a93a
7afe1ac
 
810a93a
 
6d3d780
 
7afe1ac
 
 
810a93a
7afe1ac
 
 
 
 
810a93a
 
 
7afe1ac
 
 
6d3d780
 
 
 
 
 
 
 
 
3140e72
6d3d780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7afe1ac
 
6d3d780
 
 
 
7afe1ac
6d3d780
515c5ed
6d3d780
7afe1ac
6d3d780
 
86aec55
e83db0f
 
 
 
 
 
 
 
 
86aec55
 
 
 
 
 
 
 
 
 
 
6d3d780
515c5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d3d780
 
7afe1ac
6d3d780
 
7afe1ac
 
 
 
 
 
 
 
515c5ed
7afe1ac
 
 
6d3d780
 
 
7afe1ac
6d3d780
 
 
 
7afe1ac
 
6d3d780
7afe1ac
 
6d3d780
 
 
7afe1ac
6d3d780
7afe1ac
 
6d3d780
 
 
7afe1ac
 
6d3d780
 
 
7afe1ac
 
6d3d780
 
 
 
 
7afe1ac
 
6d3d780
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
"""
h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py
"""
import gradio as gr
import plotly.graph_objects as go
import numpy as np

# Fixed BPE parameters
bpe_ps = 4.4  # determined by tokenizer
n_ctx_base = 8192
n_heads = 20
n_vocab = 128000
n_layers = 26 # Used for BPE model and BLT Global model

# Fixed local model parameters
local_d_model = 1024
local_g_size = 1
local_n_ctx = 512  # in bytes
local_n_heads = 16
local_n_vocab = 256 # Used for BLT Local model
local_d_model_k = local_d_model / local_n_heads
local_d_ff_multiplier = 4

def openai_flops_per_token(n_layers_val, n_heads_val, d_model_val, n_ctx_val, n_vocab_val, ff_ratio=4):
    """Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
    d_attn = d_model_val // n_heads_val
    d_ff = d_model_val * ff_ratio

    embeddings = 4 * d_model_val # FLOPs for embeddings - not parameter count
    attn_qkv = 2 * n_layers_val * d_model_val * 3 * (d_attn * n_heads_val)
    attn_mask = 2 * n_layers_val * n_ctx_val * (d_attn * n_heads_val)
    attn_project = 2 * n_layers_val * (d_attn * n_heads_val) * d_model_val
    ff = 2 * n_layers_val * 2 * d_model_val * d_ff
    logits = 2 * d_model_val * n_vocab_val

    return embeddings + attn_qkv + attn_mask + attn_project + ff + logits


def cross_attention_flops_per_token(n_layers_ca, n_ctx_cross_attn_kv_len, d_model_ca):
    ca_qo_proj_flops = (
        # Cross Attention QO FLOPs + backward
        2 * 4 * d_model_ca**2
    )
    ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model_ca
    return n_layers_ca * (ca_qo_proj_flops + ca_context_flops)


def calculate_flops(blt_ps, d_model_slider, local_n_layers_slider):
    # BPE calculations
    n_ctx = int(n_ctx_base / bpe_ps)
    bpe_flops_per_token_val = openai_flops_per_token(n_layers, n_heads, d_model_slider, n_ctx, n_vocab)
    bpe_per_byte = bpe_flops_per_token_val / bpe_ps

    # BLT Global calculations
    blt_n_ctx = int(n_ctx_base / blt_ps)
    blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model_slider, blt_n_ctx, n_vocab_val=0) # n_vocab=0 for global
    blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps

    # BLT Local calculations
    local_models_transformer_flops_per_byte = openai_flops_per_token(
        local_n_layers_slider, local_n_heads, local_d_model, local_n_ctx, local_n_vocab, ff_ratio=local_d_ff_multiplier
    )
    encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
        local_n_layers_slider / 2, local_n_ctx, local_d_model
    )
    decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
        local_n_layers_slider / 2, local_n_ctx // blt_ps, local_d_model
    )
    local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
    local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte

    # Calculate advantage
    blt_total = local_models_flops + blt_global_flops_per_byte
    advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) if bpe_per_byte != 0 else 0


    return {
        'bpe_per_byte': bpe_per_byte,
        'blt_global': blt_global_flops_per_byte,
        'blt_local': local_models_flops,
        'blt_total': blt_total,
        'advantage': advantage,
    }

def format_params_display(num_params):
    """Formats number of parameters into a string with M or B units."""
    if num_params is None:
        return ""
    if abs(num_params) >= 1_000_000_000:
        return f"{num_params / 1_000_000_000:.1f}B Params"
    elif abs(num_params) >= 1_000_000:
        return f"{num_params / 1_000_000:.1f}M Params"
    else: # For numbers less than 1M
        return f"{num_params / 1_000_000:.2f}M Params"



def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
    results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)

    # Calculate model parameters
    # BPE Model Parameters: 12 * N * D^2 + 2 * V * D
    bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)

    # BLT Model Parameters
    # Global Component: 12 * N * D^2 (no main vocab projection)
    blt_global_internal_params = 12 * n_layers * d_model_slider**2

    # Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
    blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
                                     (2 * local_n_vocab * local_d_model)

    # Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
    blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
    blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params

    bpe_params_str = format_params_display(bpe_model_params)

    # Format BLT global and local parameters separately
    blt_global_params_fmt_str = format_params_display(blt_global_internal_params)
    blt_local_params_fmt_str = format_params_display(blt_local_total_internal_params)

    # Combine for annotation text, using <br> for line break
    blt_combined_params_str = f"Global: {blt_global_params_fmt_str}<br>Local: {blt_local_params_fmt_str}"


    # Create the figure with subplots for better control
    fig = go.Figure()

    # Add BPE bar (only for BPE category)
    fig.add_trace(go.Bar(
        name='BPE',
        x=['BPE'],
        y=[results['bpe_per_byte']],
        text=[f"{results['bpe_per_byte']:.2e}"],
        textposition='outside',
        marker_color='#FF6B6B',
        width=0.4,
        showlegend=True
    ))

    # Add BLT Global bar (base of stack)
    fig.add_trace(go.Bar(
        name='BLT Global',
        x=['BLT'],
        y=[results['blt_global']],
        text=[f"{results['blt_global']:.2e}"],
        textposition='inside',
        marker_color='#4ECDC4',
        width=0.4,
        showlegend=True
    ))

    # Add BLT Local bar (top of stack)
    fig.add_trace(go.Bar(
        name='BLT Local',
        x=['BLT'],
        y=[results['blt_local']],
        text=[f"{results['blt_local']:.2e}"],
        textposition='inside',
        marker_color='#45B7D1',
        width=0.4,
        showlegend=True
    ))

    # Update layout with proper stacking and scientific notation
    fig.update_layout(
        title={
            'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>",
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 20}
        },
        xaxis=dict(
            title="Architecture",
            tickfont=dict(size=14)
        ),
        yaxis=dict(
            title="FLOPs per Byte",
            tickformat=".1e",  # Scientific notation with 1 decimal
            tickfont=dict(size=12),
            gridcolor='lightgray'
        ),
        barmode='stack',
        showlegend=True,
        height=650,
        template="plotly_white",
        font=dict(size=14),
        bargap=0.3,
        plot_bgcolor='white',
        margin=dict(b=110) # Increased bottom margin slightly more for two lines of text
    )

    fig.add_annotation(
        x='BLT',
        y=results['blt_total'] * 1.05,
        text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
        showarrow=False,
        font=dict(size=12, color="black"),
        bgcolor="rgba(255,255,255,0.5)",
        bordercolor="black",
        borderwidth=1,
        xanchor='center',
        yanchor='bottom'
    )

    # Add parameter count annotations at the bottom of bars
    fig.add_annotation(
        x='BPE',
        y=0,
        text=bpe_params_str,
        showarrow=False,
        xref="x",
        yref="paper",
        yanchor='top',
        xanchor='center',
        yshift=-35,
        font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
    )

    fig.add_annotation(
        x='BLT',
        y=0,
        text=blt_combined_params_str, # Using the new combined string with breakdown
        showarrow=False,
        xref="x",
        yref="paper",
        yanchor='top',
        xanchor='center',
        yshift=-45, # Adjusted yshift for two lines of text
        font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
        align="center" # Ensure text is centered if it wraps due to <br>
    )


    # Update traces to ensure proper stacking
    fig.update_traces(textfont_size=10)

    return fig

# Create Gradio interface
with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
    gr.Markdown("""
    # BLT vs BPE FLOPs Comparison
    Companion blog post [can be found here](https://lucalp.dev/bitter-lesson-tokenization-and-blt).
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Adjustable Parameters")
            blt_ps_slider = gr.Slider(
                minimum=1.0,
                maximum=10.0,
                value=4.4,
                step=0.1,
                label="BLT Patch Size (blt_ps)",
                info="Patch size for BLT architecture"
            )

            d_model_slider = gr.Slider(
                minimum=512,
                maximum=8192,
                value=2560,
                step=128,
                label="Global Model Dimension (d_model)",
                info="Hidden dimension size of the BPE model and BLT's Global model"
            )

            local_n_layers_slider = gr.Slider(
                minimum=2,
                maximum=24, # Max value for local_n_layers
                value=10,
                step=2,   # Ensure even numbers for CA split
                label="Local Model Layers (local_n_layers)",
                info="Number of layers in the BLT's local model"
            )

            gr.Markdown("""
    For inspiration, have a look at the paper's [BLT architecture configurations](https://arxiv.org/html/2412.09871v1#:~:text=%5Cbeginappendix-,11,Table%C2%A010%20shows%20different%20hyper%20parameter%20settings%20for%20BLT%20models.,-Encoder) for some inspiration.

    A few things you'll notice:
    1. Patch size reduces global model FLOPs but not local model
    2. Increasing patch size and global model dimension doesn't change total FLOPs
    3. In smaller BLTs, local models constitute a larger portion of the total FLOPs

    Parameter counts are displayed below each bar.

    A core
    hypothesis of the paper is "that larger models taking fewer steps on larger patches
    might perform better than smaller models taking more steps." [source](https://arxiv.org/html/2412.09871v1#:~:text=the%20hypothesis%20that%20larger%20models%20taking%20fewer%20steps%20on%20larger%20patches%20might%20perform%20better%20than%20smaller%20models%20taking%20more%20steps)

    The **purpose** of this tool is to show the relationship between patch size, global
    model dimension and local model layers in terms of FLOPs and parameters. This tool
    implies _nothing_ about the **effectiveness** of the FLOPs relative to loss (c.f
    [FLOPs/BPB plots from the paper](https://arxiv.org/html/2412.09871v1#:~:text=Introduction-,Figure%201%3A,-Scaling%20trends%20for)) or downstream benchmarks. In order to
    fully compare BPE-based transformers and BLT, you'll need to investigate those
    claims in the paper itself.
            """)

            # --- UPDATED SECTION 1: Fixed Parameters dropdown ---
            with gr.Accordion("Fixed Parameters", open=False):
                gr.Markdown(f"""
                - **BPE's bytes per token (bpe_ps)**: {bpe_ps}
                - **BPE/BLT Global - Num Layers (n_layers)**: {n_layers}
                - **BPE/BLT Global - Num Heads (n_heads)**: {n_heads}
                - **BPE - Vocabulary Size (n_vocab)**: {n_vocab:,}
                - **BPE/BLT - Context Length (n_ctx_base)**: {n_ctx_base:,} bytes
                - **BLT Local - Model Dimension (local_d_model)**: {local_d_model}
                - **BLT Local - Num Heads (local_n_heads)**: {local_n_heads}
                - **BLT Local - Vocabulary Size (local_n_vocab)**: {local_n_vocab}
                - **BLT Local - FF Multiplier (local_d_ff_multiplier)**: {local_d_ff_multiplier}
                """)

            # --- UPDATED SECTION 2: Current Values & Totals dropdown ---
            with gr.Accordion("Current Values & Totals", open=False):
                info_text = gr.Markdown("")

        with gr.Column(scale=2):
            plot = gr.Plot(label="FLOPs Comparison & Model Parameters")

    # Set up interactivity
    def update_plot_and_info(blt_ps_val, d_model_val, local_n_layers_val):
        fig = create_visualization(blt_ps_val, d_model_val, local_n_layers_val)
        results = calculate_flops(blt_ps_val, d_model_val, local_n_layers_val)

        # Recalculate parameters for info text (could also be returned by create_visualization or calculate_flops)
        bpe_model_p = (12 * n_layers * d_model_val**2) + (2 * n_vocab * d_model_val)
        blt_global_p = 12 * n_layers * d_model_val**2
        blt_local_transformer_p = (12 * local_n_layers_val * local_d_model**2) + \
                                    (2 * local_n_vocab * local_d_model)
        blt_local_ca_p = local_n_layers_val * 4 * local_d_model**2
        blt_local_total_internal_p = blt_local_transformer_p + blt_local_ca_p
        blt_total_model_p = blt_global_p + blt_local_total_internal_p

        info_str = f"""
        **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}
        **BPE Total Params**: {format_params_display(bpe_model_p)}

        **BLT Global FLOPs/byte**: {results['blt_global']:.2e}
        **BLT Local FLOPs/byte**: {results['blt_local']:.2e}
        **BLT Total FLOPs/byte**: {results['blt_total']:.2e}
        **BLT Total Params**: {format_params_display(blt_total_model_p)}
        (Global: {format_params_display(blt_global_p)}, Local: {format_params_display(blt_local_total_internal_p)})

        **BLT Advantage (FLOPs/byte vs BPE)**: {results['advantage']:.1f}%
        """
        return fig, info_str

    # Update plot when any slider changes
    inputs_list = [blt_ps_slider, d_model_slider, local_n_layers_slider]
    blt_ps_slider.change(
        update_plot_and_info,
        inputs=inputs_list,
        outputs=[plot, info_text]
    )
    d_model_slider.change(
        update_plot_and_info,
        inputs=inputs_list,
        outputs=[plot, info_text]
    )
    local_n_layers_slider.change(
        update_plot_and_info,
        inputs=inputs_list,
        outputs=[plot, info_text]
    )

    # Initial plot
    demo.load(
        update_plot_and_info,
        inputs=inputs_list,
        outputs=[plot, info_text]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()