lucalp commited on
Commit
810a93a
·
1 Parent(s): 7afe1ac

Adding model size breakdown

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -94,33 +94,37 @@ def format_params_display(num_params):
94
  return f"{num_params / 1_000_000:.2f}M Params"
95
 
96
 
 
97
  def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
98
  results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)
99
 
100
  # Calculate model parameters
101
  # BPE Model Parameters: 12 * N * D^2 + 2 * V * D
102
- # N = n_layers (global), D = d_model_slider, V = n_vocab (global)
103
  bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)
104
 
105
  # BLT Model Parameters
106
  # Global Component: 12 * N * D^2 (no main vocab projection)
107
- # N = n_layers (global), D = d_model_slider
108
  blt_global_internal_params = 12 * n_layers * d_model_slider**2
109
 
110
  # Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
111
- # N_local = local_n_layers_slider, D_local = local_d_model, V_local = local_n_vocab
112
  blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
113
  (2 * local_n_vocab * local_d_model)
 
114
  # Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
115
- # This assumes 4*D^2 params per CA block (Q,K,V,O projections within local_d_model or from global to local)
116
- # and local_n_layers_slider effective CA blocks.
117
  blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
118
  blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params
119
 
120
- blt_total_model_params = blt_global_internal_params + blt_local_total_internal_params
121
 
122
  bpe_params_str = format_params_display(bpe_model_params)
123
- blt_params_str = format_params_display(blt_total_model_params)
 
 
 
 
 
 
 
124
 
125
  # Create the figure with subplots for better control
126
  fig = go.Figure()
@@ -181,21 +185,21 @@ def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
181
  ),
182
  barmode='stack',
183
  showlegend=True,
184
- height=650, # Increased height slightly for param text
185
  template="plotly_white",
186
  font=dict(size=14),
187
  bargap=0.3,
188
  plot_bgcolor='white',
189
- margin=dict(b=100) # Add bottom margin for parameter text
190
  )
191
 
192
  fig.add_annotation(
193
  x='BLT',
194
- y=results['blt_total'] * 1.05, # Position above stacked bar, adjust if needed
195
  text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
196
  showarrow=False,
197
- font=dict(size=12, color="black"), # Removed bold to differentiate from param text
198
- bgcolor="rgba(255,255,255,0.5)", # Slight background for readability
199
  bordercolor="black",
200
  borderwidth=1,
201
  xanchor='center',
@@ -209,24 +213,25 @@ def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
209
  text=bpe_params_str,
210
  showarrow=False,
211
  xref="x",
212
- yref="paper", # Use paper coordinates for y to position below x-axis
213
  yanchor='top',
214
  xanchor='center',
215
- yshift=-35, # Adjust this value to position correctly below the bar
216
- font=dict(size=11, color="black", weight="bold"),
217
  )
218
 
219
  fig.add_annotation(
220
  x='BLT',
221
  y=0,
222
- text=blt_params_str,
223
  showarrow=False,
224
  xref="x",
225
  yref="paper",
226
  yanchor='top',
227
  xanchor='center',
228
- yshift=-35, # Adjust this value
229
- font=dict(size=11, color="black", weight="bold"),
 
230
  )
231
 
232
 
@@ -246,7 +251,7 @@ with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
246
 
247
  A few things you'll notice:
248
  1. Patch size reduces global model FLOPs but not local model
249
- 2. Increasing patch size and global model dimension doesn't change total FLOPs (Note: FLOPs yes, parameters will change with d_model)
250
  3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
251
  Parameter counts are displayed below each bar.
252
  """)
 
94
  return f"{num_params / 1_000_000:.2f}M Params"
95
 
96
 
97
+
98
  def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
99
  results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)
100
 
101
  # Calculate model parameters
102
  # BPE Model Parameters: 12 * N * D^2 + 2 * V * D
 
103
  bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)
104
 
105
  # BLT Model Parameters
106
  # Global Component: 12 * N * D^2 (no main vocab projection)
 
107
  blt_global_internal_params = 12 * n_layers * d_model_slider**2
108
 
109
  # Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
 
110
  blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
111
  (2 * local_n_vocab * local_d_model)
112
+
113
  # Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
 
 
114
  blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
115
  blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params
116
 
117
+ # blt_total_model_params = blt_global_internal_params + blt_local_total_internal_params # Kept for potential other uses, not directly for this annotation
118
 
119
  bpe_params_str = format_params_display(bpe_model_params)
120
+
121
+ # Format BLT global and local parameters separately
122
+ blt_global_params_fmt_str = format_params_display(blt_global_internal_params)
123
+ blt_local_params_fmt_str = format_params_display(blt_local_total_internal_params)
124
+
125
+ # Combine for annotation text, using <br> for line break
126
+ blt_combined_params_str = f"Global: {blt_global_params_fmt_str}<br>Local: {blt_local_params_fmt_str}"
127
+
128
 
129
  # Create the figure with subplots for better control
130
  fig = go.Figure()
 
185
  ),
186
  barmode='stack',
187
  showlegend=True,
188
+ height=650,
189
  template="plotly_white",
190
  font=dict(size=14),
191
  bargap=0.3,
192
  plot_bgcolor='white',
193
+ margin=dict(b=110) # Increased bottom margin slightly more for two lines of text
194
  )
195
 
196
  fig.add_annotation(
197
  x='BLT',
198
+ y=results['blt_total'] * 1.05,
199
  text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
200
  showarrow=False,
201
+ font=dict(size=12, color="black"),
202
+ bgcolor="rgba(255,255,255,0.5)",
203
  bordercolor="black",
204
  borderwidth=1,
205
  xanchor='center',
 
213
  text=bpe_params_str,
214
  showarrow=False,
215
  xref="x",
216
+ yref="paper",
217
  yanchor='top',
218
  xanchor='center',
219
+ yshift=-35,
220
+ font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
221
  )
222
 
223
  fig.add_annotation(
224
  x='BLT',
225
  y=0,
226
+ text=blt_combined_params_str, # Using the new combined string with breakdown
227
  showarrow=False,
228
  xref="x",
229
  yref="paper",
230
  yanchor='top',
231
  xanchor='center',
232
+ yshift=-45, # Adjusted yshift for two lines of text
233
+ font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
234
+ align="center" # Ensure text is centered if it wraps due to <br>
235
  )
236
 
237
 
 
251
 
252
  A few things you'll notice:
253
  1. Patch size reduces global model FLOPs but not local model
254
+ 2. Increasing patch size and global model dimension doesn't change total FLOPs
255
  3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
256
  Parameter counts are displayed below each bar.
257
  """)