qgallouedec HF Staff commited on
Commit
f9089ef
·
verified ·
1 Parent(s): f2593be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ from trl import SFTTrainer, SFTConfig
4
+ from transformers import AutoTokenizer
5
+ import pandas as pd
6
+ import numpy as np
7
+
8
+ TRUNCATION_LENGTHS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
9
+ SEED = 42
10
+ N_SAMPLES = 1000
11
+
12
+ CODE_TEMPLATE = """
13
+ training_args = SFTConfig(
14
+ ...,
15
+ max_length={},
16
+ )"""
17
+
18
+ def benchmark(model_name, dataset_name):
19
+ print(f"Running benchmark for model: {model_name} on dataset: {dataset_name}...")
20
+
21
+ print("Loading dataset...")
22
+ dataset = load_dataset(dataset_name, split="train", streaming=True).shuffle(seed=SEED).take(N_SAMPLES)
23
+
24
+ print("Loading tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+
27
+ print("Tokenizing dataset...")
28
+ config = SFTConfig(max_length=None, bf16=False)
29
+ tokenized_dataset = SFTTrainer._prepare_dataset(
30
+ None, dataset, tokenizer, config, packing=False, formatting_func=None, dataset_name="train"
31
+ )
32
+
33
+ print("Computing the sequence lengths and total tokens")
34
+ sequence_lengths = [len(sample["input_ids"]) for sample in tokenized_dataset]
35
+ total_tokens = sum(sequence_lengths)
36
+
37
+ print("Computing the truncation ratios")
38
+ truncation_ratios = []
39
+ recommended = None
40
+ for max_len in TRUNCATION_LENGTHS:
41
+ total_truncated_tokens = sum(max(length - max_len, 0) for length in sequence_lengths)
42
+ truncation_ratio = total_truncated_tokens / total_tokens * 100
43
+ truncation_ratios.append(truncation_ratio)
44
+ if recommended is None and truncation_ratio < 5.0:
45
+ recommended = max_len
46
+
47
+ hist = np.histogram(sequence_lengths, bins=50)
48
+ lengths_distribution = pd.DataFrame({
49
+ "max_length": (hist[1][:-1] + hist[1][1:])/2,
50
+ "Ratio (%)": hist[0]/N_SAMPLES*100,
51
+ })
52
+
53
+ truncation_data = pd.DataFrame({
54
+ "max_length": [str(value) for value in TRUNCATION_LENGTHS],
55
+ "Ratio (%)": truncation_ratios,
56
+ })
57
+
58
+ return lengths_distribution, truncation_data, CODE_TEMPLATE.format(recommended)
59
+
60
+ with gr.Blocks() as demo:
61
+ model_input = gr.Textbox(label="Model Name", value="Qwen/Qwen3-0.6B")
62
+ dataset_input = gr.Textbox(label="Dataset Name", value="trl-lib/tldr")
63
+ run_button = gr.Button("Run estimation")
64
+ lengths_plot = gr.BarPlot(None, title="Length distribution", x="max_length", y="Ratio (%)")
65
+ truncation_ratio_plot = gr.BarPlot(None, title="Truncation ratio (how many tokens are discarded)", x="max_length", y="Ratio (%)")
66
+
67
+ recommended_code = gr.Code(CODE_TEMPLATE.format("..."), language="python", label="Recommended configuration")
68
+
69
+ run_button.click(fn=benchmark, inputs=[model_input, dataset_input], outputs=[lengths_plot, truncation_ratio_plot, recommended_code])
70
+
71
+ with gr.Accordion("See details", open=False):
72
+ gr.Markdown("""
73
+ This tool helps you choose an appropriate `max_length` value for your SFT training (`SFTConfig`) by analyzing the tokenized dataset.
74
+
75
+ **How it works:**
76
+ - Randomly samples 1,000 examples from your dataset.
77
+ - Prepares and tokenizes the data exactly as `SFTTrainer` would.
78
+ - Generates two visualizations:
79
+ - **Sequence Length Distribution:** Shows how long your tokenized sequences are.
80
+ - **Truncation Ratio:** Estimates the percentage of tokens that would be discarded (truncated) for different `max_length` values.
81
+ - Recommends the smallest `max_length` where truncation affects less than 5% of the tokens.
82
+
83
+ Use this tool to balance efficiency and memory usage when setting your `max_length` parameter.
84
+ """)
85
+
86
+
87
+ demo.launch()