File size: 23,535 Bytes
ea1186f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e79b2
ea1186f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
import gradio as gr
import argparse


def num_floating_point_operations(args):
    def calculate_layer_counts():
        """Calculate the number of attention, Mamba, and MLP layers."""
        if args.hybrid_override_pattern:
            counts = {"M": 0, "*": 0, "-": 0}
            for layer_type in args.hybrid_override_pattern:
                if layer_type in counts:
                    counts[layer_type] += 1
            return counts["*"], counts["M"], counts["-"]
        else:
            num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
            num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
            num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
            return num_attn_layers, num_mamba_layers, num_mlp_layers

    def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
        """Calculate FLOPs for an MLP layer."""
        scale_factor = 3.0 / 2.0 if swiglu else 1.0
        return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

    def attn_layer_flops(
        batch_size,
        seq_len,
        hidden_size,
        num_heads,
        gqa=True,
        gqa_groups=8,
        kv_channels=None,
    ):
        """Calculate FLOPs for an attention layer."""
        p = (kv_channels * num_heads / hidden_size) if kv_channels else 1
        g = gqa_groups if gqa else num_heads
        return (
            4
            * batch_size
            * seq_len
            * hidden_size
            * p
            * (hidden_size + (hidden_size * (g / num_heads)) + (seq_len / 2))
        )

    def mamba_layer_flops(
        batch_size, seq_len, hidden_size, state_dim=16, head_dim=64, num_groups=1
    ):
        """Calculate FLOPs for a Mamba layer."""
        # Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels,
        # but small percent of overall layer flops
        d_in = 2 * hidden_size
        nheads = d_in // head_dim
        return (
            (
                2
                * batch_size
                * seq_len
                * hidden_size
                * (2 * d_in + 2 * num_groups * state_dim + nheads)
            )  # in_proj
            + (7 * batch_size * seq_len * d_in * state_dim)  # scan
            + (2 * batch_size * seq_len * d_in * hidden_size)  # out_proj
        )

    def hybrid_flops(
        batch_size,
        seq_len,
        hidden_size,
        num_attn_layers,
        num_mamba_layers,
        num_mlp_layers,
        mamba_state_dim=128,
        mamba_head_dim=64,
        mamba_num_groups=8,
        num_attn_heads=32,
        gqa=True,
        gqa_groups=8,
        kv_channels=None,
        mlp_expansion=4.0,
        swiglu=False,
        vocab_size=256000,
    ):
        """Calculate total FLOPs for the hybrid model."""
        flops_fwd = (
            num_attn_layers
            * attn_layer_flops(
                batch_size,
                seq_len,
                hidden_size,
                num_attn_heads,
                gqa,
                gqa_groups,
                kv_channels,
            )
            + num_mlp_layers
            * mlp_layer_flops(batch_size, seq_len, hidden_size, mlp_expansion, swiglu)
            + num_mamba_layers
            * mamba_layer_flops(
                batch_size,
                seq_len,
                hidden_size,
                mamba_state_dim,
                mamba_head_dim,
                mamba_num_groups,
            )
            + (
                2 * batch_size * seq_len * hidden_size * vocab_size
            )  # logits computation
        )
        return flops_fwd * 3

    def transformer_flops():
        """Calculate FLOPs for a standard Transformer model."""
        # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods.
        # Attention projection size.
        query_projection_size = args.kv_channels * args.num_attention_heads
        query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size
        # Group Query Attention.
        if not args.group_query_attention:
            args.num_query_groups = args.num_attention_heads
        # MoE.
        if args.num_experts is None:
            # Every Transformer MLP is dense.
            num_dense_layers = args.num_layers
            num_moe_layers = 0
            num_experts_routed_to = 0
            last_layer_is_moe = 0
        else:
            # Calculate number of dense and MoE Transformer MLPs.
            if isinstance(args.moe_layer_freq, int):
                moe_layer_pattern = [
                    1 if (i % args.moe_layer_freq == 0) else 0
                    for i in range(args.num_layers)
                ]
            elif isinstance(args.moe_layer_freq, list):
                moe_layer_pattern = args.moe_layer_freq
            else:
                raise RuntimeError("Illegal --moe-layer-freq argument provided!")
            assert len(moe_layer_pattern) == args.num_layers, (
                f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
                f"expected {args.num_layers}, "
                f"current moe layer pattern: {args.moe_layer_freq}"
            )
            num_moe_layers = sum(
                moe_layer_pattern
            )  # Number of 1s in `moe_layer_pattern`.
            num_dense_layers = args.num_layers - num_moe_layers
            num_experts_routed_to = args.moe_router_topk
            last_layer_is_moe = moe_layer_pattern[-1]

        if args.mtp_num_layers is not None:
            mtp_num_layers = args.mtp_num_layers
            num_moe_layers += last_layer_is_moe * mtp_num_layers
            num_dense_layers += (1 - last_layer_is_moe) * mtp_num_layers
            num_layers = args.num_layers + mtp_num_layers
        else:
            mtp_num_layers = 0
            num_layers = args.num_layers

        moe_ffn_hidden_size = (
            args.moe_ffn_hidden_size
            if args.moe_ffn_hidden_size is not None
            else args.ffn_hidden_size
        )
        shared_expert_ffn_hidden_size = (
            0
            if args.moe_shared_expert_intermediate_size is None
            else args.moe_shared_expert_intermediate_size
        )
        # SwiGLU.
        gated_linear_multiplier = 3 / 2 if args.swiglu else 1

        # The 12x term below comes from the following factors; for more details, see
        # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
        # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass,
        #       backward wgrad [weight gradient], backward dgrad [data gradient]).
        # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
        #       architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
        #       in MLP layer).
        # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
        expansion_factor = 3 * 2 * 2

        if args.multi_latent_attention:
            assert not args.group_query_attention
            """
            Basic arithmetic
            let B is batch size, s is seq_len, h is embedding dim,
            for one self_attnetion block (prenorm is not included)
            qkv projection:  6Bsh^2
            attn:            2Bs^2h
            attn over value: 2Bs^2h
            oproj:           2Bsh^2

            references
            https://arxiv.org/abs/2305.10403
            https://arxiv.org/abs/2205.05198
            """
            ## MLA
            if args.q_lora_rank is None:
                q_term = (
                    args.hidden_size
                    * args.num_attention_heads
                    * (args.qk_head_dim + args.qk_pos_emb_head_dim)
                )
            else:
                q_term = args.q_lora_rank * (
                    args.hidden_size
                    + args.num_attention_heads
                    * (args.qk_head_dim + args.qk_pos_emb_head_dim)
                    + 1
                )
            self_attn_term = (
                3
                * 2  # fwd(1) + bwd(2) *FMA
                * num_layers
                * (
                    ## q lora + rope + q norm
                    q_term
                    ## kv lora + rope + kv norm
                    + args.kv_lora_rank
                    * (
                        args.hidden_size
                        + args.num_attention_heads
                        * (args.qk_head_dim + args.v_head_dim)
                        + 1
                    )
                    + args.hidden_size * args.qk_pos_emb_head_dim
                    ## o proj
                    + (args.num_attention_heads * args.v_head_dim) * args.hidden_size
                    ## core attn
                    + args.seq_length
                    * (
                        args.num_attention_heads
                        * (args.qk_head_dim + args.qk_pos_emb_head_dim)
                    )
                    / 2
                    + args.seq_length * args.num_attention_heads * args.v_head_dim / 2
                )
            )

        else:
            ## MHA or GQA
            self_attn_term = (
                expansion_factor
                * num_layers
                * args.hidden_size
                * args.hidden_size
                * (
                    (
                        1
                        + (args.num_query_groups / args.num_attention_heads)
                        # # Only half of the attention matrix is non-zero and needs to be multiplied with V.
                        + (args.seq_length / args.hidden_size / 2)
                    )
                    * query_projection_to_hidden_size_ratio
                )
            )

        total_floating_point_operations = (
            args.batch_size
            * args.seq_length
            * (
                # MLP
                expansion_factor
                * num_layers
                * args.hidden_size
                * (
                    # dense layer (deepseek v2, v3 style)
                    (args.ffn_hidden_size * gated_linear_multiplier)
                    * (num_dense_layers / num_layers)
                    # routed experts
                    + (
                        moe_ffn_hidden_size
                        * num_experts_routed_to
                        * gated_linear_multiplier
                    )
                    * (num_moe_layers / num_layers)
                    # Shared Experts.
                    + (shared_expert_ffn_hidden_size * gated_linear_multiplier)
                    * (num_moe_layers / num_layers)
                )
                # Self Attention
                + self_attn_term
                # MTP norms and proj
                + 3
                * 2
                * mtp_num_layers
                * (
                    # MTP eh norm + final nrom
                    3 * args.hidden_size
                    # MTH eh proj
                    + 2 * args.hidden_size * args.hidden_size
                )
                # Logit.
                + 3
                * 2
                * args.hidden_size
                * args.padded_vocab_size
                * (mtp_num_layers + 1)
            )
        )
        return total_floating_point_operations

    # Main entrypoint for FLOPs calculation.
    if args.is_hybrid_model:
        # Calculate the number of each type of layer.
        num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()

        # Compute hybrid model FLOPs.
        return hybrid_flops(
            batch_size=args.batch_size,
            seq_len=args.seq_length,
            hidden_size=args.hidden_size,
            num_attn_layers=num_attn_layers,
            num_mamba_layers=num_mamba_layers,
            num_mlp_layers=num_mlp_layers,
            mamba_state_dim=args.mamba_state_dim,
            mamba_head_dim=args.mamba_head_dim,
            mamba_num_groups=args.mamba_num_groups,
            num_attn_heads=args.num_attention_heads,
            gqa=args.group_query_attention,
            gqa_groups=args.num_query_groups,
            kv_channels=args.kv_channels,
            mlp_expansion=args.ffn_hidden_size / args.hidden_size,
            swiglu=args.swiglu,
            vocab_size=args.padded_vocab_size,
        )
    else:
        # Compute standard Transformer model FLOPs.
        return transformer_flops()


