Adding model size breakdown
Browse files
app.py
CHANGED
@@ -94,33 +94,37 @@ def format_params_display(num_params):
|
|
94 |
return f"{num_params / 1_000_000:.2f}M Params"
|
95 |
|
96 |
|
|
|
97 |
def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
|
98 |
results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)
|
99 |
|
100 |
# Calculate model parameters
|
101 |
# BPE Model Parameters: 12 * N * D^2 + 2 * V * D
|
102 |
-
# N = n_layers (global), D = d_model_slider, V = n_vocab (global)
|
103 |
bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)
|
104 |
|
105 |
# BLT Model Parameters
|
106 |
# Global Component: 12 * N * D^2 (no main vocab projection)
|
107 |
-
# N = n_layers (global), D = d_model_slider
|
108 |
blt_global_internal_params = 12 * n_layers * d_model_slider**2
|
109 |
|
110 |
# Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
|
111 |
-
# N_local = local_n_layers_slider, D_local = local_d_model, V_local = local_n_vocab
|
112 |
blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
|
113 |
(2 * local_n_vocab * local_d_model)
|
|
|
114 |
# Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
|
115 |
-
# This assumes 4*D^2 params per CA block (Q,K,V,O projections within local_d_model or from global to local)
|
116 |
-
# and local_n_layers_slider effective CA blocks.
|
117 |
blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
|
118 |
blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params
|
119 |
|
120 |
-
blt_total_model_params = blt_global_internal_params + blt_local_total_internal_params
|
121 |
|
122 |
bpe_params_str = format_params_display(bpe_model_params)
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
# Create the figure with subplots for better control
|
126 |
fig = go.Figure()
|
@@ -181,21 +185,21 @@ def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
|
|
181 |
),
|
182 |
barmode='stack',
|
183 |
showlegend=True,
|
184 |
-
height=650,
|
185 |
template="plotly_white",
|
186 |
font=dict(size=14),
|
187 |
bargap=0.3,
|
188 |
plot_bgcolor='white',
|
189 |
-
margin=dict(b=
|
190 |
)
|
191 |
|
192 |
fig.add_annotation(
|
193 |
x='BLT',
|
194 |
-
y=results['blt_total'] * 1.05,
|
195 |
text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
|
196 |
showarrow=False,
|
197 |
-
font=dict(size=12, color="black"),
|
198 |
-
bgcolor="rgba(255,255,255,0.5)",
|
199 |
bordercolor="black",
|
200 |
borderwidth=1,
|
201 |
xanchor='center',
|
@@ -209,24 +213,25 @@ def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
|
|
209 |
text=bpe_params_str,
|
210 |
showarrow=False,
|
211 |
xref="x",
|
212 |
-
yref="paper",
|
213 |
yanchor='top',
|
214 |
xanchor='center',
|
215 |
-
yshift=-35,
|
216 |
-
font=dict(size=
|
217 |
)
|
218 |
|
219 |
fig.add_annotation(
|
220 |
x='BLT',
|
221 |
y=0,
|
222 |
-
text=
|
223 |
showarrow=False,
|
224 |
xref="x",
|
225 |
yref="paper",
|
226 |
yanchor='top',
|
227 |
xanchor='center',
|
228 |
-
yshift=-
|
229 |
-
font=dict(size=
|
|
|
230 |
)
|
231 |
|
232 |
|
@@ -246,7 +251,7 @@ with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo:
|
|
246 |
|
247 |
A few things you'll notice:
|
248 |
1. Patch size reduces global model FLOPs but not local model
|
249 |
-
2. Increasing patch size and global model dimension doesn't change total FLOPs
|
250 |
3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
|
251 |
Parameter counts are displayed below each bar.
|
252 |
""")
|
|
|
94 |
return f"{num_params / 1_000_000:.2f}M Params"
|
95 |
|
96 |
|
97 |
+
|
98 |
def create_visualization(blt_ps, d_model_slider, local_n_layers_slider):
|
99 |
results = calculate_flops(blt_ps, d_model_slider, local_n_layers_slider)
|
100 |
|
101 |
# Calculate model parameters
|
102 |
# BPE Model Parameters: 12 * N * D^2 + 2 * V * D
|
|
|
103 |
bpe_model_params = (12 * n_layers * d_model_slider**2) + (2 * n_vocab * d_model_slider)
|
104 |
|
105 |
# BLT Model Parameters
|
106 |
# Global Component: 12 * N * D^2 (no main vocab projection)
|
|
|
107 |
blt_global_internal_params = 12 * n_layers * d_model_slider**2
|
108 |
|
109 |
# Local Component Transformer Part: 12 * N_local * D_local^2 + 2 * V_local * D_local
|
|
|
110 |
blt_local_transformer_params = (12 * local_n_layers_slider * local_d_model**2) + \
|
111 |
(2 * local_n_vocab * local_d_model)
|
112 |
+
|
113 |
# Local Component Cross-Attention Part: N_local * 4 * D_local^2 (estimated)
|
|
|
|
|
114 |
blt_local_ca_params = local_n_layers_slider * 4 * local_d_model**2
|
115 |
blt_local_total_internal_params = blt_local_transformer_params + blt_local_ca_params
|
116 |
|
117 |
+
# blt_total_model_params = blt_global_internal_params + blt_local_total_internal_params # Kept for potential other uses, not directly for this annotation
|
118 |
|
119 |
bpe_params_str = format_params_display(bpe_model_params)
|
120 |
+
|
121 |
+
# Format BLT global and local parameters separately
|
122 |
+
blt_global_params_fmt_str = format_params_display(blt_global_internal_params)
|
123 |
+
blt_local_params_fmt_str = format_params_display(blt_local_total_internal_params)
|
124 |
+
|
125 |
+
# Combine for annotation text, using <br> for line break
|
126 |
+
blt_combined_params_str = f"Global: {blt_global_params_fmt_str}<br>Local: {blt_local_params_fmt_str}"
|
127 |
+
|
128 |
|
129 |
# Create the figure with subplots for better control
|
130 |
fig = go.Figure()
|
|
|
185 |
),
|
186 |
barmode='stack',
|
187 |
showlegend=True,
|
188 |
+
height=650,
|
189 |
template="plotly_white",
|
190 |
font=dict(size=14),
|
191 |
bargap=0.3,
|
192 |
plot_bgcolor='white',
|
193 |
+
margin=dict(b=110) # Increased bottom margin slightly more for two lines of text
|
194 |
)
|
195 |
|
196 |
fig.add_annotation(
|
197 |
x='BLT',
|
198 |
+
y=results['blt_total'] * 1.05,
|
199 |
text=f"Total FLOPs/Byte: {results['blt_total']:.2e}",
|
200 |
showarrow=False,
|
201 |
+
font=dict(size=12, color="black"),
|
202 |
+
bgcolor="rgba(255,255,255,0.5)",
|
203 |
bordercolor="black",
|
204 |
borderwidth=1,
|
205 |
xanchor='center',
|
|
|
213 |
text=bpe_params_str,
|
214 |
showarrow=False,
|
215 |
xref="x",
|
216 |
+
yref="paper",
|
217 |
yanchor='top',
|
218 |
xanchor='center',
|
219 |
+
yshift=-35,
|
220 |
+
font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
|
221 |
)
|
222 |
|
223 |
fig.add_annotation(
|
224 |
x='BLT',
|
225 |
y=0,
|
226 |
+
text=blt_combined_params_str, # Using the new combined string with breakdown
|
227 |
showarrow=False,
|
228 |
xref="x",
|
229 |
yref="paper",
|
230 |
yanchor='top',
|
231 |
xanchor='center',
|
232 |
+
yshift=-45, # Adjusted yshift for two lines of text
|
233 |
+
font=dict(size=10, color="black", weight="bold"), # Font size 10 for param text
|
234 |
+
align="center" # Ensure text is centered if it wraps due to <br>
|
235 |
)
|
236 |
|
237 |
|
|
|
251 |
|
252 |
A few things you'll notice:
|
253 |
1. Patch size reduces global model FLOPs but not local model
|
254 |
+
2. Increasing patch size and global model dimension doesn't change total FLOPs
|
255 |
3. In smaller BLTs, local models constitute a larger portion of the total FLOPs
|
256 |
Parameter counts are displayed below each bar.
|
257 |
""")
|