|
""" |
|
h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py |
|
""" |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
|
|
|
|
bpe_ps = 4.4 |
|
n_ctx_base = 8192 |
|
n_heads = 20 |
|
n_vocab = 128000 |
|
n_layers = 26 |
|
|
|
|
|
local_d_model = 1024 |
|
local_g_size = 1 |
|
local_n_ctx = 512 |
|
local_n_heads = 16 |
|
local_n_vocab = 256 |
|
local_d_model_k = local_d_model / local_n_heads |
|
local_d_ff_multiplier = 4 |
|
|
|
def openai_flops_per_token(n_layers_val, n_heads_val, d_model_val, n_ctx_val, n_vocab_val, ff_ratio=4): |
|
"""Open AI method for forward pass FLOPs counting of decoder-only Transformer""" |
|
d_attn = d_model_val // n_heads_val |
|
d_ff = d_model_val * ff_ratio |
|
|
|
embeddings = 4 * d_model_val |
|
attn_qkv = 2 * n_layers_val * d_model_val * 3 * (d_attn * n_heads_val) |
|
attn_mask = 2 * n_layers_val * n_ctx_val * (d_attn * n_heads_val) |
|
attn_project = 2 * n_layers_val * (d_attn * n_heads_val) * d_model_val |
|
ff = 2 * n_layers_val * 2 * d_model_val * d_ff |
|
logits = 2 * d_model_val * n_vocab_val |
|
|
|
return embeddings + attn_qkv + attn_mask + attn_project + ff + logits |
|
|
|
|
|
def cross_attention_flops_per_token(n_layers_ca, n_ctx_cross_attn_kv_len, d_model_ca): |
|
ca_qo_proj_flops = ( |
|
|
|
2 * 4 * d_model_ca**2 |
|
) |
|
ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model_ca |
|
return n_layers_ca * (ca_qo_proj_flops + ca_context_flops) |
|
|
|
|
|
def calculate_flops(blt_ps, d_model_slider, local_n_layers_slider): |
|
|
|
n_ctx = int(n_ctx_base / bpe_ps) |
|
bpe_flops_per_token_val = openai_flops_per_token(n_layers, n_heads, d_model_slider, n_ctx, n_vocab) |
|
bpe_per_byte = bpe_flops_per_token_val / bpe_ps |
|
|
|
|
|
blt_n_ctx = int(n_ctx_base / blt_ps) |
|
blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model_slider, blt_n_ctx, n_vocab_val=0) |
|
blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps |
|
|
|
|
|
local_models_transformer_flops_per_byte = openai_flops_per_token( |
|
local_n_layers_slider, local_n_heads, local_d_model, local_n_ctx, local_n_vocab, ff_ratio=local_d_ff_multiplier |
|
) |
|
encoder_model_ca_flops_per_byte = cross_attention_flops_per_token( |
|
local_n_layers_slider / 2, local_n_ctx, local_d_model |
|
) |
|
decoder_model_ca_flops_per_byte = cross_attention_flops_per_token( |
|
local_n_layers_slider / 2, local_n_ctx // blt_ps, local_d_model |
|
) |
|
local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte |
|
local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte |
|
|
|
|
|
blt_total = local_models_flops + blt_global_flops_per_byte |
|
advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) if bpe_per_byte != 0 else 0 |
|
|
|
|
|
return { |
|
'bpe_per_byte': bpe_per_byte, |
|
'blt_global': blt_global_flops_per_byte, |
|
'blt_local': local_models_flops, |
|
'blt_total': blt_total, |
|
'advantage': advantage, |
|
} |
|
|
|
def format_params_display(num_params): |
|
"""Formats number of parameters into a string with M or B units.""" |
|
if num_params is None: |
|
return "" |
|
if abs(num_params) >= 1_000_000_000: |
|
return f"{num_params / 1_000_000_000:.1f}B Params" |
|
elif abs(num_params) >= 1_000_000: |
|
return f"{num_params / 1_000_000:.1f}M Params" |
|
else: |
|
return f"{num_params / 1_000_000:.2f}M Params" |
|
|
|
|
|
|
|
def create_visualization(blt_ps, d_model_slider, local_n_layers_slider): |
|
results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider) |
|
|
|
|
|
|
|
bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider) |
|
|
|
|
|
|
|
blt_global_internal_params = 12 * n_layers * d_model_slider**2 |
|
|
|
|
|
blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \ |
|
(2 * local_n_vocab * local_d_model) |
|
|
|
|
|
blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2 |
|
blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params |
|
|
|
bpe_params_str = format_params_display(bpe_model_params) |
|
|
|
|
|
blt_global_params_fmt_str = format_params_display(blt_global_internal_params) |
|
blt_local_params_fmt_str = format_params_display(blt_local_total_internal_params) |
|
|
|
|
|
blt_combined_params_str = f"Global: {blt_global_params_fmt_str}<br>Local: {blt_local_params_fmt_str}" |
|
|
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
name='BPE', |
|
x=['BPE'], |
|
y=[results['bpe_per_byte']], |
|
text=[f"{results['bpe_per_byte']:.2e}"], |
|
textposition='outside', |
|
marker_color='#FF6B6B', |
|
width=0.4, |
|
showlegend=True |
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
name='BLT Global', |
|
x=['BLT'], |
|
y=[results['blt_global']], |
|
text=[f"{results['blt_global']:.2e}"], |
|
textposition='inside', |
|
marker_color='#4ECDC4', |
|
width=0.4, |
|
showlegend=True |
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
name='BLT Local', |
|
x=['BLT'], |
|
y=[results['blt_local']], |
|
text=[f"{results['blt_local']:.2e}"], |
|
textposition='inside', |
|
marker_color='#45B7D1', |
|
width=0.4, |
|
showlegend=True |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title={ |
|
'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>", |
|
'x': 0.5, |
|
'xanchor': 'center', |
|
'font': {'size': 20} |
|
}, |
|
xaxis=dict( |
|
title="Architecture", |
|
tickfont=dict(size=14) |
|
), |
|
yaxis=dict( |
|
title="FLOPs per Byte", |
|
tickformat=".1e", |
|
tickfont=dict(size=12), |
|
gridcolor='lightgray' |
|
), |
|
barmode='stack', |
|
showlegend=True, |
|
height=650, |
|
template="plotly_white", |
|
font=dict(size=14), |
|
bargap=0.3, |
|
plot_bgcolor='white', |
|
margin=dict(b=110) |
|
) |
|
|
|
fig.add_annotation( |
|
x='BLT', |
|
y=results['blt_total'] * 1.05, |
|
text=f"Total FLOPs/Byte: {results['blt_total']:.2e}", |
|
showarrow=False, |
|
font=dict(size=12, color="black"), |
|
bgcolor="rgba(255,255,255,0.5)", |
|
bordercolor="black", |
|
borderwidth=1, |
|
xanchor='center', |
|
yanchor='bottom' |
|
) |
|
|
|
|
|
fig.add_annotation( |
|
x='BPE', |
|
y=0, |
|
text=bpe_params_str, |
|
showarrow=False, |
|
xref="x", |
|
yref="paper", |
|
yanchor='top', |
|
xanchor='center', |
|
yshift=-35, |
|
font=dict(size=10, color="black", weight="bold"), |
|
) |
|
|
|
fig.add_annotation( |
|
x='BLT', |
|
y=0, |
|
text=blt_combined_params_str, |
|
showarrow=False, |
|
xref="x", |
|
yref="paper", |
|
yanchor='top', |
|
xanchor='center', |
|
yshift=-45, |
|
font=dict(size=10, color="black", weight="bold"), |
|
align="center" |
|
) |
|
|
|
|
|
|
|
fig.update_traces(textfont_size=10) |
|
|
|
return fig |
|
|
|
|
|
with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo: |
|
gr.Markdown(""" |
|
# BLT vs BPE FLOPs Comparison |
|
Companion blog post [can be found here](https://lucalp.dev/bitter-lesson-tokenization-and-blt). |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Adjustable Parameters") |
|
blt_ps_slider = gr.Slider( |
|
minimum=1.0, |
|
maximum=10.0, |
|
value=4.4, |
|
step=0.1, |
|
label="BLT Patch Size (blt_ps)", |
|
info="Patch size for BLT architecture" |
|
) |
|
|
|
d_model_slider = gr.Slider( |
|
minimum=512, |
|
maximum=8192, |
|
value=2560, |
|
step=128, |
|
label="Global Model Dimension (d_model)", |
|
info="Hidden dimension size of the BPE model and BLT's Global model" |
|
) |
|
|
|
local_n_layers_slider = gr.Slider( |
|
minimum=2, |
|
maximum=24, |
|
value=10, |
|
step=2, |
|
label="Local Model Layers (local_n_layers)", |
|
info="Number of layers in the BLT's local model" |
|
) |
|
|
|
gr.Markdown(""" |
|
For inspiration, have a look at the paper's [BLT architecture configurations](https://arxiv.org/html/2412.09871v1#:~:text=%5Cbeginappendix-,11,Table%C2%A010%20shows%20different%20hyper%20parameter%20settings%20for%20BLT%20models.,-Encoder) for some inspiration. |
|
|
|
A few things you'll notice: |
|
1. Patch size reduces global model FLOPs but not local model |
|
2. Increasing patch size and global model dimension doesn't change total FLOPs |
|
3. In smaller BLTs, local models constitute a larger portion of the total FLOPs |
|
|
|
Parameter counts are displayed below each bar. |
|
|
|
A core |
|
hypothesis of the paper is "that larger models taking fewer steps on larger patches |
|
might perform better than smaller models taking more steps." [source](https://arxiv.org/html/2412.09871v1#:~:text=the%20hypothesis%20that%20larger%20models%20taking%20fewer%20steps%20on%20larger%20patches%20might%20perform%20better%20than%20smaller%20models%20taking%20more%20steps) |
|
|
|
The **purpose** of this tool is to show the relationship between patch size, global |
|
model dimension and local model layers in terms of FLOPs and parameters. This tool |
|
implies _nothing_ about the **effectiveness** of the FLOPs relative to loss (c.f |
|
[FLOPs/BPB plots from the paper](https://arxiv.org/html/2412.09871v1#:~:text=Introduction-,Figure%201%3A,-Scaling%20trends%20for)) or downstream benchmarks. In order to |
|
fully compare BPE-based transformers and BLT, you'll need to investigate those |
|
claims in the paper itself. |
|
""") |
|
|
|
|
|
with gr.Accordion("Fixed Parameters", open=False): |
|
gr.Markdown(f""" |
|
- **BPE's bytes per token (bpe_ps)**: {bpe_ps} |
|
- **BPE/BLT Global - Num Layers (n_layers)**: {n_layers} |
|
- **BPE/BLT Global - Num Heads (n_heads)**: {n_heads} |
|
- **BPE - Vocabulary Size (n_vocab)**: {n_vocab:,} |
|
- **BPE/BLT - Context Length (n_ctx_base)**: {n_ctx_base:,} bytes |
|
- **BLT Local - Model Dimension (local_d_model)**: {local_d_model} |
|
- **BLT Local - Num Heads (local_n_heads)**: {local_n_heads} |
|
- **BLT Local - Vocabulary Size (local_n_vocab)**: {local_n_vocab} |
|
- **BLT Local - FF Multiplier (local_d_ff_multiplier)**: {local_d_ff_multiplier} |
|
""") |
|
|
|
|
|
with gr.Accordion("Current Values & Totals", open=False): |
|
info_text = gr.Markdown("") |
|
|
|
with gr.Column(scale=2): |
|
plot = gr.Plot(label="FLOPs Comparison & Model Parameters") |
|
|
|
|
|
def update_plot_and_info(blt_ps_val, d_model_val, local_n_layers_val): |
|
fig = create_visualization(blt_ps_val, d_model_val, local_n_layers_val) |
|
results = calculate_flops(blt_ps_val, d_model_val, local_n_layers_val) |
|
|
|
|
|
bpe_model_p = (12 * n_layers * d_model_val**2) + (2 * n_vocab * d_model_val) |
|
blt_global_p = 12 * n_layers * d_model_val**2 |
|
blt_local_transformer_p = (12 * local_n_layers_val * local_d_model**2) + \ |
|
(2 * local_n_vocab * local_d_model) |
|
blt_local_ca_p = local_n_layers_val * 4 * local_d_model**2 |
|
blt_local_total_internal_p = blt_local_transformer_p + blt_local_ca_p |
|
blt_total_model_p = blt_global_p + blt_local_total_internal_p |
|
|
|
info_str = f""" |
|
**BPE FLOPs/byte**: {results['bpe_per_byte']:.2e} |
|
**BPE Total Params**: {format_params_display(bpe_model_p)} |
|
|
|
**BLT Global FLOPs/byte**: {results['blt_global']:.2e} |
|
**BLT Local FLOPs/byte**: {results['blt_local']:.2e} |
|
**BLT Total FLOPs/byte**: {results['blt_total']:.2e} |
|
**BLT Total Params**: {format_params_display(blt_total_model_p)} |
|
(Global: {format_params_display(blt_global_p)}, Local: {format_params_display(blt_local_total_internal_p)}) |
|
|
|
**BLT Advantage (FLOPs/byte vs BPE)**: {results['advantage']:.1f}% |
|
""" |
|
return fig, info_str |
|
|
|
|
|
inputs_list = [blt_ps_slider, d_model_slider, local_n_layers_slider] |
|
blt_ps_slider.change( |
|
update_plot_and_info, |
|
inputs=inputs_list, |
|
outputs=[plot, info_text] |
|
) |
|
d_model_slider.change( |
|
update_plot_and_info, |
|
inputs=inputs_list, |
|
outputs=[plot, info_text] |
|
) |
|
local_n_layers_slider.change( |
|
update_plot_and_info, |
|
inputs=inputs_list, |
|
outputs=[plot, info_text] |
|
) |
|
|
|
|
|
demo.load( |
|
update_plot_and_info, |
|
inputs=inputs_list, |
|
outputs=[plot, info_text] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|