lucalp's picture
Initial commit
6d3d780
raw
history blame
8.99 kB
"""
h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py
"""
import gradio as gr
import plotly.graph_objects as go
import numpy as np
# Fixed BPE parameters
bpe_ps = 4.4 # determined by tokenizer
n_ctx_base = 8192
n_heads = 20
n_vocab = 128000
n_layers = 26
# Fixed local model parameters
local_d_model = 1024
local_g_size = 1
local_n_ctx = 512 # in bytes
local_n_heads = 16
local_n_vocab = 256
local_d_model_k = local_d_model / local_n_heads
local_d_ff_multiplier = 4
def openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab, ff_ratio=4):
"""Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
d_attn = d_model // n_heads
d_ff = d_model * ff_ratio
embeddings = 4 * d_model
attn_qkv = 2 * n_layers * d_model * 3 * (d_attn * n_heads)
attn_mask = 2 * n_layers * n_ctx * (d_attn * n_heads)
attn_project = 2 * n_layers * (d_attn * n_heads) * d_model
ff = 2 * n_layers * 2 * d_model * d_ff
logits = 2 * d_model * n_vocab
return embeddings + attn_qkv + attn_mask + attn_project + ff + logits
def cross_attention_flops_per_token(n_layers, n_ctx_cross_attn_kv_len, d_model):
ca_qo_proj_flops = (
# Cross Attention QO FLOPs + backward
2 * 4 * d_model**2
)
ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model
return n_layers * (ca_qo_proj_flops + ca_context_flops)
def calculate_flops(blt_ps, d_model, local_n_layers):
# BPE calculations
n_ctx = int(n_ctx_base / bpe_ps)
bpe_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab)
bpe_per_byte = bpe_flops_per_token / bpe_ps
# BLT Global calculations
blt_n_ctx = int(n_ctx_base / blt_ps)
blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, blt_n_ctx, n_vocab=0)
blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps
# BLT Local calculations
local_models_transformer_flops_per_byte = openai_flops_per_token(
local_n_layers, local_n_heads, local_d_model, local_n_ctx, local_n_vocab
)
encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
local_n_layers/2, local_n_ctx, local_d_model
)
decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
local_n_layers/2, local_n_ctx // blt_ps, local_d_model
)
local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte
# Calculate advantage
blt_total = local_models_flops + blt_global_flops_per_byte
advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte)
return {
'bpe_per_byte': bpe_per_byte,
'blt_global': blt_global_flops_per_byte,
'blt_local': local_models_flops,
'blt_total': blt_total,
'advantage': advantage
}
def create_visualization(blt_ps, d_model, local_n_layers):
results = calculate_flops(blt_ps, d_model, local_n_layers)
# Create the figure with subplots for better control
fig = go.Figure()
# Add BPE bar (only for BPE category)
fig.add_trace(go.Bar(
name='BPE',
x=['BPE'],
y=[results['bpe_per_byte']],
text=[f"{results['bpe_per_byte']:.2e}"],
textposition='outside',
marker_color='#FF6B6B',
width=0.4,
showlegend=True
))
# Add BLT Global bar (base of stack)
fig.add_trace(go.Bar(
name='BLT Global',
x=['BLT'],
y=[results['blt_global']],
text=[f"{results['blt_global']:.2e}"],
textposition='inside',
marker_color='#4ECDC4',
width=0.4,
showlegend=True
))
# Add BLT Local bar (top of stack)
fig.add_trace(go.Bar(
name='BLT Local',
x=['BLT'],
y=[results['blt_local']],
text=[f"{results['blt_local']:.2e}"],
textposition='inside',
marker_color='#45B7D1',
width=0.4,
showlegend=True
))
# Update layout with proper stacking and scientific notation
fig.update_layout(
title={
'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 20}
},
xaxis=dict(
title="Architecture",
tickfont=dict(size=14)
),
yaxis=dict(
title="FLOPs per Byte",
tickformat=".1e", # Scientific notation with 1 decimal
tickfont=dict(size=12),
gridcolor='lightgray'
),
barmode='stack',
showlegend=True,
height=600,
template="plotly_white",
font=dict(size=14),
bargap=0.3,
plot_bgcolor='white'
)
fig.add_annotation(
x='BLT',
y=results['blt_total'] * 1.1, # Position above stacked bar
text=f"Total: {results['blt_total']:.2e}",
showarrow=False,
font=dict(size=12, color="black", weight="bold"),
bgcolor="white",
bordercolor="black",
borderwidth=1
)
# Update traces to ensure proper stacking
fig.update_traces(textfont_size=10)
return fig
# Create Gradio interface
with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
gr.Markdown("""
# BLT vs BPE FLOPs Comparison
This interactive visualization compares the computational efficiency (FLOPs per byte) between:
- **BPE (Byte Pair Encoding)**: Traditional transformer architecture
- **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes.
A few things you'll notice:
1. Patch size reduces global model FLOPs but not local model
2. Increasing patch size and global model dimension doesn't change total FLOPs
3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Adjustable Parameters")
blt_ps_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=4.4,
step=0.1,
label="BLT Patch Size (blt_ps)",
info="Patch size for BLT architecture"
)
d_model_slider = gr.Slider(
minimum=512,
maximum=8192,
value=2560,
step=128,
label="Model Dimension (d_model)",
info="Hidden dimension size of the model"
)
local_n_layers_slider = gr.Slider(
minimum=2,
maximum=24,
value=10,
step=2,
label="Local Model Layers (local_n_layers)",
info="Number of layers in the local model"
)
gr.Markdown("### Fixed Parameters")
gr.Markdown("""
- **BPE's bytes per token**: 4.4
- **BPE/BLT Number of Layers**: 26
- **BPE/BLT Number of Heads**: 20
- **BPE's Vocabulary Size**: 128,000
- **BPE/BLT Context Length**: 8,192 bytes
- **Local Model Dimension**: 1,024
- **Local Model Heads**: 16
""")
gr.Markdown("### Current Values")
info_text = gr.Markdown("")
with gr.Column(scale=2):
plot = gr.Plot(label="FLOPs Comparison")
# Set up interactivity
def update_plot(blt_ps, d_model, local_n_layers):
fig = create_visualization(blt_ps, d_model, local_n_layers)
# Calculate values for info display
results = calculate_flops(blt_ps, d_model, local_n_layers)
info_str = f"""
**BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}
**BLT Global FLOPs/byte**: {results['blt_global']:.2e}
**BLT Local FLOPs/byte**: {results['blt_local']:.2e}
**BLT Total FLOPs/byte**: {results['blt_total']:.2e}
"""
return fig, info_str
# Update plot when any slider changes
blt_ps_slider.change(
update_plot,
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
outputs=[plot, info_text]
)
d_model_slider.change(
update_plot,
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
outputs=[plot, info_text]
)
local_n_layers_slider.change(
update_plot,
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
outputs=[plot, info_text]
)
# Initial plot
demo.load(
update_plot,
inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
outputs=[plot, info_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch()