lucalp commited on
Commit
7afe1ac
·
1 Parent(s): 51dd960

Added model parameter count

Browse files
Files changed (1) hide show
  1. app.py +148 -65
app.py CHANGED
@@ -10,79 +10,117 @@ bpe_ps = 4.4 # determined by tokenizer
10
  n_ctx_base = 8192
11
  n_heads = 20
12
  n_vocab = 128000
13
- n_layers = 26
14
 
15
  # Fixed local model parameters
16
  local_d_model = 1024
17
  local_g_size = 1
18
  local_n_ctx = 512 # in bytes
19
  local_n_heads = 16
20
- local_n_vocab = 256
21
  local_d_model_k = local_d_model / local_n_heads
22
  local_d_ff_multiplier = 4
23
 
24
- def openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab, ff_ratio=4):
25
  """Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
26
- d_attn = d_model // n_heads
27
- d_ff = d_model * ff_ratio
28
 
29
- embeddings = 4 * d_model
30
- attn_qkv = 2 * n_layers * d_model * 3 * (d_attn * n_heads)
31
- attn_mask = 2 * n_layers * n_ctx * (d_attn * n_heads)
32
- attn_project = 2 * n_layers * (d_attn * n_heads) * d_model
33
- ff = 2 * n_layers * 2 * d_model * d_ff
34
- logits = 2 * d_model * n_vocab
35
 
36
  return embeddings + attn_qkv + attn_mask + attn_project + ff + logits
37
 
38
 
39
- def cross_attention_flops_per_token(n_layers, n_ctx_cross_attn_kv_len, d_model):
40
  ca_qo_proj_flops = (
41
  # Cross Attention QO FLOPs + backward
42
- 2 * 4 * d_model**2
43
  )
44
- ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model
45
- return n_layers * (ca_qo_proj_flops + ca_context_flops)
46
 
47
 
48
- def calculate_flops(blt_ps, d_model, local_n_layers):
49
  # BPE calculations
50
  n_ctx = int(n_ctx_base / bpe_ps)
51
- bpe_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab)
52
- bpe_per_byte = bpe_flops_per_token / bpe_ps
53
 
54
  # BLT Global calculations
55
  blt_n_ctx = int(n_ctx_base / blt_ps)
56
- blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, blt_n_ctx, n_vocab=0)
57
  blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps
58
 
59
  # BLT Local calculations
60
  local_models_transformer_flops_per_byte = openai_flops_per_token(
61
- local_n_layers, local_n_heads, local_d_model, local_n_ctx, local_n_vocab
62
  )
63
  encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
64
- local_n_layers/2, local_n_ctx, local_d_model
65
  )
66
  decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
67
- local_n_layers/2, local_n_ctx // blt_ps, local_d_model
68
  )
69
  local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
70
  local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte
71
 
72
  # Calculate advantage
73
  blt_total = local_models_flops + blt_global_flops_per_byte
74
- advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte)
 
75
 
76
  return {
77
  'bpe_per_byte': bpe_per_byte,
78
  'blt_global': blt_global_flops_per_byte,
79
  'blt_local': local_models_flops,
80
  'blt_total': blt_total,
81
- 'advantage': advantage
82
  }
83
 
84
- def create_visualization(blt_ps, d_model, local_n_layers):
85
- results = calculate_flops(blt_ps, d_model, local_n_layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Create the figure with subplots for better control
88
  fig = go.Figure()
@@ -143,24 +181,55 @@ def create_visualization(blt_ps, d_model, local_n_layers):
143
  ),
144
  barmode='stack',
145
  showlegend=True,
146
- height=600,
147
  template="plotly_white",
148
  font=dict(size=14),
149
  bargap=0.3,
150
- plot_bgcolor='white'
 
151
  )
152
 
153
  fig.add_annotation(
154
  x='BLT',
155
- y=results['blt_total'] * 1.1, # Position above stacked bar
156
- text=f"Total: {results['blt_total']:.2e}",
157
  showarrow=False,
158
- font=dict(size=12, color="black", weight="bold"),
159
- bgcolor="white",
160
  bordercolor="black",
161
- borderwidth=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # Update traces to ensure proper stacking
165
  fig.update_traces(textfont_size=10)
166
 
@@ -171,14 +240,15 @@ with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
171
  gr.Markdown("""
172
  # BLT vs BPE FLOPs Comparison
173
 
174
- This interactive visualization compares the computational efficiency (FLOPs per byte) between:
175
  - **BPE (Byte Pair Encoding)**: Traditional transformer architecture
176
  - **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes.
177
 
178
  A few things you'll notice:
179
  1. Patch size reduces global model FLOPs but not local model
180
- 2. Increasing patch size and global model dimension doesn't change total FLOPs
181
  3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
 
182
  """)
183
 
184
  with gr.Row():
@@ -198,75 +268,88 @@ with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
198
  maximum=8192,
199
  value=2560,
200
  step=128,
201
- label="Model Dimension (d_model)",
202
- info="Hidden dimension size of the model"
203
  )
204
 
205
  local_n_layers_slider = gr.Slider(
206
  minimum=2,
207
- maximum=24,
208
  value=10,
209
- step=2,
210
  label="Local Model Layers (local_n_layers)",
211
- info="Number of layers in the local model"
212
  )
213
 
214
  gr.Markdown("### Fixed Parameters")
215
- gr.Markdown("""
216
- - **BPE's bytes per token**: 4.4
217
- - **BPE/BLT Number of Layers**: 26
218
- - **BPE/BLT Number of Heads**: 20
219
- - **BPE's Vocabulary Size**: 128,000
220
- - **BPE/BLT Context Length**: 8,192 bytes
221
- - **Local Model Dimension**: 1,024
222
- - **Local Model Heads**: 16
 
 
223
  """)
224
 
225
- gr.Markdown("### Current Values")
226
  info_text = gr.Markdown("")
227
 
228
  with gr.Column(scale=2):
229
- plot = gr.Plot(label="FLOPs Comparison")
230
 
231
  # Set up interactivity
232
- def update_plot(blt_ps, d_model, local_n_layers):
233
- fig = create_visualization(blt_ps, d_model, local_n_layers)
 
 
 
 
 
 
 
 
 
 
234
 
235
- # Calculate values for info display
236
- results = calculate_flops(blt_ps, d_model, local_n_layers)
237
  info_str = f"""
238
  **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}
 
239
 
240
  **BLT Global FLOPs/byte**: {results['blt_global']:.2e}
241
-
242
  **BLT Local FLOPs/byte**: {results['blt_local']:.2e}
243
-
244
  **BLT Total FLOPs/byte**: {results['blt_total']:.2e}
