""" h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py """ import gradio as gr import plotly.graph_objects as go import numpy as np # Fixed BPE parameters bpe_ps = 4.4 # determined by tokenizer n_ctx_base = 8192 n_heads = 20 n_vocab = 128000 n_layers = 26 # Used for BPE model and BLT Global model # Fixed local model parameters local_d_model = 1024 local_g_size = 1 local_n_ctx = 512 # in bytes local_n_heads = 16 local_n_vocab = 256 # Used for BLT Local model local_d_model_k = local_d_model / local_n_heads local_d_ff_multiplier = 4 def openai_flops_per_token(n_layers_val, n_heads_val, d_model_val, n_ctx_val, n_vocab_val, ff_ratio=4): """Open AI method for forward pass FLOPs counting of decoder-only Transformer""" d_attn = d_model_val // n_heads_val d_ff = d_model_val * ff_ratio embeddings = 4 * d_model_val # FLOPs for embeddings - not parameter count attn_qkv = 2 * n_layers_val * d_model_val * 3 * (d_attn * n_heads_val) attn_mask = 2 * n_layers_val * n_ctx_val * (d_attn * n_heads_val) attn_project = 2 * n_layers_val * (d_attn * n_heads_val) * d_model_val ff = 2 * n_layers_val * 2 * d_model_val * d_ff logits = 2 * d_model_val * n_vocab_val return embeddings + attn_qkv + attn_mask + attn_project + ff + logits def cross_attention_flops_per_token(n_layers_ca, n_ctx_cross_attn_kv_len, d_model_ca): ca_qo_proj_flops = ( # Cross Attention QO FLOPs + backward 2 * 4 * d_model_ca**2 ) ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model_ca return n_layers_ca * (ca_qo_proj_flops + ca_context_flops) def calculate_flops(blt_ps, d_model_slider, local_n_layers_slider): # BPE calculations n_ctx = int(n_ctx_base / bpe_ps) bpe_flops_per_token_val = openai_flops_per_token(n_layers, n_heads, d_model_slider, n_ctx, n_vocab) bpe_per_byte = bpe_flops_per_token_val / bpe_ps # BLT Global calculations blt_n_ctx = int(n_ctx_base / blt_ps) blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model_slider, blt_n_ctx, n_vocab_val=0) # n_vocab=0 for global blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps # BLT Local calculations local_models_transformer_flops_per_byte = openai_flops_per_token( local_n_layers_slider, local_n_heads, local_d_model, local_n_ctx, local_n_vocab, ff_ratio=local_d_ff_multiplier ) encoder_model_ca_flops_per_byte = cross_attention_flops_per_token( local_n_layers_slider / 2, local_n_ctx, local_d_model ) decoder_model_ca_flops_per_byte = cross_attention_flops_per_token( local_n_layers_slider / 2, local_n_ctx // blt_ps, local_d_model ) local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte # Calculate advantage blt_total = local_models_flops + blt_global_flops_per_byte advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) if bpe_per_byte != 0 else 0 return { 'bpe_per_byte': bpe_per_byte, 'blt_global': blt_global_flops_per_byte, 'blt_local': local_models_flops, 'blt_total': blt_total, 'advantage': advantage, } def format_params_display(num_params): """Formats number of parameters into a string with M or B units.""" if num_params is None: return "" if abs(num_params) >= 1_000_000_000: return f"{num_params / 1_000_000_000:.1f}B Params" elif abs(num_params) >= 1_000_000: return f"{num_params / 1_000_000:.1f}M Params" else: # For numbers less than 1M return f"{num_params / 1_000_000:.2f}M Params" def create_visualization(blt_ps, d_model_slider, local_n_layers_slider): results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider) # Calculate model parameters # BPE Model Parameters: 12 * N * D^2 + 2 * V * D bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider) # BLT Model Parameters # Global Component: 12 * N * D^2 (no main vocab projection) blt_global_internal_params = 12 * n_layers * d_model_slider**2 # Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \ (2 * local_n_vocab * local_d_model) # Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated) blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2 blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params bpe_params_str = format_params_display(bpe_model_params) # Format BLT global and local parameters separately blt_global_params_fmt_str = format_params_display(blt_global_internal_params) blt_local_params_fmt_str = format_params_display(blt_local_total_internal_params) # Combine for annotation text, using
for line break blt_combined_params_str = f"Global: {blt_global_params_fmt_str}
Local: {blt_local_params_fmt_str}" # Create the figure with subplots for better control fig = go.Figure() # Add BPE bar (only for BPE category) fig.add_trace(go.Bar( name='BPE', x=['BPE'], y=[results['bpe_per_byte']], text=[f"{results['bpe_per_byte']:.2e}"], textposition='outside', marker_color='#FF6B6B', width=0.4, showlegend=True )) # Add BLT Global bar (base of stack) fig.add_trace(go.Bar( name='BLT Global', x=['BLT'], y=[results['blt_global']], text=[f"{results['blt_global']:.2e}"], textposition='inside', marker_color='#4ECDC4', width=0.4, showlegend=True )) # Add BLT Local bar (top of stack) fig.add_trace(go.Bar( name='BLT Local', x=['BLT'], y=[results['blt_local']], text=[f"{results['blt_local']:.2e}"], textposition='inside', marker_color='#45B7D1', width=0.4, showlegend=True )) # Update layout with proper stacking and scientific notation fig.update_layout( title={ 'text': f"FLOPs per Byte Comparison
BLT FLOPs comparison: {results['advantage']:.1f}%", 'x': 0.5, 'xanchor': 'center', 'font': {'size': 20} }, xaxis=dict( title="Architecture", tickfont=dict(size=14) ), yaxis=dict( title="FLOPs per Byte", tickformat=".1e", # Scientific notation with 1 decimal tickfont=dict(size=12), gridcolor='lightgray' ), barmode='stack', showlegend=True, height=650, template="plotly_white", font=dict(size=14), bargap=0.3, plot_bgcolor='white', margin=dict(b=110) # Increased bottom margin slightly more for two lines of text ) fig.add_annotation( x='BLT', y=results['blt_total'] * 1.05, text=f"Total FLOPs/Byte: {results['blt_total']:.2e}", showarrow=False, font=dict(size=12, color="black"), bgcolor="rgba(255,255,255,0.5)", bordercolor="black", borderwidth=1, xanchor='center', yanchor='bottom' ) # Add parameter count annotations at the bottom of bars fig.add_annotation( x='BPE', y=0, text=bpe_params_str, showarrow=False, xref="x", yref="paper", yanchor='top', xanchor='center', yshift=-35, font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text ) fig.add_annotation( x='BLT', y=0, text=blt_combined_params_str, # Using the new combined string with breakdown showarrow=False, xref="x", yref="paper", yanchor='top', xanchor='center', yshift=-45, # Adjusted yshift for two lines of text font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text align="center" # Ensure text is centered if it wraps due to
) # Update traces to ensure proper stacking fig.update_traces(textfont_size=10) return fig # Create Gradio interface with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo: gr.Markdown(""" # BLT vs BPE FLOPs Comparison Companion blog post [can be found here](https://lucalp.dev/bitter-lesson-tokenization-and-blt). """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Adjustable Parameters") blt_ps_slider = gr.Slider( minimum=1.0, maximum=10.0, value=4.4, step=0.1, label="BLT Patch Size (blt_ps)", info="Patch size for BLT architecture" ) d_model_slider = gr.Slider( minimum=512, maximum=8192, value=2560, step=128, label="Global Model Dimension (d_model)", info="Hidden dimension size of the BPE model and BLT's Global model" ) local_n_layers_slider = gr.Slider( minimum=2, maximum=24, # Max value for local_n_layers value=10, step=2, # Ensure even numbers for CA split label="Local Model Layers (local_n_layers)", info="Number of layers in the BLT's local model" ) gr.Markdown(""" For inspiration, have a look at the paper's [BLT architecture configurations](https://arxiv.org/html/2412.09871v1#:~:text=%5Cbeginappendix-,11,Table%C2%A010%20shows%20different%20hyper%20parameter%20settings%20for%20BLT%20models.,-Encoder) for some inspiration. A few things you'll notice: 1. Patch size reduces global model FLOPs but not local model 2. Increasing patch size and global model dimension doesn't change total FLOPs 3. In smaller BLTs, local models constitute a larger portion of the total FLOPs Parameter counts are displayed below each bar. A core hypothesis of the paper is "that larger models taking fewer steps on larger patches might perform better than smaller models taking more steps." [source](https://arxiv.org/html/2412.09871v1#:~:text=the%20hypothesis%20that%20larger%20models%20taking%20fewer%20steps%20on%20larger%20patches%20might%20perform%20better%20than%20smaller%20models%20taking%20more%20steps) The **purpose** of this tool is to show the relationship between patch size, global model dimension and local model layers in terms of FLOPs and parameters. This tool implies _nothing_ about the **effectiveness** of the FLOPs relative to loss (c.f [FLOPs/BPB plots from the paper](https://arxiv.org/html/2412.09871v1#:~:text=Introduction-,Figure%201%3A,-Scaling%20trends%20for)) or downstream benchmarks. In order to fully compare BPE-based transformers and BLT, you'll need to investigate those claims in the paper itself. """) # --- UPDATED SECTION 1: Fixed Parameters dropdown --- with gr.Accordion("Fixed Parameters", open=False): gr.Markdown(f""" - **BPE's bytes per token (bpe_ps)**: {bpe_ps} - **BPE/BLT Global - Num Layers (n_layers)**: {n_layers} - **BPE/BLT Global - Num Heads (n_heads)**: {n_heads} - **BPE - Vocabulary Size (n_vocab)**: {n_vocab:,} - **BPE/BLT - Context Length (n_ctx_base)**: {n_ctx_base:,} bytes - **BLT Local - Model Dimension (local_d_model)**: {local_d_model} - **BLT Local - Num Heads (local_n_heads)**: {local_n_heads} - **BLT Local - Vocabulary Size (local_n_vocab)**: {local_n_vocab} - **BLT Local - FF Multiplier (local_d_ff_multiplier)**: {local_d_ff_multiplier} """) # --- UPDATED SECTION 2: Current Values & Totals dropdown --- with gr.Accordion("Current Values & Totals", open=False): info_text = gr.Markdown("") with gr.Column(scale=2): plot = gr.Plot(label="FLOPs Comparison & Model Parameters") # Set up interactivity def update_plot_and_info(blt_ps_val, d_model_val, local_n_layers_val): fig = create_visualization(blt_ps_val, d_model_val, local_n_layers_val) results = calculate_flops(blt_ps_val, d_model_val, local_n_layers_val) # Recalculate parameters for info text (could also be returned by create_visualization or calculate_flops) bpe_model_p = (12 * n_layers * d_model_val**2) + (2 * n_vocab * d_model_val) blt_global_p = 12 * n_layers * d_model_val**2 blt_local_transformer_p = (12 * local_n_layers_val * local_d_model**2) + \ (2 * local_n_vocab * local_d_model) blt_local_ca_p = local_n_layers_val * 4 * local_d_model**2 blt_local_total_internal_p = blt_local_transformer_p + blt_local_ca_p blt_total_model_p = blt_global_p + blt_local_total_internal_p info_str = f""" **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e} **BPE Total Params**: {format_params_display(bpe_model_p)} **BLT Global FLOPs/byte**: {results['blt_global']:.2e} **BLT Local FLOPs/byte**: {results['blt_local']:.2e} **BLT Total FLOPs/byte**: {results['blt_total']:.2e} **BLT Total Params**: {format_params_display(blt_total_model_p)} (Global: {format_params_display(blt_global_p)}, Local: {format_params_display(blt_local_total_internal_p)}) **BLT Advantage (FLOPs/byte vs BPE)**: {results['advantage']:.1f}% """ return fig, info_str # Update plot when any slider changes inputs_list = [blt_ps_slider, d_model_slider, local_n_layers_slider] blt_ps_slider.change( update_plot_and_info, inputs=inputs_list, outputs=[plot, info_text] ) d_model_slider.change( update_plot_and_info, inputs=inputs_list, outputs=[plot, info_text] ) local_n_layers_slider.change( update_plot_and_info, inputs=inputs_list, outputs=[plot, info_text] ) # Initial plot demo.load( update_plot_and_info, inputs=inputs_list, outputs=[plot, info_text] ) # Launch the app if __name__ == "__main__": demo.launch()