C10X commited on
Commit
400d8ff
Β·
verified Β·
1 Parent(s): e11dbd9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +382 -0
app.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import signal
4
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
5
+ import gradio as gr
6
+ import tempfile
7
+ import torch
8
+ from datasets import load_dataset
9
+ from tqdm.auto import tqdm
10
+ import re
11
+ import numpy as np
12
+ import gc
13
+ import unicodedata
14
+ from multiprocessing import cpu_count
15
+ from transformers import LlamaTokenizerFast
16
+ import fasttext
17
+ from typing import Tuple, Dict, List
18
+ import json
19
+ import matplotlib.pyplot as plt
20
+ import seaborn as sns
21
+ from datetime import datetime
22
+ import warnings
23
+ from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download, whoami
24
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
25
+ from pathlib import Path
26
+ from textwrap import dedent
27
+ from scipy import stats
28
+ from apscheduler.schedulers.background import BackgroundScheduler
29
+
30
+ warnings.filterwarnings('ignore')
31
+
32
+ # Environment variables
33
+ HF_TOKEN = os.environ.get("HF_TOKEN")
34
+
35
+ # Global variables for model caching
36
+ MODEL_CACHE_DIR = Path.home() / ".cache" / "ultra_fineweb"
37
+ MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
38
+ MODEL_LOADED = False
39
+ fasttext_model = None
40
+ tokenizer = None
41
+
42
+ # CSS
43
+ css = """
44
+ .gradio-container {overflow-y: auto;}
45
+ .gr-button-primary {
46
+ background-color: #ff6b00 !important;
47
+ border-color: #ff6b00 !important;
48
+ }
49
+ .gr-button-primary:hover {
50
+ background-color: #ff8534 !important;
51
+ border-color: #ff8534 !important;
52
+ }
53
+ .gr-button-secondary {
54
+ background-color: #475467 !important;
55
+ }
56
+ #login-button {
57
+ background-color: #FFD21E !important;
58
+ color: #000000 !important;
59
+ }
60
+ """
61
+
62
+ # HTML templates
63
+ TITLE = """
64
+ <div style="text-align: center; margin-bottom: 30px;">
65
+ <h1 style="font-size: 36px; margin-bottom: 10px;">Create your own Dataset Quality Scores, blazingly fast ⚑!</h1>
66
+ <p style="font-size: 16px; color: #666;">The space takes a HF dataset as input, scores it and provides statistics and quality distribution.</p>
67
+ </div>
68
+ """
69
+
70
+ # FIXED: Added `color: #444;` to ensure text is visible on the light background.
71
+ DESCRIPTION = """
72
+ <div style="padding: 20px; background-color: #f0f0f0; border-radius: 10px; margin-bottom: 20px; color: #444;">
73
+ <h3>πŸ“‹ How it works:</h3>
74
+ <ol>
75
+ <li>Choose a dataset from Hugging Face Hub.</li>
76
+ <li>The Ultra-FineWeb classifier will score each text sample.</li>
77
+ <li>View quality distribution and download the scored dataset.</li>
78
+ <li>Optionally, upload the results to a new repository on your Hugging Face account.</li>
79
+ </ol>
80
+ <p><strong>Note:</strong> The first run will download the model (~347MB), which may take a moment.</p>
81
+ </div>
82
+ """
83
+
84
+ # --- Helper Functions ---
85
+ def escape(s: str) -> str:
86
+ """Escape HTML for safe display"""
87
+ s = str(s).replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace('"', "&quot;").replace("\n", "<br/>")
88
+ return s
89
+
90
+ def fasttext_preprocess(content: str, tokenizer) -> str:
91
+ """Preprocess text for FastText model"""
92
+ if not isinstance(content, str):
93
+ return ""
94
+ content = re.sub(r'\n{3,}', '\n\n', content).lower()
95
+ content = ''.join(c for c in unicodedata.normalize('NFKD', content)
96
+ if unicodedata.category(c) != 'Mn')
97
+ token_ids = tokenizer.encode(content, add_special_tokens=False)
98
+ single_text_list = [tokenizer.decode([token_id]) for token_id in token_ids]
99
+ content = ' '.join(single_text_list)
100
+ content = re.sub(r'\n', ' n ', content)
101
+ content = re.sub(r'\r', '', content)
102
+ content = re.sub(r'\t', ' ', content)
103
+ content = re.sub(r' +', ' ', content).strip()
104
+ return content
105
+
106
+ def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
107
+ """Run FastText inference"""
108
+ pred_label, pred_prob = model.predict(norm_content)
109
+ pred_label = pred_label[0]
110
+ _score = min(pred_prob.tolist()[0], 1.0)
111
+ if pred_label == "__label__neg":
112
+ _score = 1 - _score
113
+ return pred_label, _score
114
+
115
+ def load_models():
116
+ """Load models with caching"""
117
+ global MODEL_LOADED, fasttext_model, tokenizer
118
+ if MODEL_LOADED:
119
+ return True
120
+
121
+ try:
122
+ model_dir = MODEL_CACHE_DIR / "Ultra-FineWeb-classifier"
123
+ if not model_dir.exists():
124
+ print("Downloading Ultra-FineWeb-classifier...")
125
+ snapshot_download(repo_id="openbmb/Ultra-FineWeb-classifier", local_dir=str(model_dir), local_dir_use_symlinks=False)
126
+
127
+ fasttext_path = model_dir / "classifiers" / "ultra_fineweb_en.bin"
128
+ tokenizer_path = model_dir / "local_tokenizer"
129
+
130
+ if not fasttext_path.exists():
131
+ raise FileNotFoundError(f"FastText model not found at {fasttext_path}")
132
+
133
+ print("Loading models...")
134
+ fasttext_model = fasttext.load_model(str(fasttext_path))
135
+ tokenizer = LlamaTokenizerFast.from_pretrained(str(tokenizer_path) if tokenizer_path.exists() else "meta-llama/Llama-2-7b-hf")
136
+
137
+ MODEL_LOADED = True
138
+ print("Models loaded successfully!")
139
+ return True
140
+ except Exception as e:
141
+ print(f"Error loading models: {e}")
142
+ gr.Warning(f"Failed to load models: {e}")
143
+ return False
144
+
145
+ def create_quality_plot(scores: List[float], dataset_name: str) -> str:
146
+ """Create quality distribution plot"""
147
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
148
+ output_path = tmpfile.name
149
+
150
+ plt.figure(figsize=(10, 6))
151
+ sns.histplot(scores, bins=50, kde=True, color='#6B7FD7', edgecolor='black', line_kws={'linewidth': 2, 'color': 'red'})
152
+
153
+ mean_score = np.mean(scores)
154
+ median_score = np.median(scores)
155
+
156
+ plt.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.3f}')
157
+ plt.axvline(median_score, color='orange', linestyle=':', linewidth=2, label=f'Median: {median_score:.3f}')
158
+
159
+ plt.xlabel('Quality Score', fontsize=12)
160
+ plt.ylabel('Density', fontsize=12)
161
+ plt.title(f'Quality Score Distribution - {dataset_name}', fontsize=14, fontweight='bold')
162
+ plt.legend()
163
+ plt.grid(axis='y', alpha=0.3)
164
+ plt.xlim(0, 1)
165
+
166
+ plt.tight_layout()
167
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
168
+ plt.close()
169
+
170
+ return output_path
171
+
172
+ def process_dataset(
173
+ model_id: str,
174
+ dataset_split: str,
175
+ text_column: str,
176
+ sample_size: int,
177
+ batch_size: int,
178
+ progress=gr.Progress(track_tqdm=True)
179
+ ) -> Tuple[str, str, str, str, gr.update, gr.update]:
180
+ """Process dataset and return results, including visibility updates for UI components."""
181
+
182
+ try:
183
+ progress(0, desc="Loading models...")
184
+ if not load_models():
185
+ raise gr.Error("Failed to load scoring models. Please check the logs.")
186
+
187
+ progress(0.1, desc="Loading dataset...")
188
+ dataset = load_dataset(model_id, split=dataset_split, streaming=False)
189
+
190
+ if text_column not in dataset.column_names:
191
+ raise gr.Error(f"Column '{text_column}' not found. Available columns: {', '.join(dataset.column_names)}")
192
+
193
+ total_samples = len(dataset)
194
+ actual_samples = min(sample_size, total_samples)
195
+ dataset = dataset.select(range(actual_samples))
196
+
197
+ scores, scored_data = [], []
198
+
199
+ for i in tqdm(range(0, actual_samples, batch_size), desc="Scoring batches"):
200
+ batch = dataset[i:min(i + batch_size, actual_samples)]
201
+ for text in batch[text_column]:
202
+ norm_content = fasttext_preprocess(text, tokenizer)
203
+ label, score = (0.0, "__label__neg") if not norm_content else fasttext_infer(norm_content, fasttext_model)
204
+ scores.append(score)
205
+ scored_data.append({'text': text, 'quality_score': score, 'predicted_label': label})
206
+
207
+ progress(0.9, desc="Generating statistics and plot...")
208
+ stats_dict = {
209
+ 'dataset_id': model_id,
210
+ 'dataset_split': dataset_split,
211
+ 'processed_samples': actual_samples,
212
+ 'statistics': {
213
+ 'mean': float(np.mean(scores)), 'median': float(np.median(scores)),
214
+ 'std': float(np.std(scores)), 'min': float(np.min(scores)), 'max': float(np.max(scores)),
215
+ 'p90': float(np.percentile(scores, 90)),
216
+ },
217
+ }
218
+
219
+ plot_file = create_quality_plot(scores, model_id.split('/')[-1])
220
+
221
+ with tempfile.NamedTemporaryFile(mode='w', suffix=".jsonl", delete=False, encoding='utf-8') as f_out:
222
+ output_file_path = f_out.name
223
+ for item in scored_data:
224
+ f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
225
+
226
+ with tempfile.NamedTemporaryFile(mode='w', suffix=".json", delete=False, encoding='utf-8') as f_stats:
227
+ stats_file_path = f_stats.name
228
+ json.dump(stats_dict, f_stats, indent=2)
229
+
230
+ summary_html = f"""
231
+ <div style="padding: 15px; background-color: #f9f9f9; border-radius: 10px;">
232
+ <h4>βœ… Scoring Completed!</h4>
233
+ <p><strong>Dataset:</strong> {escape(model_id)}<br>
234
+ <strong>Processed Samples:</strong> {actual_samples:,}<br>
235
+ <strong>Mean Score:</strong> {stats_dict['statistics']['mean']:.3f}<br>
236
+ <strong>Median Score:</strong> {stats_dict['statistics']['median']:.3f}</p>
237
+ </div>
238
+ """
239
+
240
+ return summary_html, output_file_path, stats_file_path, plot_file, gr.update(visible=True), gr.update(visible=True)
241
+
242
+ except Exception as e:
243
+ error_html = f"""
244
+ <div style="padding: 20px; background-color: #fee; border: 1px solid #d00; border-radius: 10px;">
245
+ <h4>❌ Error</h4><pre style="white-space: pre-wrap; font-size: 14px;">{escape(e)}</pre>
246
+ </div>
247
+ """
248
+ return error_html, None, None, None, gr.update(visible=False), gr.update(visible=False)
249
+
250
+ def upload_to_hub(
251
+ scored_file: str, stats_file: str, plot_file: str, new_dataset_id: str,
252
+ private: bool, hf_token: str, progress=gr.Progress(track_tqdm=True)
253
+ ) -> str:
254
+ """Upload results to Hugging Face Hub"""
255
+ if not hf_token: return '❌ <span style="color: red;">Please provide your Hugging Face token.</span>'
256
+ if not all([scored_file, new_dataset_id]): return '❌ <span style="color: red;">Missing scored file or new dataset ID.</span>'
257
+
258
+ try:
259
+ progress(0.1, desc="Connecting to Hub...")
260
+ api = HfApi(token=hf_token)
261
+ username = whoami(token=hf_token)["name"]
262
+ repo_id = f"{username}/{new_dataset_id}" if "/" not in new_dataset_id else new_dataset_id
263
+
264
+ progress(0.2, desc=f"Creating repo: {repo_id}")
265
+ repo_url = create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=private, token=hf_token).repo_url
266
+
267
+ progress(0.4, desc="Uploading scored dataset...")
268
+ upload_file(path_or_fileobj=scored_file, path_in_repo="data/scored_dataset.jsonl", repo_id=repo_id, repo_type="dataset", token=hf_token)
269
+
270
+ progress(0.6, desc="Uploading assets...")
271
+ if stats_file and os.path.exists(stats_file):
272
+ upload_file(path_or_fileobj=stats_file, path_in_repo="statistics.json", repo_id=repo_id, repo_type="dataset", token=hf_token)
273
+ if plot_file and os.path.exists(plot_file):
274
+ upload_file(path_or_fileobj=plot_file, path_in_repo="quality_distribution.png", repo_id=repo_id, repo_type="dataset", token=hf_token)
275
+
276
+ readme_content = dedent(f"""
277
+ ---
278
+ license: apache-2.0
279
+ ---
280
+ # Quality-Scored Dataset: {repo_id.split('/')[-1]}
281
+ This dataset was scored for quality using the [Dataset Quality Scorer Space](https://huggingface.co/spaces/ggml-org/dataset-quality-scorer).
282
+ ![Quality Distribution](quality_distribution.png)
283
+ ## Usage
284
+ ```python
285
+ from datasets import load_dataset
286
+ dataset = load_dataset("{repo_id}", split="train")
287
+ ```
288
+ """).strip()
289
+
290
+ upload_file(path_or_fileobj=readme_content.encode(), path_in_repo="README.md", repo_id=repo_id, repo_type="dataset", token=hf_token)
291
+ progress(1.0, "Done!")
292
+ return f'βœ… <span style="color: green;">Successfully uploaded to <a href="{repo_url}" target="_blank">{repo_id}</a></span>'
293
+
294
+ except Exception as e:
295
+ return f'❌ <span style="color: red;">Upload failed: {escape(e)}</span>'
296
+
297
+ def create_demo():
298
+ with gr.Blocks(css=css, title="Dataset Quality Scorer") as demo:
299
+ gr.HTML(TITLE)
300
+ gr.HTML(DESCRIPTION)
301
+
302
+ gr.Markdown("### 1. Configure & Score Dataset")
303
+ with gr.Row():
304
+ with gr.Column(scale=3):
305
+ dataset_search = HuggingfaceHubSearch(label="Hub Dataset ID", search_type="dataset", value="roneneldan/TinyStories")
306
+ text_column = gr.Textbox(label="Text Column Name", value="text", info="The column containing the text to score.")
307
+ with gr.Column(scale=2):
308
+ dataset_split = gr.Dropdown(["train", "validation", "test"], label="Split", value="train")
309
+ with gr.Row():
310
+ sample_size = gr.Number(label="Sample Size", value=1000, minimum=100, step=100, info="Max samples.")
311
+ batch_size = gr.Number(label="Batch Size", value=32, minimum=1, step=1, info="Processing batch.")
312
+
313
+ with gr.Row():
314
+ clear_btn = gr.Button("Clear", variant="secondary")
315
+ process_btn = gr.Button("πŸš€ Start Scoring", variant="primary", size="lg")
316
+
317
+ # --- Results and Upload Sections (Initially Hidden) ---
318
+ with gr.Accordion("βœ… Results", open=True, visible=False) as results_accordion:
319
+ gr.Markdown("### 2. Review Results")
320
+ with gr.Row():
321
+ with gr.Column(scale=2):
322
+ summary_output = gr.HTML(label="Summary")
323
+ with gr.Column(scale=1):
324
+ plot_output = gr.Image(label="Quality Distribution", show_label=True)
325
+ with gr.Row():
326
+ scored_file_output = gr.File(label="πŸ“„ Download Scored Dataset (.jsonl)", type="filepath")
327
+ stats_file_output = gr.File(label="πŸ“Š Download Statistics (.json)", type="filepath")
328
+
329
+ with gr.Accordion("☁️ Upload to Hub", open=False, visible=False) as upload_accordion:
330
+ gr.Markdown("### 3. (Optional) Upload to Hugging Face Hub")
331
+ hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", placeholder="hf_...", value=HF_TOKEN or "", info="Your HF token with 'write' permissions.")
332
+ new_dataset_id = gr.Textbox(label="New Dataset Name", placeholder="my-scored-dataset", info="Will be created under your username.")
333
+ private_checkbox = gr.Checkbox(label="Make dataset private", value=False)
334
+ upload_btn = gr.Button("πŸ“€ Upload to Hub", variant="primary")
335
+ upload_status = gr.HTML()
336
+
337
+ # --- Event Handlers ---
338
+ def clear_form():
339
+ return "roneneldan/TinyStories", "train", "text", 1000, 32, None, None, None, None, gr.update(visible=False), gr.update(visible=False), ""
340
+
341
+ clear_btn.click(
342
+ fn=clear_form,
343
+ outputs=[
344
+ dataset_search, dataset_split, text_column, sample_size, batch_size,
345
+ summary_output, scored_file_output, stats_file_output, plot_output,
346
+ results_accordion, upload_accordion, upload_status
347
+ ]
348
+ )
349
+
350
+ process_btn.click(
351
+ fn=process_dataset,
352
+ inputs=[dataset_search, dataset_split, text_column, sample_size, batch_size],
353
+ outputs=[summary_output, scored_file_output, stats_file_output, plot_output, results_accordion, upload_accordion]
354
+ )
355
+
356
+ upload_btn.click(
357
+ fn=upload_to_hub,
358
+ inputs=[scored_file_output, stats_file_output, plot_output, new_dataset_id, private_checkbox, hf_token_input],
359
+ outputs=[upload_status]
360
+ )
361
+ return demo
362
+
363
+ # --- App Execution ---
364
+ demo = create_demo()
365
+
366
+ if os.environ.get("SPACE_ID"):
367
+ def restart_space():
368
+ if HF_TOKEN:
369
+ try:
370
+ print("Scheduler: Triggering space restart...")
371
+ api = HfApi()
372
+ api.restart_space(repo_id=os.environ["SPACE_ID"], token=HF_TOKEN)
373
+ except Exception as e:
374
+ print(f"Scheduler: Failed to restart space: {e}")
375
+
376
+ scheduler = BackgroundScheduler()
377
+ scheduler.add_job(restart_space, "interval", hours=6)
378
+ scheduler.start()
379
+ print("Background scheduler for periodic restarts is active.")
380
+
381
+ if __name__ == "__main__":
382
+ demo.queue().launch(debug=False, show_api=False)