245
- """
 
246
 
 
 
247
  return fig, info_str
248
 
249
  # Update plot when any slider changes
 
250
  blt_ps_slider.change(
251
- update_plot,
252
- inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
253
  outputs=[plot, info_text]
254
  )
255
  d_model_slider.change(
256
- update_plot,
257
- inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
258
  outputs=[plot, info_text]
259
  )
260
  local_n_layers_slider.change(
261
- update_plot,
262
- inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
263
  outputs=[plot, info_text]
264
  )
265
 
266
  # Initial plot
267
  demo.load(
268
- update_plot,
269
- inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider],
270
  outputs=[plot, info_text]
271
  )
272
 
 
10
  n_ctx_base = 8192
11
  n_heads = 20
12
  n_vocab = 128000
13
+ n_layers = 26 # Used for BPE model and BLT Global model
14
 
15
  # Fixed local model parameters
16
  local_d_model = 1024
17
  local_g_size = 1
18
  local_n_ctx = 512 # in bytes
19
  local_n_heads = 16
20
+ local_n_vocab = 256 # Used for BLT Local model
21
  local_d_model_k = local_d_model / local_n_heads
22
  local_d_ff_multiplier = 4
23
 
24
+ def openai_flops_per_token(n_layers_val, n_heads_val, d_model_val, n_ctx_val, n_vocab_val, ff_ratio=4):
25
  """Open AI method for forward pass FLOPs counting of decoder-only Transformer"""
26
+ d_attn = d_model_val // n_heads_val
27
+ d_ff = d_model_val * ff_ratio
28
 
29
+ embeddings = 4 * d_model_val # FLOPs for embeddings - not parameter count
30
+ attn_qkv = 2 * n_layers_val * d_model_val * 3 * (d_attn * n_heads_val)
31
+ attn_mask = 2 * n_layers_val * n_ctx_val * (d_attn * n_heads_val)
32
+ attn_project = 2 * n_layers_val * (d_attn * n_heads_val) * d_model_val
33
+ ff = 2 * n_layers_val * 2 * d_model_val * d_ff
34
+ logits = 2 * d_model_val * n_vocab_val
35
 
36
  return embeddings + attn_qkv + attn_mask + attn_project + ff + logits
37
 
38
 
39
+ def cross_attention_flops_per_token(n_layers_ca, n_ctx_cross_attn_kv_len, d_model_ca):
40
  ca_qo_proj_flops = (
41
  # Cross Attention QO FLOPs + backward
42
+ 2 * 4 * d_model_ca**2
43
  )
44
+ ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model_ca
45
+ return n_layers_ca * (ca_qo_proj_flops + ca_context_flops)
46
 
47
 
48
+ def calculate_flops(blt_ps, d_model_slider, local_n_layers_slider):
49
  # BPE calculations
50
  n_ctx = int(n_ctx_base / bpe_ps)
51
+ bpe_flops_per_token_val = openai_flops_per_token(n_layers, n_heads, d_model_slider, n_ctx, n_vocab)
52
+ bpe_per_byte = bpe_flops_per_token_val / bpe_ps
53
 
54
  # BLT Global calculations
55
  blt_n_ctx = int(n_ctx_base / blt_ps)
56
+ blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model_slider, blt_n_ctx, n_vocab_val=0) # n_vocab=0 for global
57
  blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps
58
 
59
  # BLT Local calculations
60
  local_models_transformer_flops_per_byte = openai_flops_per_token(
61
+ local_n_layers_slider, local_n_heads, local_d_model, local_n_ctx, local_n_vocab, ff_ratio=local_d_ff_multiplier
62
  )
63
  encoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
64
+ local_n_layers_slider / 2, local_n_ctx, local_d_model
65
  )
66
  decoder_model_ca_flops_per_byte = cross_attention_flops_per_token(
67
+ local_n_layers_slider / 2, local_n_ctx // blt_ps, local_d_model
68
  )
69
  local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte
70
  local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte
71
 
72
  # Calculate advantage
73
  blt_total = local_models_flops + blt_global_flops_per_byte
74
+ advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) if bpe_per_byte != 0 else 0
75
+
76
 
77
  return {
78
  'bpe_per_byte': bpe_per_byte,
79
  'blt_global': blt_global_flops_per_byte,
80
  'blt_local': local_models_flops,
81
  'blt_total': blt_total,
82
+ 'advantage': advantage,
83
  }
84
 
85
+ def format_params_display(num_params):
86
+ """Formats number of parameters into a string with M or B units."""
87
+ if num_params is None:
88
+ return ""
89
+ if abs(num_params) >= 1_000_000_000:
90
+ return f"{num_params / 1_000_000_000:.1f}B Params"
91
+ elif abs(num_params) >= 1_000_000:
92
+ return f"{num_params / 1_000_000:.1f}M Params"
93
+ else: # For numbers less than 1M
94
+ return f"{num_params / 1_000_000:.2f}M Params"
95
+
96
+
97
+ def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
98
+ results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)
99
+
100
+ # Calculate model parameters
101
+ # BPE Model Parameters: 12 * N * D^2 + 2 * V * D
102
+ # N = n_layers (global), D = d_model_slider, V = n_vocab (global)
103
+ bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)
104
+
105
+ # BLT Model Parameters
106
+ # Global Component: 12 * N * D^2 (no main vocab projection)
107
+ # N = n_layers (global), D = d_model_slider
108
+ blt_global_internal_params = 12 * n_layers * d_model_slider**2
109
+
110
+ # Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
111
+ # N_local = local_n_layers_slider, D_local = local_d_model, V_local = local_n_vocab
112
+ blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
113
+ (2 * local_n_vocab * local_d_model)
114
+ # Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
115
+ # This assumes 4*D^2 params per CA block (Q,K,V,O projections within local_d_model or from global to local)
116
+ # and local_n_layers_slider effective CA blocks.
117
+ blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
118
+ blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params
119
+
120
+ blt_total_model_params = blt_global_internal_params + blt_local_total_internal_params
121
+
122
+ bpe_params_str = format_params_display(bpe_model_params)
123
+ blt_params_str = format_params_display(blt_total_model_params)
124
 
125
  # Create the figure with subplots for better control
126
  fig = go.Figure()
 
181
  ),
182
  barmode='stack',
183
  showlegend=True,
184
+ height=650, # Increased height slightly for param text
185
  template="plotly_white",
186
  font=dict(size=14),
187
  bargap=0.3,
188
+ plot_bgcolor='white',
189
+ margin=dict(b=100) # Add bottom margin for parameter text
190
  )
191
 
192
  fig.add_annotation(
193
  x='BLT',
194
+ y=results['blt_total'] * 1.05, # Position above stacked bar, adjust if needed
195
+ text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
196
  showarrow=False,
197
+ font=dict(size=12, color="black"), # Removed bold to differentiate from param text
198
+ bgcolor="rgba(255,255,255,0.5)", # Slight background for readability
199
  bordercolor="black",
200
+ borderwidth=1,
201
+ xanchor='center',
202
+ yanchor='bottom'
203
+ )
204
+
205
+ # Add parameter count annotations at the bottom of bars
206
+ fig.add_annotation(
207
+ x='BPE',
208
+ y=0,
209
+ text=bpe_params_str,
210
+ showarrow=False,
211
+ xref="x",
212
+ yref="paper", # Use paper coordinates for y to position below x-axis
213
+ yanchor='top',
214
+ xanchor='center',
215
+ yshift=-35, # Adjust this value to position correctly below the bar
216
+ font=dict(size=11, color="black", weight="bold"),
217
  )
218
 
219
+ fig.add_annotation(
220
+ x='BLT',
221
+ y=0,
222
+ text=blt_params_str,
223
+ showarrow=False,
224
+ xref="x",
225
+ yref="paper",
226
+ yanchor='top',
227
+ xanchor='center',
228
+ yshift=-35, # Adjust this value
229
+ font=dict(size=11, color="black", weight="bold"),
230
+ )
231
+
232
+
233
  # Update traces to ensure proper stacking
234
  fig.update_traces(textfont_size=10)
235
 
 
240
  gr.Markdown("""