def calculate_flops(args):
    model_flops = num_floating_point_operations(args)
    flops_per_token = model_flops / (args.batch_size * args.seq_length)
    print(f"FLOPs Per Iteration: {model_flops}\nFLOPs Per Token: {flops_per_token}")
    return model_flops


def calculate_mfu(model_flops, *, iter_elapsed_time, num_p800_cards):
    assert (
        model_flops and iter_elapsed_time and num_p800_cards
    ), "Iter elapsed time and P800 cards must be provided"
    mfu = model_flops / (iter_elapsed_time * num_p800_cards * 3.5e14)
    print(f"MFU P800 bf16: {mfu:.2%}")


def calculate_mfu_web(  is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
                        ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
                        num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
                        multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
                        mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards
                    ):
    is_hybrid_model = True if is_hybrid_model == "True" else False
    group_query_attention = True if group_query_attention == "True" else False
    swiglu = True if swiglu == "True" else False
    multi_latent_attention = True if multi_latent_attention == "True" else False

    '''
    为了直接调用calculate_flops(args)接口,这里将参数直接打包
    '''
    class parameter:
        def __init__(self, 
                        is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
                        ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
                        num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
                        multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
                        mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards, 
                        hybrid_override_pattern=None):
            self.is_hybrid_model = is_hybrid_model
            self.group_query_attention = group_query_attention
            self.swiglu = swiglu
            self.num_layers = num_layers
            self.hidden_size = hidden_size
            self.ffn_hidden_size = ffn_hidden_size
            self.padded_vocab_size = padded_vocab_size
            self.num_attention_heads = num_attention_heads
            self.kv_channels = kv_channels
            self.num_experts = num_experts
            self.moe_layer_freq = moe_layer_freq
            self.moe_router_topk = moe_router_topk
            self.moe_ffn_hidden_size = moe_ffn_hidden_size
            self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
            self.multi_latent_attention = multi_latent_attention
            self.q_lora_rank = q_lora_rank
            self.kv_lora_rank = kv_lora_rank
            self.qk_head_dim = qk_head_dim
            self.v_head_dim = v_head_dim
            self.qk_pos_emb_head_dim = qk_pos_emb_head_dim
            self.mtp_num_layers = mtp_num_layers
            self.seq_length = seq_length
            self.batch_size = batch_size
            self.iter_elapsed_time = iter_elapsed_time
            self.num_p800_cards = num_p800_cards
            self.hybrid_override_pattern = hybrid_override_pattern

    mfu_parameter = parameter(is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
                        ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
                        num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
                        multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
                        mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards, 
                        hybrid_override_pattern=None)

    model_flops = num_floating_point_operations(mfu_parameter)
    flops_per_token = model_flops / (batch_size * seq_length)
    print(f"FLOPs Per Iteration: {model_flops}\nFLOPs Per Token: {flops_per_token}")

    assert (
            model_flops and iter_elapsed_time and num_p800_cards
        ), "Iter elapsed time and P800 cards must be provided"

    mfu = model_flops / (iter_elapsed_time * num_p800_cards * 3.5e14)
    print(f"MFU P800 bf16: {mfu:.2%}")
    return model_flops, flops_per_token, "{:.2f}%".format(mfu * 100)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = parser.parse_args()

    # Standard Transformer config
    args.is_hybrid_model = False
    args.group_query_attention = False
    args.swiglu = True
    args.num_layers = 61
    args.hidden_size = 7168
    args.ffn_hidden_size = 18432
    args.padded_vocab_size = 100002
    args.num_attention_heads = 128
    args.kv_channels = 128

    # MoE config
    args.num_experts = 256
    args.moe_layer_freq = 1
    args.moe_router_topk = 8
    args.moe_ffn_hidden_size = 2048
    args.moe_shared_expert_intermediate_size = 2048

    # MLA config
    args.multi_latent_attention = True
    args.q_lora_rank = 1536
    args.kv_lora_rank = 512
    args.qk_head_dim = 128
    args.v_head_dim = 128
    args.qk_pos_emb_head_dim = 64

    # MTP config
    args.mtp_num_layers = 1

    # Data config
    args.seq_length = 4096
    args.batch_size = 1024

    # mfu config
    args.iter_elapsed_time = 100
    args.num_p800_cards = 512

    #calculate_mfu(calculate_flops(args), iter_elapsed_time=args.iter_elapsed_time, num_p800_cards=args.num_p800_cards)
    with gr.Blocks(title="Compute MFU") as demo:
        gr.Markdown("## Compute MFU")
        
        with gr.Group() as custom_group:
            gr.Markdown("Standard Transformer config:")
            with gr.Row():
                is_hybrid_model = gr.Dropdown(["True", "False"], 
                                                label="hybrid model", 
                                                value="True" if args.is_hybrid_model else "False")

                group_query_attention = gr.Dropdown(["True", "False"], 
                                                label="group query attention", 
                                                value="True" if args.group_query_attention else "False")

                swiglu = gr.Dropdown(["True", "False"], 
                                        label="swiglu", 
                                        value="True" if args.swiglu else "False")

                num_layers = gr.Number(label="num layers", value=args.num_layers, precision=0)
                hidden_size = gr.Number(label="hidden size", value=args.hidden_size, precision=0)
                ffn_hidden_size = gr.Number(label="ffn hidden size", value=args.ffn_hidden_size, precision=0)
                padded_vocab_size = gr.Number(label="padded vocab size", value=args.padded_vocab_size, precision=0)
                num_attention_heads = gr.Number(label="num attention heads", value=args.num_attention_heads, precision=0)
                kv_channels = gr.Number(label="kv channels", value=args.kv_channels, precision=0)

        with gr.Group() as custom_group:
            gr.Markdown("MoE config:")
            with gr.Row():
                num_experts = gr.Number(label="num experts", value=args.num_experts, precision=0)
                moe_layer_freq = gr.Number(label="moe layer freq", value=args.moe_layer_freq, precision=0)
                moe_router_topk = gr.Number(label="moe router topk", value=args.moe_router_topk, precision=0)
                moe_ffn_hidden_size = gr.Number(label="moe ffn hidden size", value=args.moe_ffn_hidden_size, precision=0)
                moe_shared_expert_intermediate_size = gr.Number(label="moe shared expert intermediate size", value=args.moe_shared_expert_intermediate_size, precision=0)

        with gr.Group() as custom_group:
            gr.Markdown("MLA config:")
            with gr.Row():
                multi_latent_attention = gr.Dropdown(["True", "False"], 
                                                label="multi_latent_attention", 
                                                value="True" if args.multi_latent_attention else "False")
                q_lora_rank = gr.Number(label="q lora rank", value=args.q_lora_rank, precision=0)
                kv_lora_rank = gr.Number(label="kv lora rank", value=args.kv_lora_rank, precision=0)
                qk_head_dim = gr.Number(label="qk head dim", value=args.qk_head_dim, precision=0)
                v_head_dim = gr.Number(label="v head dim", value=args.v_head_dim, precision=0)
                qk_pos_emb_head_dim = gr.Number(label="qk pos emb head dim", value=args.qk_pos_emb_head_dim, precision=0)

        with gr.Group() as custom_group:
            with gr.Row():
                with gr.Group():
                    gr.Markdown("MTP config:")
                    mtp_num_layers = gr.Number(label="mtp num layers", value=args.mtp_num_layers, precision=0)

                with gr.Group():
                    gr.Markdown("Data config:")
                    with gr.Row():
                        seq_length = gr.Number(label="seq length", value=args.seq_length, precision=0)
                        batch_size = gr.Number(label="batch size", value=args.batch_size, precision=0)

                with gr.Group():
                    gr.Markdown("MFU config:")
                    with gr.Row():
                        iter_elapsed_time = gr.Number(label="iter elapsed time", value=args.iter_elapsed_time, precision=0)
                        num_p800_cards = gr.Number(label="num p800 cards", value=args.num_p800_cards, precision=0)

        # 计算结果显示控件
        with gr.Group() as custom_group:
            gr.Markdown("Compute results:")
            with gr.Row():
                model_flops = gr.Number(label="model flops", precision=0)
                flops_per_token = gr.Number(label="flops per token", precision=0)
                # mfu = gr.Number(label="mfu", precision=0)
                mfu = gr.Textbox(label="MFU P800 bf16")

        # 计算按钮
        btn = gr.Button("Calculate")
        btn.click(  fn=calculate_mfu_web, 
                    inputs=[is_hybrid_model, group_query_attention, swiglu, num_layers, hidden_size,
                        ffn_hidden_size, padded_vocab_size, num_attention_heads, kv_channels,
                        num_experts, moe_layer_freq, moe_router_topk, moe_ffn_hidden_size, moe_shared_expert_intermediate_size,
                        multi_latent_attention, q_lora_rank, kv_lora_rank, qk_head_dim, v_head_dim, qk_pos_emb_head_dim,
                        mtp_num_layers, seq_length, batch_size, iter_elapsed_time, num_p800_cards], 
                    outputs=[model_flops, flops_per_token, mfu]
                )

    # 启动 Gradio 应用
    demo.launch()