lucalp's picture
Moving description
e83db0f
"""
h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py
"""
import gradio as gr
import plotly.graph_objects as go
import numpy as np
# Fixed BPE parameters
bpe_ps = 4.4 # determined by tokenizer
n_ctx_base = 8192
n_heads = 20
n_vocab = 128000
n_layers = 26 # Used for BPE model and BLT Global model
# Fixed local model parameters
local_d_model = 1024
local_g_size = 1
local_n_ctx = 512 # in bytes
local_n_heads = 16
local_n_vocab = 256 # Used for BLT Local model
local_d_model_k = local_d_model / local_n_heads
local_d_ff_multiplier = 4
def openai_flops_per_token(n_layers_val, n_heads_val, d_model_val, n_ctx_val, n_vocab_val, ff_ratio=4):
"""Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
d_attn = d_model_val // n_heads_val
d_ff = d_model_val * ff_ratio
embeddings = 4 * d_model_val # FLOPs for embeddings - not parameter count
attn_qkv = 2 * n_layers_val * d_model_val * 3 * (d_attn * n_heads_val)
attn_mask = 2 * n_layers_val * n_ctx_val * (d_attn * n_heads_val)
attn_project = 2 * n_layers_val * (d_attn * n_heads_val) * d_model_val
ff = 2 * n_layers_val * 2 * d_model_val * d_ff
logits = 2 * d_model_val * n_vocab_val
return embeddings + attn_qkv + attn_mask + attn_project + ff + logits
def cross_attention_flops_per_token(n_layers_ca, n_ctx_cross_attn_kv_len, d_model_ca):
ca_qo_proj_flops = (
# Cross Attention QO FLOPs + backward
2 * 4 * d_model_ca**2
)
ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model_ca
return n_layers_ca * (ca_qo_proj_flops + ca_context_flops)
def calculate_flops(blt_ps, d_model_slider, local_n_layers_slider):
# BPE calculations
n_ctx = int(n_ctx_base / bpe_ps)
bpe_flops_per_token_val = openai_flops_per_token(n_layers, n_heads, d_model_slider, n_ctx, n_vocab)
bpe_per_byte = bpe_flops_per_token_val / bpe_ps
# BLT Global calculations
blt_n_ctx = int(n_ctx_base / blt_ps)
blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model_slider, blt_n_ctx, n_vocab_val=0) # n_vocab=0 for global
blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps
# BLT Local calculations
local_models_transformer_flops_per_byte = openai_flops_per_token(
local_n_layers_slider, local_n_heads, local_d_model, local_n_ctx, local_n_vocab, ff_ratio=local_d_ff_multiplier
)
encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
local_n_layers_slider / 2, local_n_ctx, local_d_model
)
decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
local_n_layers_slider / 2, local_n_ctx // blt_ps, local_d_model
)
local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte
# Calculate advantage
blt_total = local_models_flops + blt_global_flops_per_byte
advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) if bpe_per_byte != 0 else 0
return {
'bpe_per_byte': bpe_per_byte,
'blt_global': blt_global_flops_per_byte,
'blt_local': local_models_flops,
'blt_total': blt_total,
'advantage': advantage,
}
def format_params_display(num_params):
"""Formats number of parameters into a string with M or B units."""
if num_params is None:
return ""
if abs(num_params) >= 1_000_000_000:
return f"{num_params / 1_000_000_000:.1f}B Params"
elif abs(num_params) >= 1_000_000:
return f"{num_params / 1_000_000:.1f}M Params"
else: # For numbers less than 1M
return f"{num_params / 1_000_000:.2f}M Params"
def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)
# Calculate model parameters
# BPE Model Parameters: 12 * N * D^2 + 2 * V * D
bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)
# BLT Model Parameters
# Global Component: 12 * N * D^2 (no main vocab projection)
blt_global_internal_params = 12 * n_layers * d_model_slider**2
# Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
(2 * local_n_vocab * local_d_model)
# Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params
bpe_params_str = format_params_display(bpe_model_params)
# Format BLT global and local parameters separately
blt_global_params_fmt_str = format_params_display(blt_global_internal_params)
blt_local_params_fmt_str = format_params_display(blt_local_total_internal_params)
# Combine for annotation text, using <br> for line break
blt_combined_params_str = f"Global: {blt_global_params_fmt_str}<br>Local: {blt_local_params_fmt_str}"
# Create the figure with subplots for better control
fig = go.Figure()
# Add BPE bar (only for BPE category)
fig.add_trace(go.Bar(
name='BPE',
x=['BPE'],
y=[results['bpe_per_byte']],
text=[f"{results['bpe_per_byte']:.2e}"],
textposition='outside',
marker_color='#FF6B6B',
width=0.4,
showlegend=True
))
# Add BLT Global bar (base of stack)
fig.add_trace(go.Bar(
name='BLT Global',
x=['BLT'],
y=[results['blt_global']],
text=[f"{results['blt_global']:.2e}"],
textposition='inside',
marker_color='#4ECDC4',
width=0.4,
showlegend=True
))
# Add BLT Local bar (top of stack)
fig.add_trace(go.Bar(
name='BLT Local',
x=['BLT'],
y=[results['blt_local']],
text=[f"{results['blt_local']:.2e}"],
textposition='inside',
marker_color='#45B7D1',
width=0.4,
showlegend=True
))
# Update layout with proper stacking and scientific notation
fig.update_layout(
title={
'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 20}
},
xaxis=dict(
title="Architecture",
tickfont=dict(size=14)
),
yaxis=dict(
title="FLOPs per Byte",
tickformat=".1e", # Scientific notation with 1 decimal
tickfont=dict(size=12),
gridcolor='lightgray'
),
barmode='stack',
showlegend=True,
height=650,
template="plotly_white",
font=dict(size=14),
bargap=0.3,
plot_bgcolor='white',
margin=dict(b=110) # Increased bottom margin slightly more for two lines of text
)
fig.add_annotation(
x='BLT',
y=results['blt_total'] * 1.05,
text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
showarrow=False,
font=dict(size=12, color="black"),
bgcolor="rgba(255,255,255,0.5)",
bordercolor="black",
borderwidth=1,
xanchor='center',
yanchor='bottom'
)
# Add parameter count annotations at the bottom of bars
fig.add_annotation(
x='BPE',
y=0,
text=bpe_params_str,
showarrow=False,
xref="x",
yref="paper",
yanchor='top',
xanchor='center',
yshift=-35,
font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
)
fig.add_annotation(
x='BLT',
y=0,
text=blt_combined_params_str, # Using the new combined string with breakdown
showarrow=False,
xref="x",
yref="paper",
yanchor='top',
xanchor='center',
yshift=-45, # Adjusted yshift for two lines of text
font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
align="center" # Ensure text is centered if it wraps due to <br>
)
# Update traces to ensure proper stacking
fig.update_traces(textfont_size=10)
return fig
# Create Gradio interface
with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
gr.Markdown("""
# BLT vs BPE FLOPs Comparison
Companion blog post [can be found here](https://lucalp.dev/bitter-lesson-tokenization-and-blt).
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Adjustable Parameters")
blt_ps_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=4.4,
step=0.1,
label="BLT Patch Size (blt_ps)",
info="Patch size for BLT architecture"
)
d_model_slider = gr.Slider(
minimum=512,
maximum=8192,
value=2560,
step=128,
label="Global Model Dimension (d_model)",
info="Hidden dimension size of the BPE model and BLT's Global model"
)
local_n_layers_slider = gr.Slider(
minimum=2,
maximum=24, # Max value for local_n_layers
value=10,
step=2, # Ensure even numbers for CA split
label="Local Model Layers (local_n_layers)",
info="Number of layers in the BLT's local model"
)
gr.Markdown("""
For inspiration, have a look at the paper's [BLT architecture configurations](https://arxiv.org/html/2412.09871v1#:~:text=%5Cbeginappendix-,11,Table%C2%A010%20shows%20different%20hyper%20parameter%20settings%20for%20BLT%20models.,-Encoder) for some inspiration.
A few things you'll notice:
1. Patch size reduces global model FLOPs but not local model
2. Increasing patch size and global model dimension doesn't change total FLOPs
3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
Parameter counts are displayed below each bar.
A core
hypothesis of the paper is "that larger models taking fewer steps on larger patches
might perform better than smaller models taking more steps." [source](https://arxiv.org/html/2412.09871v1#:~:text=the%20hypothesis%20that%20larger%20models%20taking%20fewer%20steps%20on%20larger%20patches%20might%20perform%20better%20than%20smaller%20models%20taking%20more%20steps)
The **purpose** of this tool is to show the relationship between patch size, global
model dimension and local model layers in terms of FLOPs and parameters. This tool
implies _nothing_ about the **effectiveness** of the FLOPs relative to loss (c.f
[FLOPs/BPB plots from the paper](https://arxiv.org/html/2412.09871v1#:~:text=Introduction-,Figure%201%3A,-Scaling%20trends%20for)) or downstream benchmarks. In order to
fully compare BPE-based transformers and BLT, you'll need to investigate those
claims in the paper itself.
""")
# --- UPDATED SECTION 1: Fixed Parameters dropdown ---
with gr.Accordion("Fixed Parameters", open=False):
gr.Markdown(f"""
- **BPE's bytes per token (bpe_ps)**: {bpe_ps}
- **BPE/BLT Global - Num Layers (n_layers)**: {n_layers}
- **BPE/BLT Global - Num Heads (n_heads)**: {n_heads}
- **BPE - Vocabulary Size (n_vocab)**: {n_vocab:,}
- **BPE/BLT - Context Length (n_ctx_base)**: {n_ctx_base:,} bytes
- **BLT Local - Model Dimension (local_d_model)**: {local_d_model}
- **BLT Local - Num Heads (local_n_heads)**: {local_n_heads}
- **BLT Local - Vocabulary Size (local_n_vocab)**: {local_n_vocab}
- **BLT Local - FF Multiplier (local_d_ff_multiplier)**: {local_d_ff_multiplier}
""")
# --- UPDATED SECTION 2: Current Values & Totals dropdown ---
with gr.Accordion("Current Values & Totals", open=False):
info_text = gr.Markdown("")
with gr.Column(scale=2):
plot = gr.Plot(label="FLOPs Comparison & Model Parameters")
# Set up interactivity
def update_plot_and_info(blt_ps_val, d_model_val, local_n_layers_val):
fig = create_visualization(blt_ps_val, d_model_val, local_n_layers_val)
results = calculate_flops(blt_ps_val, d_model_val, local_n_layers_val)
# Recalculate parameters for info text (could also be returned by create_visualization or calculate_flops)
bpe_model_p = (12 * n_layers * d_model_val**2) + (2 * n_vocab * d_model_val)
blt_global_p = 12 * n_layers * d_model_val**2
blt_local_transformer_p = (12 * local_n_layers_val * local_d_model**2) + \
(2 * local_n_vocab * local_d_model)
blt_local_ca_p = local_n_layers_val * 4 * local_d_model**2
blt_local_total_internal_p = blt_local_transformer_p + blt_local_ca_p
blt_total_model_p = blt_global_p + blt_local_total_internal_p
info_str = f"""
**BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}
**BPE Total Params**: {format_params_display(bpe_model_p)}
**BLT Global FLOPs/byte**: {results['blt_global']:.2e}
**BLT Local FLOPs/byte**: {results['blt_local']:.2e}
**BLT Total FLOPs/byte**: {results['blt_total']:.2e}
**BLT Total Params**: {format_params_display(blt_total_model_p)}
(Global: {format_params_display(blt_global_p)}, Local: {format_params_display(blt_local_total_internal_p)})
**BLT Advantage (FLOPs/byte vs BPE)**: {results['advantage']:.1f}%
"""
return fig, info_str
# Update plot when any slider changes
inputs_list = [blt_ps_slider, d_model_slider, local_n_layers_slider]
blt_ps_slider.change(
update_plot_and_info,
inputs=inputs_list,
outputs=[plot, info_text]
)
d_model_slider.change(
update_plot_and_info,
inputs=inputs_list,
outputs=[plot, info_text]
)
local_n_layers_slider.change(
update_plot_and_info,
inputs=inputs_list,
outputs=[plot, info_text]
)
# Initial plot
demo.load(
update_plot_and_info,
inputs=inputs_list,
outputs=[plot, info_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()