lucalp commited on
Commit
6d3d780
·
1 Parent(s): dbde1bc

Initial commit

Browse files
Files changed (1) hide show
  1. app.py +275 -0
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py
3
+ """
4
+ import gradio as gr
5
+ import plotly.graph_objects as go
6
+ import numpy as np
7
+
8
+ # Fixed BPE parameters
9
+ bpe_ps = 4.4 # determined by tokenizer
10
+ n_ctx_base = 8192
11
+ n_heads = 20
12
+ n_vocab = 128000
13
+ n_layers = 26
14
+
15
+ # Fixed local model parameters
16
+ local_d_model = 1024
17
+ local_g_size = 1
18
+ local_n_ctx = 512 # in bytes
19
+ local_n_heads = 16
20
+ local_n_vocab = 256
21
+ local_d_model_k = local_d_model / local_n_heads
22
+ local_d_ff_multiplier = 4
23
+
24
+ def openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab, ff_ratio=4):
25
+ """Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
26
+ d_attn = d_model // n_heads
27
+ d_ff = d_model * ff_ratio
28
+
29
+ embeddings = 4 * d_model
30
+ attn_qkv = 2 * n_layers * d_model * 3 * (d_attn * n_heads)
31
+ attn_mask = 2 * n_layers * n_ctx * (d_attn * n_heads)
32
+ attn_project = 2 * n_layers * (d_attn * n_heads) * d_model
33
+ ff = 2 * n_layers * 2 * d_model * d_ff
34
+ logits = 2 * d_model * n_vocab
35
+
36
+ return embeddings + attn_qkv + attn_mask + attn_project + ff + logits
37
+
38
+
39
+ def cross_attention_flops_per_token(n_layers, n_ctx_cross_attn_kv_len, d_model):
40
+ ca_qo_proj_flops = (
41
+ # Cross Attention QO FLOPs + backward
42
+ 2 * 4 * d_model**2
43
+ )
44
+ ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model
45
+ return n_layers * (ca_qo_proj_flops + ca_context_flops)
46
+
47
+
48
+ def calculate_flops(blt_ps, d_model, local_n_layers):
49
+ # BPE calculations
50
+ n_ctx = int(n_ctx_base / bpe_ps)
51
+ bpe_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab)
52
+ bpe_per_byte = bpe_flops_per_token / bpe_ps
53
+
54
+ # BLT Global calculations
55
+ blt_n_ctx = int(n_ctx_base / blt_ps)
56
+ blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, blt_n_ctx, n_vocab=0)
57
+ blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps
58
+
59
+ # BLT Local calculations
60
+ local_models_transformer_flops_per_byte = openai_flops_per_token(
61
+ local_n_layers, local_n_heads, local_d_model, local_n_ctx, local_n_vocab
62
+ )
63
+ encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
64
+ local_n_layers/2, local_n_ctx, local_d_model
65
+ )
66
+ decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
67
+ local_n_layers/2, local_n_ctx // blt_ps, local_d_model
68
+ )
69
+ local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
70
+ local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte
71
+
72
+ # Calculate advantage
73
+ blt_total = local_models_flops + blt_global_flops_per_byte
74
+ advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte)
75
+
76
+ return {
77
+ 'bpe_per_byte': bpe_per_byte,
78
+ 'blt_global': blt_global_flops_per_byte,
79
+ 'blt_local': local_models_flops,
80
+ 'blt_total': blt_total,
81
+ 'advantage': advantage
82
+ }
83
+
84
+ def create_visualization(blt_ps, d_model, local_n_layers):
85
+ results = calculate_flops(blt_ps, d_model, local_n_layers)
86
+
87
+ # Create the figure with subplots for better control
88
+ fig = go.Figure()
89
+
90
+ # Add BPE bar (only for BPE category)
91
+ fig.add_trace(go.Bar(
92
+ name='BPE',
93
+ x=['BPE'],
94
+ y=[results['bpe_per_byte']],
95
+ text=[f"{results['bpe_per_byte']:.2e}"],
96
+ textposition='outside',
97
+ marker_color='#FF6B6B',
98
+ width=0.4,
99
+ showlegend=True
100
+ ))
101
+
102
+ # Add BLT Global bar (base of stack)
103
+ fig.add_trace(go.Bar(
104
+ name='BLT Global',
105
+ x=['BLT'],
106
+ y=[results['blt_global']],
107
+ text=[f"{results['blt_global']:.2e}"],
108
+ textposition='inside',
109
+ marker_color='#4ECDC4',
110
+ width=0.4,
111
+ showlegend=True
112
+ ))
113
+
114
+ # Add BLT Local bar (top of stack)
115
+ fig.add_trace(go.Bar(
116
+ name='BLT Local',
117
+ x=['BLT'],
118
+ y=[results['blt_local']],
119
+ text=[f"{results['blt_local']:.2e}"],
120
+ textposition='inside',
121
+ marker_color='#45B7D1',
122
+ width=0.4,
123
+ showlegend=True
124
+ ))
125
+
126
+ # Update layout with proper stacking and scientific notation
127
+ fig.update_layout(
128
+ title={
129
+ 'text': f"FLOPs per Byte Comparison<br><sub>BLT FLOPs comparison: {results['advantage']:.1f}%</sub>",
130
+ 'x': 0.5,
131
+ 'xanchor': 'center',
132
+ 'font': {'size': 20}
133
+ },
134
+ xaxis=dict(
135
+ title="Architecture",
136
+ tickfont=dict(size=14)
137
+ ),
138
+ yaxis=dict(
139
+ title="FLOPs per Byte",
140
+ tickformat=".1e", # Scientific notation with 1 decimal
141
+ tickfont=dict(size=12),
142
+ gridcolor='lightgray'
143
+ ),
144
+ barmode='stack',
145
+ showlegend=True,
146
+ height=600,
147
+ template="plotly_white",
148
+ font=dict(size=14),
149
+ bargap=0.3,
150
+ plot_bgcolor='white'
151
+ )
152
+
153
+ fig.add_annotation(
154
+ x='BLT',
155
+ y=results['blt_total'] * 1.1, # Position above stacked bar
156
+ text=f"Total: {results['blt_total']:.2e}",
157
+ showarrow=False,
158
+ font=dict(size=12, color="black", weight="bold"),
159
+ bgcolor="white",
160
+ bordercolor="black",
161
+ borderwidth=1
162
+ )
163
+
164
+ # Update traces to ensure proper stacking
165
+ fig.update_traces(textfont_size=10)
166
+
167
+ return fig
168
+
169
+ # Create Gradio interface
170
+ with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
171
+ gr.Markdown("""
172
+ # BLT vs BPE FLOPs Comparison
173
+
174
+ This interactive visualization compares the computational efficiency (FLOPs per byte) between:
175
+ - **BPE (Byte Pair Encoding)**: Traditional transformer architecture
176
+ - **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes.
177
+
178
+ A few things you'll notice:
179
+ 1. Patch size reduces global model FLOPs but not local model
180
+ 2. Increasing patch size and global model dimension doesn't change total FLOPs
181
+ 3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
182
+ """)
183
+
184
+ with gr.Row():
185
+ with gr.Column(scale=1):
186
+ gr.Markdown("### Adjustable Parameters")
187
+ blt_ps_slider = gr.Slider(
188
+ minimum=1.0,
189
+ maximum=10.0,
190
+ value=4.4,
191
+ step=0.1,
192
+ label="BLT Patch Size (blt_ps)",
193
+ info="Patch size for BLT architecture"
194
+ )
195
+
196
+ d_model_slider = gr.Slider(
197
+ minimum=512,
198
+ maximum=8192,
199
+ value=2560,
200
+ step=128,
201
+ label="Model Dimension (d_model)",
202
+ info="Hidden dimension size of the model"
203
+ )
204
+
205
+ local_n_layers_slider = gr.Slider(
206
+ minimum=2,
207
+ maximum=24,
208
+ value=10,
209
+ step=2,
210
+ label="Local Model Layers (local_n_layers)",
211
+ info="Number of layers in the local model"
212
+ )
213
+
214
+ gr.Markdown("### Fixed Parameters")
215
+ gr.Markdown("""
216
+ - **BPE's bytes per token**: 4.4
217
+ - **BPE/BLT Number of Layers**: 26
218
+ - **BPE/BLT Number of Heads**: 20
219
+ - **BPE's Vocabulary Size**: 128,000
220
+ - **BPE/BLT Context Length**: 8,192 bytes
221
+ - **Local Model Dimension**: 1,024
222
+ - **Local Model Heads**: 16
223
+ """)
224
+
225
+ gr.Markdown("### Current Values")
226
+ info_text = gr.Markdown("")
227
+
228
+ with gr.Column(scale=2):
229
+ plot = gr.Plot(label="FLOPs Comparison")
230
+
231
+ # Set up interactivity
232
+ def update_plot(blt_ps, d_model, local_n_layers):
233
+ fig = create_visualization(blt_ps, d_model, local_n_layers)
234
+
235
+ # Calculate values for info display
236
+ results = calculate_flops(blt_ps, d_model, local_n_layers)
237
+ info_str = f"""
238
+ **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}
239
+
240
+ **BLT Global FLOPs/byte**: {results['blt_global']:.2e}
241
+
242
+ **BLT Local FLOPs/byte**: {results['blt_local']:.2e}
243
+
244
+ **BLT Total FLOPs/byte**: {results['blt_total']:.2e}
245
+ """
246
+
247
+ return fig, info_str
248
+
249
+ # Update plot when any slider changes
250
+ blt_ps_slider.change(
251
+ update_plot,
252
+ inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
253
+ outputs=[plot, info_text]
254
+ )
255
+ d_model_slider.change(
256
+ update_plot,
257
+ inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
258
+ outputs=[plot, info_text]
259
+ )
260
+ local_n_layers_slider.change(
261
+ update_plot,
262
+ inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
263
+ outputs=[plot, info_text]
264
+ )
265
+
266
+ # Initial plot
267
+ demo.load(
268
+ update_plot,
269
+ inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
270
+ outputs=[plot, info_text]
271
+ )
272
+
273
+ # Launch the app
274
+ if __name__ == "__main__":
275
+ demo.launch()