|
""" |
|
h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py |
|
""" |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
|
|
|
|
bpe_ps = 4.4 |
|
n_ctx_base = 8192 |
|
n_heads = 20 |
|
n_vocab = 128000 |
|
n_layers = 26 |
|
|
|
|
|
local_d_model = 1024 |
|
local_g_size = 1 |
|
local_n_ctx = 512 |
|
local_n_heads = 16 |
|
local_n_vocab = 256 |
|
local_d_model_k = local_d_model / local_n_heads |
|
local_d_ff_multiplier = 4 |
|
|
|
def openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab, ff_ratio=4): |
|
"""Open AI method for forward pass FLOPs counting of decoder-only Transformer""" |
|
d_attn = d_model // n_heads |
|
d_ff = d_model * ff_ratio |
|
|
|
embeddings = 4 * d_model |
|
attn_qkv = 2 * n_layers * d_model * 3 * (d_attn * n_heads) |
|
attn_mask = 2 * n_layers * n_ctx * (d_attn * n_heads) |
|
attn_project = 2 * n_layers * (d_attn * n_heads) * d_model |
|
ff = 2 * n_layers * 2 * d_model * d_ff |
|
logits = 2 * d_model * n_vocab |
|
|
|
return embeddings + attn_qkv + attn_mask + attn_project + ff + logits |
|
|
|
|
|
def cross_attention_flops_per_token(n_layers, n_ctx_cross_attn_kv_len, d_model): |
|
ca_qo_proj_flops = ( |
|
|
|
2 * 4 * d_model**2 |
|
) |
|
ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model |
|
return n_layers * (ca_qo_proj_flops + ca_context_flops) |
|
|
|
|
|
def calculate_flops(blt_ps, d_model, local_n_layers): |
|
|
|
n_ctx = int(n_ctx_base / bpe_ps) |
|
bpe_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab) |
|
bpe_per_byte = bpe_flops_per_token / bpe_ps |
|
|
|
|
|
blt_n_ctx = int(n_ctx_base / blt_ps) |
|
blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, blt_n_ctx, n_vocab=0) |
|
blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps |
|
|
|
|
|
local_models_transformer_flops_per_byte = openai_flops_per_token( |
|
local_n_layers, local_n_heads, local_d_model, local_n_ctx, local_n_vocab |
|
) |
|
encoder_model_ca_flops_per_byte = cross_attention_flops_per_token( |
|
local_n_layers/2, local_n_ctx, local_d_model |
|
) |
|
decoder_model_ca_flops_per_byte = cross_attention_flops_per_token( |
|
local_n_layers/2, local_n_ctx // blt_ps, local_d_model |
|
) |
|
local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte |
|
local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte |
|
|
|
|
|
blt_total = local_models_flops + blt_global_flops_per_byte |
|
advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) |
|
|
|
return { |
|
'bpe_per_byte': bpe_per_byte, |
|
'blt_global': blt_global_flops_per_byte, |
|
'blt_local': local_models_flops, |
|
'blt_total': blt_total, |
|
'advantage': advantage |
|
} |
|
|
|
def create_visualization(blt_ps, d_model, local_n_layers): |
|
results = calculate_flops(blt_ps, d_model, local_n_layers) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
name='BPE', |
|
x=['BPE'], |
|
y=[results['bpe_per_byte']], |
|
text=[f"{results['bpe_per_byte']:.2e}"], |
|
textposition='outside', |
|
marker_color='#FF6B6B', |
|
width=0.4, |
|
showlegend=True |
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
name='BLT Global', |
|
x=['BLT'], |
|
y=[results['blt_global']], |
|
text=[f"{results['blt_global']:.2e}"], |
|
textposition='inside', |
|
marker_color='#4ECDC4', |
|
width=0.4, |
|
showlegend=True |
|
)) |
|
|
|
|
|
fig.add_trace(go.Bar( |
|
name='BLT Local', |
|
x=['BLT'], |
|
y=[results['blt_local']], |
|
text=[f"{results['blt_local']:.2e}"], |
|
textposition='inside', |
|
marker_color='#45B7D1', |
|
width=0.4, |
|
showlegend=True |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
title={ |
|
'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>", |
|
'x': 0.5, |
|
'xanchor': 'center', |
|
'font': {'size': 20} |
|
}, |
|
xaxis=dict( |
|
title="Architecture", |
|
tickfont=dict(size=14) |
|
), |
|
yaxis=dict( |
|
title="FLOPs per Byte", |
|
tickformat=".1e", |
|
tickfont=dict(size=12), |
|
gridcolor='lightgray' |
|
), |
|
barmode='stack', |
|
showlegend=True, |
|
height=600, |
|
template="plotly_white", |
|
font=dict(size=14), |
|
bargap=0.3, |
|
plot_bgcolor='white' |
|
) |
|
|
|
fig.add_annotation( |
|
x='BLT', |
|
y=results['blt_total'] * 1.1, |
|
text=f"Total: {results['blt_total']:.2e}", |
|
showarrow=False, |
|
font=dict(size=12, color="black", weight="bold"), |
|
bgcolor="white", |
|
bordercolor="black", |
|
borderwidth=1 |
|
) |
|
|
|
|
|
fig.update_traces(textfont_size=10) |
|
|
|
return fig |
|
|
|
|
|
with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo: |
|
gr.Markdown(""" |
|
# BLT vs BPE FLOPs Comparison |
|
|
|
This interactive visualization compares the computational efficiency (FLOPs per byte) between: |
|
- **BPE (Byte Pair Encoding)**: Traditional transformer architecture |
|
- **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes. |
|
|
|
A few things you'll notice: |
|
1. Patch size reduces global model FLOPs but not local model |
|
2. Increasing patch size and global model dimension doesn't change total FLOPs |
|
3. In smaller BLTs, local models constitute a larger portion of the total FLOPs |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Adjustable Parameters") |
|
blt_ps_slider = gr.Slider( |
|
minimum=1.0, |
|
maximum=10.0, |
|
value=4.4, |
|
step=0.1, |
|
label="BLT Patch Size (blt_ps)", |
|
info="Patch size for BLT architecture" |
|
) |
|
|
|
d_model_slider = gr.Slider( |
|
minimum=512, |
|
maximum=8192, |
|
value=2560, |
|
step=128, |
|
label="Model Dimension (d_model)", |
|
info="Hidden dimension size of the model" |
|
) |
|
|
|
local_n_layers_slider = gr.Slider( |
|
minimum=2, |
|
maximum=24, |
|
value=10, |
|
step=2, |
|
label="Local Model Layers (local_n_layers)", |
|
info="Number of layers in the local model" |
|
) |
|
|
|
gr.Markdown("### Fixed Parameters") |
|
gr.Markdown(""" |
|
- **BPE's bytes per token**: 4.4 |
|
- **BPE/BLT Number of Layers**: 26 |
|
- **BPE/BLT Number of Heads**: 20 |
|
- **BPE's Vocabulary Size**: 128,000 |
|
- **BPE/BLT Context Length**: 8,192 bytes |
|
- **Local Model Dimension**: 1,024 |
|
- **Local Model Heads**: 16 |
|
""") |
|
|
|
gr.Markdown("### Current Values") |
|
info_text = gr.Markdown("") |
|
|
|
with gr.Column(scale=2): |
|
plot = gr.Plot(label="FLOPs Comparison") |
|
|
|
|
|
def update_plot(blt_ps, d_model, local_n_layers): |
|
fig = create_visualization(blt_ps, d_model, local_n_layers) |
|
|
|
|
|
results = calculate_flops(blt_ps, d_model, local_n_layers) |
|
info_str = f""" |
|
**BPE FLOPs/byte**: {results['bpe_per_byte']:.2e} |
|
|
|
**BLT Global FLOPs/byte**: {results['blt_global']:.2e} |
|
|
|
**BLT Local FLOPs/byte**: {results['blt_local']:.2e} |
|
|
|
**BLT Total FLOPs/byte**: {results['blt_total']:.2e} |
|
""" |
|
|
|
return fig, info_str |
|
|
|
|
|
blt_ps_slider.change( |
|
update_plot, |
|
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], |
|
outputs=[plot, info_text] |
|
) |
|
d_model_slider.change( |
|
update_plot, |
|
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], |
|
outputs=[plot, info_text] |
|
) |
|
local_n_layers_slider.change( |
|
update_plot, |
|
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], |
|
outputs=[plot, info_text] |
|
) |
|
|
|
|
|
demo.load( |
|
update_plot, |
|
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], |
|
outputs=[plot, info_text] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|