241
  # BLT vs BPE FLOPs Comparison
242
 
243
+ This interactive visualization compares the computational efficiency (FLOPs per byte) and total model parameters between:
244
  - **BPE (Byte Pair Encoding)**: Traditional transformer architecture
245
  - **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes.
246
 
247
  A few things you'll notice:
248
  1. Patch size reduces global model FLOPs but not local model
249
+ 2. Increasing patch size and global model dimension doesn't change total FLOPs (Note: FLOPs yes, parameters will change with d_model)
250
  3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
251
+ Parameter counts are displayed below each bar.
252
  """)
253
 
254
  with gr.Row():
 
268
  maximum=8192,
269
  value=2560,
270
  step=128,
271
+ label="Global Model Dimension (d_model)",
272
+ info="Hidden dimension size of the BPE model and BLT's Global model"
273
  )
274
 
275
  local_n_layers_slider = gr.Slider(
276
  minimum=2,
277
+ maximum=24, # Max value for local_n_layers
278
  value=10,
279
+ step=2, # Ensure even numbers for CA split
280
  label="Local Model Layers (local_n_layers)",
281
+ info="Number of layers in the BLT's local model"
282
  )
283
 
284
  gr.Markdown("### Fixed Parameters")
285
+ gr.Markdown(f"""
286
+ - **BPE's bytes per token (bpe_ps)**: {bpe_ps}
287
+ - **BPE/BLT Global - Num Layers (n_layers)**: {n_layers}
288
+ - **BPE/BLT Global - Num Heads (n_heads)**: {n_heads}
289
+ - **BPE - Vocabulary Size (n_vocab)**: {n_vocab:,}
290
+ - **BPE/BLT - Context Length (n_ctx_base)**: {n_ctx_base:,} bytes
291
+ - **BLT Local - Model Dimension (local_d_model)**: {local_d_model}
292
+ - **BLT Local - Num Heads (local_n_heads)**: {local_n_heads}
293
+ - **BLT Local - Vocabulary Size (local_n_vocab)**: {local_n_vocab}
294
+ - **BLT Local - FF Multiplier (local_d_ff_multiplier)**: {local_d_ff_multiplier}
295
  """)
296
 
297
+ gr.Markdown("### Current Values & Totals")
298
  info_text = gr.Markdown("")
299
 
300
  with gr.Column(scale=2):
301
+ plot = gr.Plot(label="FLOPs Comparison & Model Parameters")
302
 
303
  # Set up interactivity
304
+ def update_plot_and_info(blt_ps_val, d_model_val, local_n_layers_val):
305
+ fig = create_visualization(blt_ps_val, d_model_val, local_n_layers_val)
306
+ results = calculate_flops(blt_ps_val, d_model_val, local_n_layers_val)
307
+
308
+ # Recalculate parameters for info text (could also be returned by create_visualization or calculate_flops)
309
+ bpe_model_p = (12 * n_layers * d_model_val**2) + (2 * n_vocab * d_model_val)
310
+ blt_global_p = 12 * n_layers * d_model_val**2
311
+ blt_local_transformer_p = (12 * local_n_layers_val * local_d_model**2) + \
312
+ (2 * local_n_vocab * local_d_model)
313
+ blt_local_ca_p = local_n_layers_val * 4 * local_d_model**2
314
+ blt_local_total_internal_p = blt_local_transformer_p + blt_local_ca_p
315
+ blt_total_model_p = blt_global_p + blt_local_total_internal_p
316
 
 
 
317
  info_str = f"""
