File size: 8,993 Bytes
6d3d780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py
"""
import gradio as gr
import plotly.graph_objects as go
import numpy as np

# Fixed BPE parameters
bpe_ps = 4.4  # determined by tokenizer
n_ctx_base = 8192
n_heads = 20
n_vocab = 128000
n_layers = 26

# Fixed local model parameters
local_d_model = 1024
local_g_size = 1
local_n_ctx = 512  # in bytes
local_n_heads = 16
local_n_vocab = 256
local_d_model_k = local_d_model / local_n_heads
local_d_ff_multiplier = 4

def openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab, ff_ratio=4):
    """Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
    d_attn = d_model // n_heads
    d_ff = d_model * ff_ratio

    embeddings = 4 * d_model
    attn_qkv = 2 * n_layers * d_model * 3 * (d_attn * n_heads)
    attn_mask = 2 * n_layers * n_ctx * (d_attn * n_heads)
    attn_project = 2 * n_layers * (d_attn * n_heads) * d_model
    ff = 2 * n_layers * 2 * d_model * d_ff
    logits = 2 * d_model * n_vocab

    return embeddings + attn_qkv + attn_mask + attn_project + ff + logits


def cross_attention_flops_per_token(n_layers, n_ctx_cross_attn_kv_len, d_model):
    ca_qo_proj_flops = (
        # Cross Attention QO FLOPs + backward
        2 * 4 * d_model**2
    )
    ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model
    return n_layers * (ca_qo_proj_flops + ca_context_flops)


def calculate_flops(blt_ps, d_model, local_n_layers):
    # BPE calculations
    n_ctx = int(n_ctx_base / bpe_ps)
    bpe_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab)
    bpe_per_byte = bpe_flops_per_token / bpe_ps

    # BLT Global calculations
    blt_n_ctx = int(n_ctx_base / blt_ps)
    blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, blt_n_ctx, n_vocab=0)
    blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps

    # BLT Local calculations
    local_models_transformer_flops_per_byte = openai_flops_per_token(
        local_n_layers, local_n_heads, local_d_model, local_n_ctx, local_n_vocab
    )
    encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
        local_n_layers/2, local_n_ctx, local_d_model
    )
    decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
        local_n_layers/2, local_n_ctx // blt_ps, local_d_model
    )
    local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
    local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte

    # Calculate advantage
    blt_total = local_models_flops + blt_global_flops_per_byte
    advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte)

    return {
        'bpe_per_byte': bpe_per_byte,
        'blt_global': blt_global_flops_per_byte,
        'blt_local': local_models_flops,
        'blt_total': blt_total,
        'advantage': advantage
    }

def create_visualization(blt_ps, d_model, local_n_layers):
    results = calculate_flops(blt_ps, d_model, local_n_layers)

    # Create the figure with subplots for better control
    fig = go.Figure()

    # Add BPE bar (only for BPE category)
    fig.add_trace(go.Bar(
        name='BPE',
        x=['BPE'],
        y=[results['bpe_per_byte']],
        text=[f"{results['bpe_per_byte']:.2e}"],
        textposition='outside',
        marker_color='#FF6B6B',
        width=0.4,
        showlegend=True
    ))

    # Add BLT Global bar (base of stack)
    fig.add_trace(go.Bar(
        name='BLT Global',
        x=['BLT'],
        y=[results['blt_global']],
        text=[f"{results['blt_global']:.2e}"],
        textposition='inside',
        marker_color='#4ECDC4',
        width=0.4,
        showlegend=True
    ))

    # Add BLT Local bar (top of stack)
    fig.add_trace(go.Bar(
        name='BLT Local',
        x=['BLT'],
        y=[results['blt_local']],
        text=[f"{results['blt_local']:.2e}"],
        textposition='inside',
        marker_color='#45B7D1',
        width=0.4,
        showlegend=True
    ))

    # Update layout with proper stacking and scientific notation
    fig.update_layout(
        title={
            'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>",
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 20}
        },
        xaxis=dict(
            title="Architecture",
            tickfont=dict(size=14)
        ),
        yaxis=dict(
            title="FLOPs per Byte",
            tickformat=".1e",  # Scientific notation with 1 decimal
            tickfont=dict(size=12),
            gridcolor='lightgray'
        ),
        barmode='stack',
        showlegend=True,
        height=600,
        template="plotly_white",
        font=dict(size=14),
        bargap=0.3,
        plot_bgcolor='white'
    )

    fig.add_annotation(
        x='BLT',
        y=results['blt_total'] * 1.1,  # Position above stacked bar
        text=f"Total: {results['blt_total']:.2e}",
        showarrow=False,
        font=dict(size=12, color="black", weight="bold"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1
    )

    # Update traces to ensure proper stacking
    fig.update_traces(textfont_size=10)

    return fig

# Create Gradio interface
with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
    gr.Markdown("""
    # BLT vs BPE FLOPs Comparison

    This interactive visualization compares the computational efficiency (FLOPs per byte) between:
    - **BPE (Byte Pair Encoding)**: Traditional transformer architecture
    - **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes.

    A few things you'll notice:
    1. Patch size reduces global model FLOPs but not local model
    2. Increasing patch size and global model dimension doesn't change total FLOPs
    3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Adjustable Parameters")
            blt_ps_slider = gr.Slider(
                minimum=1.0,
                maximum=10.0,
                value=4.4,
                step=0.1,
                label="BLT Patch Size (blt_ps)",
                info="Patch size for BLT architecture"
            )

            d_model_slider = gr.Slider(
                minimum=512,
                maximum=8192,
                value=2560,
                step=128,
                label="Model Dimension (d_model)",
                info="Hidden dimension size of the model"
            )

            local_n_layers_slider = gr.Slider(
                minimum=2,
                maximum=24,
                value=10,
                step=2,
                label="Local Model Layers (local_n_layers)",
                info="Number of layers in the local model"
            )

            gr.Markdown("### Fixed Parameters")
            gr.Markdown("""
            - **BPE's bytes per token**: 4.4
            - **BPE/BLT Number of Layers**: 26
            - **BPE/BLT Number of Heads**: 20
            - **BPE's Vocabulary Size**: 128,000
            - **BPE/BLT Context Length**: 8,192 bytes
            - **Local Model Dimension**: 1,024
            - **Local Model Heads**: 16
            """)

            gr.Markdown("### Current Values")
            info_text = gr.Markdown("")

        with gr.Column(scale=2):
            plot = gr.Plot(label="FLOPs Comparison")

    # Set up interactivity
    def update_plot(blt_ps, d_model, local_n_layers):
        fig = create_visualization(blt_ps, d_model, local_n_layers)

        # Calculate values for info display
        results = calculate_flops(blt_ps, d_model, local_n_layers)
        info_str = f"""
        **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}

        **BLT Global FLOPs/byte**: {results['blt_global']:.2e}

        **BLT Local FLOPs/byte**: {results['blt_local']:.2e}

        **BLT Total FLOPs/byte**: {results['blt_total']:.2e}
        """

        return fig, info_str

    # Update plot when any slider changes
    blt_ps_slider.change(
        update_plot,
        inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
        outputs=[plot, info_text]
    )
    d_model_slider.change(
        update_plot,
        inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
        outputs=[plot, info_text]
    )
    local_n_layers_slider.change(
        update_plot,
        inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
        outputs=[plot, info_text]
    )

    # Initial plot
    demo.load(
        update_plot,
        inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
        outputs=[plot, info_text]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()