v_gonghuilin commited on
Commit
ea1186f
·
1 Parent(s): 194cca8

Add application file

Browse files
Files changed (1) hide show
  1. app.py +544 -0
app.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+
4
+
5
+ def num_floating_point_operations(args):
6
+ def calculate_layer_counts():
7
+ """Calculate the number of attention, Mamba, and MLP layers."""
8
+ if args.hybrid_override_pattern:
9
+ counts = {"M": 0, "*": 0, "-": 0}
10
+ for layer_type in args.hybrid_override_pattern:
11
+ if layer_type in counts:
12
+ counts[layer_type] += 1
13
+ return counts["*"], counts["M"], counts["-"]
14
+ else:
15
+ num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
16
+ num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
17
+ num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
18
+ return num_attn_layers, num_mamba_layers, num_mlp_layers
19
+
20
+ def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
21
+ """Calculate FLOPs for an MLP layer."""
22
+ scale_factor = 3.0 / 2.0 if swiglu else 1.0
23
+ return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2
24
+
25
+ def attn_layer_flops(
26
+ batch_size,
27
+ seq_len,
28
+ hidden_size,
29
+ num_heads,
30
+ gqa=True,
31
+ gqa_groups=8,
32
+ kv_channels=None,
33
+ ):
34
+ """Calculate FLOPs for an attention layer."""
35
+ p = (kv_channels * num_heads / hidden_size) if kv_channels else 1
36
+ g = gqa_groups if gqa else num_heads
37
+ return (
38
+ 4
39
+ * batch_size
40
+ * seq_len
41
+ * hidden_size
42
+ * p
43
+ * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2))
44
+ )
45
+
46
+ def mamba_layer_flops(
47
+ batch_size, seq_len, hidden_size, state_dim=16, head_dim=64, num_groups=1
48
+ ):
49
+ """Calculate FLOPs for a Mamba layer."""
50
+ # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels,
51
+ # but small percent of overall layer flops
52
+ d_in = 2 * hidden_size
53
+ nheads = d_in // head_dim
54
+ return (
55
+ (
56
+ 2
57
+ * batch_size
58
+ * seq_len
59
+ * hidden_size
60
+ * (2 * d_in + 2 * num_groups * state_dim + nheads)
61
+ ) # in_proj
62
+ + (7 * batch_size * seq_len * d_in * state_dim) # scan
63
+ + (2 * batch_size * seq_len * d_in * hidden_size) # out_proj
64
+ )
65
+
66
+ def hybrid_flops(
67
+ batch_size,
68
+ seq_len,
69
+ hidden_size,
70
+ num_attn_layers,
71
+ num_mamba_layers,
72
+ num_mlp_layers,
73
+ mamba_state_dim=128,
74
+ mamba_head_dim=64,
75
+ mamba_num_groups=8,
76
+ num_attn_heads=32,
77
+ gqa=True,
78
+ gqa_groups=8,
79
+ kv_channels=None,
80
+ mlp_expansion=4.0,
81
+ swiglu=False,
82
+ vocab_size=256000,
83
+ ):
84
+ """Calculate total FLOPs for the hybrid model."""
85
+ flops_fwd = (
86
+ num_attn_layers
87
+ * attn_layer_flops(
88
+ batch_size,
89
+ seq_len,
90
+ hidden_size,
91
+ num_attn_heads,
92
+ gqa,
93
+ gqa_groups,
94
+ kv_channels,
95
+ )
96
+ + num_mlp_layers
97
+ * mlp_layer_flops(batch_size, seq_len, hidden_size, mlp_expansion, swiglu)
98
+ + num_mamba_layers
99
+ * mamba_layer_flops(
100
+ batch_size,
101
+ seq_len,
102
+ hidden_size,
103
+ mamba_state_dim,
104
+ mamba_head_dim,
105
+ mamba_num_groups,
106
+ )
107
+ + (
108
+ 2 * batch_size * seq_len * hidden_size * vocab_size
109
+ ) # logits computation
110
+ )
111
+ return flops_fwd * 3
112
+
113
+ def transformer_flops():
114
+ """Calculate FLOPs for a standard Transformer model."""
115
+ # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods.
116
+ # Attention projection size.
117
+ query_projection_size = args.kv_channels * args.num_attention_heads
118
+ query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
119
+ # Group Query Attention.
120
+ if not args.group_query_attention:
121
+ args.num_query_groups = args.num_attention_heads
122
+ # MoE.
123
+ if args.num_experts is None:
124
+ # Every Transformer MLP is dense.
125
+ num_dense_layers = args.num_layers
126
+ num_moe_layers = 0
127
+ num_experts_routed_to = 0
128
+ last_layer_is_moe = 0
129
+ else:
130
+ # Calculate number of dense and MoE Transformer MLPs.
131
+ if isinstance(args.moe_layer_freq, int):
132
+ moe_layer_pattern = [
133
+ 1 if (i % args.moe_layer_freq == 0) else 0
134
+ for i in range(args.num_layers)
135
+ ]
136
+ elif isinstance(args.moe_layer_freq, list):
137
+ moe_layer_pattern = args.moe_layer_freq
138
+ else:
139
+ raise RuntimeError("Illegal --moe-layer-freq argument provided!")
140
+ assert len(moe_layer_pattern) == args.num_layers, (
141
+ f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
142
+ f"expected {args.num_layers}, "
143
+ f"current moe layer pattern: {args.moe_layer_freq}"
144
+ )
145
+ num_moe_layers = sum(
146
+ moe_layer_pattern
147
+ ) # Number of 1s in `moe_layer_pattern`.
148
+ num_dense_layers = args.num_layers - num_moe_layers
149
+ num_experts_routed_to = args.moe_router_topk
150
+ last_layer_is_moe = moe_layer_pattern[-1]
151
+
152
+ if args.mtp_num_layers is not None:
153
+ mtp_num_layers = args.mtp_num_layers
154
+ num_moe_layers += last_layer_is_moe * mtp_num_layers
155
+ num_dense_layers += (1 - last_layer_is_moe) * mtp_num_layers
156
+ num_layers = args.num_layers + mtp_num_layers
157
+ else:
158
+ mtp_num_layers = 0
159
+ num_layers = args.num_layers
160
+
161
+ moe_ffn_hidden_size = (
162
+ args.moe_ffn_hidden_size
163
+ if args.moe_ffn_hidden_size is not None
164
+ else args.ffn_hidden_size
165
+ )
166
+ shared_expert_ffn_hidden_size = (
167
+ 0
168
+ if args.moe_shared_expert_intermediate_size is None
169
+ else args.moe_shared_expert_intermediate_size
170
+ )
171
+ # SwiGLU.
172
+ gated_linear_multiplier = 3 / 2 if args.swiglu else 1
173
+
174
+ # The 12x term below comes from the following factors; for more details, see
175
+ # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
176
+ # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
177
+ # backward wgrad [weight gradient], backward dgrad [data gradient]).
178
+ # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
179
+ # architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
180
+ # in MLP layer).
181
+ # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
182
+ expansion_factor = 3 * 2 * 2
183
+
184
+ if args.multi_latent_attention:
185
+ assert not args.group_query_attention
186
+ """
187
+ Basic arithmetic
188
+ let B is batch size, s is seq_len, h is embedding dim,
189
+ for one self_attnetion block (prenorm is not included)
190
+ qkv projection: 6Bsh^2
191
+ attn: 2Bs^2h
192
+ attn over value: 2Bs^2h
193
+ oproj: 2Bsh^2
194
+
195
+ references
196
+ https://arxiv.org/abs/2305.10403
197
+ https://arxiv.org/abs/2205.05198
198
+ """
199
+ ## MLA
200
+ if args.q_lora_rank is None:
201
+ q_term = (
202
+ args.hidden_size
203
+ * args.num_attention_heads
204
+ * (args.qk_head_dim + args.qk_pos_emb_head_dim)
205
+ )
206
+ else:
207
+ q_term = args.q_lora_rank * (
208
+ args.hidden_size
209
+ + args.num_attention_heads
210
+ * (args.qk_head_dim + args.qk_pos_emb_head_dim)
211
+ + 1
212
+ )
213
+ self_attn_term = (
214
+ 3
215
+ * 2 # fwd(1) + bwd(2) *FMA
216
+ * num_layers
217
+ * (
218
+ ## q lora + rope + q norm
219
+ q_term
220
+ ## kv lora + rope + kv norm
221
+ + args.kv_lora_rank
222
+ * (
223
+ args.hidden_size
224
+ + args.num_attention_heads
225
+ * (args.qk_head_dim + args.v_head_dim)
226
+ + 1
227
+ )
228
+ + args.hidden_size * args.qk_pos_emb_head_dim
229
+ ## o proj
230
+ + (args.num_attention_heads * args.v_head_dim) * args.hidden_size
231
+ ## core attn
232
+ + args.seq_length
233
+ * (
234
+ args.num_attention_heads
235
+ * (args.qk_head_dim + args.qk_pos_emb_head_dim)
236
+ )
237
+ / 2
238
+ + args.seq_length * args.num_attention_heads * args.v_head_dim / 2
239
+ )
240
+ )
241
+
242
+ else:
243
+ ## MHA or GQA
244
+ self_attn_term = (
245
+ expansion_factor
246
+ * num_layers
247
+ * args.hidden_size
248
+ * args.hidden_size
249
+ * (
250
+ (
251
+ 1
252
+ + (args.num_query_groups / args.num_attention_heads)
253
+ # # Only half of the attention matrix is non-zero and needs to be multiplied with V.
254
+ + (args.seq_length / args.hidden_size / 2)
255
+ )
256
+ * query_projection_to_hidden_size_ratio
257
+ )
258
+ )
259
+
260
+ total_floating_point_operations = (
261
+ args.batch_size
262
+ * args.seq_length
263
+ * (
264
+ # MLP
265
+ expansion_factor
266
+ * num_layers
267
+ * args.hidden_size
268
+ * (
269
+ # dense layer (deepseek v2, v3 style)
270
+ (args.ffn_hidden_size * gated_linear_multiplier)
271
+ * (num_dense_layers / num_layers)
272
+ # routed experts
273
+ + (
274
+ moe_ffn_hidden_size
275
+ * num_experts_routed_to
276
+ * gated_linear_multiplier
277
+ )
278
+ * (num_moe_layers / num_layers)
279
+ # Shared Experts.
280
+ + (shared_expert_ffn_hidden_size * gated_linear_multiplier)
281
+ * (num_moe_layers / num_layers)
282
+ )
283
+ # Self Attention
284
+ + self_attn_term
285
+ # MTP norms and proj
286
+ + 3
287
+ * 2
288
+ * mtp_num_layers
289
+ * (
290
+ # MTP eh norm + final nrom
291
+ 3 * args.hidden_size
292
+ # MTH eh proj
293
+ + 2 * args.hidden_size * args.hidden_size
294
+ )
295
+ # Logit.
296
+ + 3
297
+ * 2
298
+ * args.hidden_size
299
+ * args.padded_vocab_size
300
+ * (mtp_num_layers + 1)
301
+ )
302
+ )
303
+ return total_floating_point_operations
304
+
305
+ # Main entrypoint for FLOPs calculation.
306
+ if args.is_hybrid_model:
307
+ # Calculate the number of each type of layer.
308
+ num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()
309
+
310
+ # Compute hybrid model FLOPs.
311
+ return hybrid_flops(
312
+ batch_size=args.batch_size,
313
+ seq_len=args.seq_length,
314
+ hidden_size=args.hidden_size,
315
+ num_attn_layers=num_attn_layers,
316
+ num_mamba_layers=num_mamba_layers,
317
+ num_mlp_layers=num_mlp_layers,
318
+ mamba_state_dim=args.mamba_state_dim,
319
+ mamba_head_dim=args.mamba_head_dim,
320
+ mamba_num_groups=args.mamba_num_groups,
321
+ num_attn_heads=args.num_attention_heads,
322
+ gqa=args.group_query_attention,
323
+ gqa_groups=args.num_query_groups,
324
+ kv_channels=args.kv_channels,
325
+ mlp_expansion=args.ffn_hidden_size / args.hidden_size,
326
+ swiglu=args.swiglu,
327
+ vocab_size=args.padded_vocab_size,
328
+ )
329
+ else:
330
+ # Compute standard Transformer model FLOPs.
331
+ return transformer_flops()
332
+
333
+
334
+ def calculate_flops(args):
335
+ model_flops = num_floating_point_operations(args)
336
+ flops_per_token = model_flops / (args.batch_size * args.seq_length)
337
+ print(f"FLOPs Per Iteration: {model_flops}\nFLOPs Per Token: {flops_per_token}")
338
+ return model_flops
339
+
340
+
341
+ def calculate_mfu(model_flops, *, iter_elapsed_time, num_p800_cards):
342
+ assert (
343
+ model_flops and iter_elapsed_time and num_p800_cards
344
+ ), "Iter elapsed time and P800 cards must be provided"
345
+ mfu = model_flops / (iter_elapsed_time * num_p800_cards * 3.5e14)
346
+ print(f"MFU P800 bf16: {mfu:.2%}")
347
+
348
+
349
+ def calculate_mfu_web( is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
350
+ ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
351
+ num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
352
+ multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
353
+ mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards
354
+ ):
355
+ is_hybrid_model = True if is_hybrid_model == "True" else False
356
+ group_query_attention = True if group_query_attention == "True" else False
357
+ swiglu = True if swiglu == "True" else False
358
+ multi_latent_attention = True if multi_latent_attention == "True" else False
359
+
360
+ '''
361
+ 为了直接调用calculate_flops(args)接口,这里将参数直接打包
362
+ '''
363
+ class parameter:
364
+ def __init__(self,
365
+ is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
366
+ ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
367
+ num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
368
+ multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
369
+ mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards,
370
+ hybrid_override_pattern=None):
371
+ self.is_hybrid_model = is_hybrid_model
372
+ self.group_query_attention = group_query_attention
373
+ self.swiglu = swiglu
374
+ self.num_layers = num_layers
375
+ self.hidden_size = hidden_size
376
+ self.ffn_hidden_size = ffn_hidden_size
377
+ self.padded_vocab_size = padded_vocab_size
378
+ self.num_attention_heads = num_attention_heads
379
+ self.kv_channels = kv_channels
380
+ self.num_experts = num_experts
381
+ self.moe_layer_freq = moe_layer_freq
382
+ self.moe_router_topk = moe_router_topk
383
+ self.moe_ffn_hidden_size = moe_ffn_hidden_size
384
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
385
+ self.multi_latent_attention = multi_latent_attention
386
+ self.q_lora_rank = q_lora_rank
387
+ self.kv_lora_rank = kv_lora_rank
388
+ self.qk_head_dim = qk_head_dim
389
+ self.v_head_dim = v_head_dim
390
+ self.qk_pos_emb_head_dim = qk_pos_emb_head_dim
391
+ self.mtp_num_layers = mtp_num_layers
392
+ self.seq_length = seq_length
393
+ self.batch_size = batch_size
394
+ self.iter_elapsed_time = iter_elapsed_time
395
+ self.num_p800_cards = num_p800_cards
396
+ self.hybrid_override_pattern = hybrid_override_pattern
397
+
398
+ mfu_parameter = parameter(is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
399
+ ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
400
+ num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
401
+ multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
402
+ mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards,
403
+ hybrid_override_pattern=None)
404
+
405
+ model_flops = num_floating_point_operations(mfu_parameter)
406
+ flops_per_token = model_flops / (batch_size * seq_length)
407
+ print(f"FLOPs Per Iteration: {model_flops}\nFLOPs Per Token: {flops_per_token}")
408
+
409
+ assert (
410
+ model_flops and iter_elapsed_time and num_p800_cards
411
+ ), "Iter elapsed time and P800 cards must be provided"
412
+
413
+ mfu = model_flops / (iter_elapsed_time * num_p800_cards * 3.5e14)
414
+ print(f"MFU P800 bf16: {mfu:.2%}")
415
+ return model_flops, flops_per_token, "{:.2f}".format(mfu * 100)
416
+
417
+ if __name__ == "__main__":
418
+ parser = argparse.ArgumentParser()
419
+ args = parser.parse_args()
420
+
421
+ # Standard Transformer config
422
+ args.is_hybrid_model = False
423
+ args.group_query_attention = False
424
+ args.swiglu = True
425
+ args.num_layers = 61
426
+ args.hidden_size = 7168
427
+ args.ffn_hidden_size = 18432
428
+ args.padded_vocab_size = 100002
429
+ args.num_attention_heads = 128
430
+ args.kv_channels = 128
431
+
432
+ # MoE config
433
+ args.num_experts = 256
434
+ args.moe_layer_freq = 1
435
+ args.moe_router_topk = 8
436
+ args.moe_ffn_hidden_size = 2048
437
+ args.moe_shared_expert_intermediate_size = 2048
438
+
439
+ # MLA config
440
+ args.multi_latent_attention = True
441
+ args.q_lora_rank = 1536
442
+ args.kv_lora_rank = 512
443
+ args.qk_head_dim = 128
444
+ args.v_head_dim = 128
445
+ args.qk_pos_emb_head_dim = 64
446
+
447
+ # MTP config
448
+ args.mtp_num_layers = 1
449
+
450
+ # Data config
451
+ args.seq_length = 4096
452
+ args.batch_size = 1024
453
+
454
+ # mfu config
455
+ args.iter_elapsed_time = 100
456
+ args.num_p800_cards = 512
457
+
458
+ #calculate_mfu(calculate_flops(args), iter_elapsed_time=args.iter_elapsed_time, num_p800_cards=args.num_p800_cards)
459
+ with gr.Blocks(title="Compute MFU") as demo:
460
+ gr.Markdown("## Compute MFU")
461
+
462
+ with gr.Group() as custom_group:
463
+ gr.Markdown("Standard Transformer config:")
464
+ with gr.Row():
465
+ is_hybrid_model = gr.Dropdown(["True", "False"],
466
+ label="hybrid model",
467
+ value="True" if args.is_hybrid_model else "False")
468
+
469
+ group_query_attention = gr.Dropdown(["True", "False"],
470
+ label="group query attention",
471
+ value="True" if args.group_query_attention else "False")
472
+
473
+ swiglu = gr.Dropdown(["True", "False"],
474
+ label="swiglu",
475
+ value="True" if args.swiglu else "False")
476
+
477
+ num_layers = gr.Number(label="num layers", value=args.num_layers, precision=0)
478
+ hidden_size = gr.Number(label="hidden size", value=args.hidden_size, precision=0)
479
+ ffn_hidden_size = gr.Number(label="ffn hidden size", value=args.ffn_hidden_size, precision=0)
480
+ padded_vocab_size = gr.Number(label="padded vocab size", value=args.padded_vocab_size, precision=0)
481
+ num_attention_heads = gr.Number(label="num attention heads", value=args.num_attention_heads, precision=0)
482
+ kv_channels = gr.Number(label="kv channels", value=args.kv_channels, precision=0)
483
+
484
+ with gr.Group() as custom_group:
485
+ gr.Markdown("MoE config:")
486
+ with gr.Row():
487
+ num_experts = gr.Number(label="num experts", value=args.num_experts, precision=0)
488
+ moe_layer_freq = gr.Number(label="moe layer freq", value=args.moe_layer_freq, precision=0)
489
+ moe_router_topk = gr.Number(label="moe router topk", value=args.moe_router_topk, precision=0)
490
+ moe_ffn_hidden_size = gr.Number(label="moe ffn hidden size", value=args.moe_ffn_hidden_size, precision=0)
491
+ moe_shared_expert_intermediate_size = gr.Number(label="moe shared expert intermediate size", value=args.moe_shared_expert_intermediate_size, precision=0)
492
+
493
+ with gr.Group() as custom_group:
494
+ gr.Markdown("MLA config:")
495
+ with gr.Row():
496
+ multi_latent_attention = gr.Dropdown(["True", "False"],
497
+ label="multi_latent_attention",
498
+ value="True" if args.multi_latent_attention else "False")
499
+ q_lora_rank = gr.Number(label="q lora rank", value=args.q_lora_rank, precision=0)
500
+ kv_lora_rank = gr.Number(label="kv lora rank", value=args.kv_lora_rank, precision=0)
501
+ qk_head_dim = gr.Number(label="qk head dim", value=args.qk_head_dim, precision=0)
502
+ v_head_dim = gr.Number(label="v head dim", value=args.v_head_dim, precision=0)
503
+ qk_pos_emb_head_dim = gr.Number(label="qk pos emb head dim", value=args.qk_pos_emb_head_dim, precision=0)
504
+
505
+ with gr.Group() as custom_group:
506
+ with gr.Row():
507
+ with gr.Group():
508
+ gr.Markdown("MTP config:")
509
+ mtp_num_layers = gr.Number(label="mtp num layers", value=args.mtp_num_layers, precision=0)
510
+
511
+ with gr.Group():
512
+ gr.Markdown("Data config:")
513
+ with gr.Row():
514
+ seq_length = gr.Number(label="seq length", value=args.seq_length, precision=0)
515
+ batch_size = gr.Number(label="batch size", value=args.batch_size, precision=0)
516
+
517
+ with gr.Group():
518
+ gr.Markdown("MFU config:")
519
+ with gr.Row():
520
+ iter_elapsed_time = gr.Number(label="iter elapsed time", value=args.iter_elapsed_time, precision=0)
521
+ num_p800_cards = gr.Number(label="num p800 cards", value=args.num_p800_cards, precision=0)
522
+
523
+ # 计算结果显示控件
524
+ with gr.Group() as custom_group:
525
+ gr.Markdown("Compute results:")
526
+ with gr.Row():
527
+ model_flops = gr.Number(label="model flops", precision=0)
528
+ flops_per_token = gr.Number(label="flops per token", precision=0)
529
+ # mfu = gr.Number(label="mfu", precision=0)
530
+ mfu = gr.Textbox(label="MFU P800 bf16")
531
+
532
+ # 计算按钮
533
+ btn = gr.Button("Calculate")
534
+ btn.click( fn=calculate_mfu_web,
535
+ inputs=[is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
536
+ ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
537
+ num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
538
+ multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
539
+ mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards],
540
+ outputs=[model_flops, flops_per_token, mfu]
541
+ )
542
+
543
+ # 启动 Gradio 应用
544
+ demo.launch()