318
  **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e}
319
+ **BPE Total Params**: {format_params_display(bpe_model_p)}
320
 
321
  **BLT Global FLOPs/byte**: {results['blt_global']:.2e}
 
322
  **BLT Local FLOPs/byte**: {results['blt_local']:.2e}
 
323
  **BLT Total FLOPs/byte**: {results['blt_total']:.2e}
324
+ **BLT Total Params**: {format_params_display(blt_total_model_p)}
325
+ (Global: {format_params_display(blt_global_p)}, Local: {format_params_display(blt_local_total_internal_p)})
326
 
327
+ **BLT Advantage (FLOPs/byte vs BPE)**: {results['advantage']:.1f}%
328
+ """
329
  return fig, info_str
330
 
331
  # Update plot when any slider changes
332
+ inputs_list = [blt_ps_slider, d_model_slider, local_n_layers_slider]
333
  blt_ps_slider.change(
334
+ update_plot_and_info,
335
+ inputs=inputs_list,
336
  outputs=[plot, info_text]
337
  )
338
  d_model_slider.change(
339
+ update_plot_and_info,
340
+ inputs=inputs_list,
341
  outputs=[plot, info_text]
342
  )
343
  local_n_layers_slider.change(
344
+ update_plot_and_info,
345
+ inputs=inputs_list,
346
  outputs=[plot, info_text]
347
  )
348
 
349
  # Initial plot
350
  demo.load(
351
+ update_plot_and_info,
352
+ inputs=inputs_list,
353
  outputs=[plot, info_text]
354
  )
355