Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import signal
|
4 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
5 |
+
import gradio as gr
|
6 |
+
import tempfile
|
7 |
+
import torch
|
8 |
+
from datasets import load_dataset
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
import re
|
11 |
+
import numpy as np
|
12 |
+
import gc
|
13 |
+
import unicodedata
|
14 |
+
from multiprocessing import cpu_count
|
15 |
+
from transformers import LlamaTokenizerFast
|
16 |
+
import fasttext
|
17 |
+
from typing import Tuple, Dict, List
|
18 |
+
import json
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import seaborn as sns
|
21 |
+
from datetime import datetime
|
22 |
+
import warnings
|
23 |
+
from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download, whoami
|
24 |
+
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
25 |
+
from pathlib import Path
|
26 |
+
from textwrap import dedent
|
27 |
+
from scipy import stats
|
28 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
29 |
+
|
30 |
+
warnings.filterwarnings('ignore')
|
31 |
+
|
32 |
+
# Environment variables
|
33 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
34 |
+
|
35 |
+
# Global variables for model caching
|
36 |
+
MODEL_CACHE_DIR = Path.home() / ".cache" / "ultra_fineweb"
|
37 |
+
MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
38 |
+
MODEL_LOADED = False
|
39 |
+
fasttext_model = None
|
40 |
+
tokenizer = None
|
41 |
+
|
42 |
+
# CSS
|
43 |
+
css = """
|
44 |
+
.gradio-container {overflow-y: auto;}
|
45 |
+
.gr-button-primary {
|
46 |
+
background-color: #ff6b00 !important;
|
47 |
+
border-color: #ff6b00 !important;
|
48 |
+
}
|
49 |
+
.gr-button-primary:hover {
|
50 |
+
background-color: #ff8534 !important;
|
51 |
+
border-color: #ff8534 !important;
|
52 |
+
}
|
53 |
+
.gr-button-secondary {
|
54 |
+
background-color: #475467 !important;
|
55 |
+
}
|
56 |
+
#login-button {
|
57 |
+
background-color: #FFD21E !important;
|
58 |
+
color: #000000 !important;
|
59 |
+
}
|
60 |
+
"""
|
61 |
+
|
62 |
+
# HTML templates
|
63 |
+
TITLE = """
|
64 |
+
<div style="text-align: center; margin-bottom: 30px;">
|
65 |
+
<h1 style="font-size: 36px; margin-bottom: 10px;">Create your own Dataset Quality Scores, blazingly fast β‘!</h1>
|
66 |
+
<p style="font-size: 16px; color: #666;">The space takes a HF dataset as input, scores it and provides statistics and quality distribution.</p>
|
67 |
+
</div>
|
68 |
+
"""
|
69 |
+
|
70 |
+
# FIXED: Added `color: #444;` to ensure text is visible on the light background.
|
71 |
+
DESCRIPTION = """
|
72 |
+
<div style="padding: 20px; background-color: #f0f0f0; border-radius: 10px; margin-bottom: 20px; color: #444;">
|
73 |
+
<h3>π How it works:</h3>
|
74 |
+
<ol>
|
75 |
+
<li>Choose a dataset from Hugging Face Hub.</li>
|
76 |
+
<li>The Ultra-FineWeb classifier will score each text sample.</li>
|
77 |
+
<li>View quality distribution and download the scored dataset.</li>
|
78 |
+
<li>Optionally, upload the results to a new repository on your Hugging Face account.</li>
|
79 |
+
</ol>
|
80 |
+
<p><strong>Note:</strong> The first run will download the model (~347MB), which may take a moment.</p>
|
81 |
+
</div>
|
82 |
+
"""
|
83 |
+
|
84 |
+
# --- Helper Functions ---
|
85 |
+
def escape(s: str) -> str:
|
86 |
+
"""Escape HTML for safe display"""
|
87 |
+
s = str(s).replace("&", "&").replace("<", "<").replace(">", ">").replace('"', """).replace("\n", "<br/>")
|
88 |
+
return s
|
89 |
+
|
90 |
+
def fasttext_preprocess(content: str, tokenizer) -> str:
|
91 |
+
"""Preprocess text for FastText model"""
|
92 |
+
if not isinstance(content, str):
|
93 |
+
return ""
|
94 |
+
content = re.sub(r'\n{3,}', '\n\n', content).lower()
|
95 |
+
content = ''.join(c for c in unicodedata.normalize('NFKD', content)
|
96 |
+
if unicodedata.category(c) != 'Mn')
|
97 |
+
token_ids = tokenizer.encode(content, add_special_tokens=False)
|
98 |
+
single_text_list = [tokenizer.decode([token_id]) for token_id in token_ids]
|
99 |
+
content = ' '.join(single_text_list)
|
100 |
+
content = re.sub(r'\n', ' n ', content)
|
101 |
+
content = re.sub(r'\r', '', content)
|
102 |
+
content = re.sub(r'\t', ' ', content)
|
103 |
+
content = re.sub(r' +', ' ', content).strip()
|
104 |
+
return content
|
105 |
+
|
106 |
+
def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
|
107 |
+
"""Run FastText inference"""
|
108 |
+
pred_label, pred_prob = model.predict(norm_content)
|
109 |
+
pred_label = pred_label[0]
|
110 |
+
_score = min(pred_prob.tolist()[0], 1.0)
|
111 |
+
if pred_label == "__label__neg":
|
112 |
+
_score = 1 - _score
|
113 |
+
return pred_label, _score
|
114 |
+
|
115 |
+
def load_models():
|
116 |
+
"""Load models with caching"""
|
117 |
+
global MODEL_LOADED, fasttext_model, tokenizer
|
118 |
+
if MODEL_LOADED:
|
119 |
+
return True
|
120 |
+
|
121 |
+
try:
|
122 |
+
model_dir = MODEL_CACHE_DIR / "Ultra-FineWeb-classifier"
|
123 |
+
if not model_dir.exists():
|
124 |
+
print("Downloading Ultra-FineWeb-classifier...")
|
125 |
+
snapshot_download(repo_id="openbmb/Ultra-FineWeb-classifier", local_dir=str(model_dir), local_dir_use_symlinks=False)
|
126 |
+
|
127 |
+
fasttext_path = model_dir / "classifiers" / "ultra_fineweb_en.bin"
|
128 |
+
tokenizer_path = model_dir / "local_tokenizer"
|
129 |
+
|
130 |
+
if not fasttext_path.exists():
|
131 |
+
raise FileNotFoundError(f"FastText model not found at {fasttext_path}")
|
132 |
+
|
133 |
+
print("Loading models...")
|
134 |
+
fasttext_model = fasttext.load_model(str(fasttext_path))
|
135 |
+
tokenizer = LlamaTokenizerFast.from_pretrained(str(tokenizer_path) if tokenizer_path.exists() else "meta-llama/Llama-2-7b-hf")
|
136 |
+
|
137 |
+
MODEL_LOADED = True
|
138 |
+
print("Models loaded successfully!")
|
139 |
+
return True
|
140 |
+
except Exception as e:
|
141 |
+
print(f"Error loading models: {e}")
|
142 |
+
gr.Warning(f"Failed to load models: {e}")
|
143 |
+
return False
|
144 |
+
|
145 |
+
def create_quality_plot(scores: List[float], dataset_name: str) -> str:
|
146 |
+
"""Create quality distribution plot"""
|
147 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
|
148 |
+
output_path = tmpfile.name
|
149 |
+
|
150 |
+
plt.figure(figsize=(10, 6))
|
151 |
+
sns.histplot(scores, bins=50, kde=True, color='#6B7FD7', edgecolor='black', line_kws={'linewidth': 2, 'color': 'red'})
|
152 |
+
|
153 |
+
mean_score = np.mean(scores)
|
154 |
+
median_score = np.median(scores)
|
155 |
+
|
156 |
+
plt.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.3f}')
|
157 |
+
plt.axvline(median_score, color='orange', linestyle=':', linewidth=2, label=f'Median: {median_score:.3f}')
|
158 |
+
|
159 |
+
plt.xlabel('Quality Score', fontsize=12)
|
160 |
+
plt.ylabel('Density', fontsize=12)
|
161 |
+
plt.title(f'Quality Score Distribution - {dataset_name}', fontsize=14, fontweight='bold')
|
162 |
+
plt.legend()
|
163 |
+
plt.grid(axis='y', alpha=0.3)
|
164 |
+
plt.xlim(0, 1)
|
165 |
+
|
166 |
+
plt.tight_layout()
|
167 |
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
168 |
+
plt.close()
|
169 |
+
|
170 |
+
return output_path
|
171 |
+
|
172 |
+
def process_dataset(
|
173 |
+
model_id: str,
|
174 |
+
dataset_split: str,
|
175 |
+
text_column: str,
|
176 |
+
sample_size: int,
|
177 |
+
batch_size: int,
|
178 |
+
progress=gr.Progress(track_tqdm=True)
|
179 |
+
) -> Tuple[str, str, str, str, gr.update, gr.update]:
|
180 |
+
"""Process dataset and return results, including visibility updates for UI components."""
|
181 |
+
|
182 |
+
try:
|
183 |
+
progress(0, desc="Loading models...")
|
184 |
+
if not load_models():
|
185 |
+
raise gr.Error("Failed to load scoring models. Please check the logs.")
|
186 |
+
|
187 |
+
progress(0.1, desc="Loading dataset...")
|
188 |
+
dataset = load_dataset(model_id, split=dataset_split, streaming=False)
|
189 |
+
|
190 |
+
if text_column not in dataset.column_names:
|
191 |
+
raise gr.Error(f"Column '{text_column}' not found. Available columns: {', '.join(dataset.column_names)}")
|
192 |
+
|
193 |
+
total_samples = len(dataset)
|
194 |
+
actual_samples = min(sample_size, total_samples)
|
195 |
+
dataset = dataset.select(range(actual_samples))
|
196 |
+
|
197 |
+
scores, scored_data = [], []
|
198 |
+
|
199 |
+
for i in tqdm(range(0, actual_samples, batch_size), desc="Scoring batches"):
|
200 |
+
batch = dataset[i:min(i + batch_size, actual_samples)]
|
201 |
+
for text in batch[text_column]:
|
202 |
+
norm_content = fasttext_preprocess(text, tokenizer)
|
203 |
+
label, score = (0.0, "__label__neg") if not norm_content else fasttext_infer(norm_content, fasttext_model)
|
204 |
+
scores.append(score)
|
205 |
+
scored_data.append({'text': text, 'quality_score': score, 'predicted_label': label})
|
206 |
+
|
207 |
+
progress(0.9, desc="Generating statistics and plot...")
|
208 |
+
stats_dict = {
|
209 |
+
'dataset_id': model_id,
|
210 |
+
'dataset_split': dataset_split,
|
211 |
+
'processed_samples': actual_samples,
|
212 |
+
'statistics': {
|
213 |
+
'mean': float(np.mean(scores)), 'median': float(np.median(scores)),
|
214 |
+
'std': float(np.std(scores)), 'min': float(np.min(scores)), 'max': float(np.max(scores)),
|
215 |
+
'p90': float(np.percentile(scores, 90)),
|
216 |
+
},
|
217 |
+
}
|
218 |
+
|
219 |
+
plot_file = create_quality_plot(scores, model_id.split('/')[-1])
|
220 |
+
|
221 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix=".jsonl", delete=False, encoding='utf-8') as f_out:
|
222 |
+
output_file_path = f_out.name
|
223 |
+
for item in scored_data:
|
224 |
+
f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
|
225 |
+
|
226 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix=".json", delete=False, encoding='utf-8') as f_stats:
|
227 |
+
stats_file_path = f_stats.name
|
228 |
+
json.dump(stats_dict, f_stats, indent=2)
|
229 |
+
|
230 |
+
summary_html = f"""
|
231 |
+
<div style="padding: 15px; background-color: #f9f9f9; border-radius: 10px;">
|
232 |
+
<h4>β
Scoring Completed!</h4>
|
233 |
+
<p><strong>Dataset:</strong> {escape(model_id)}<br>
|
234 |
+
<strong>Processed Samples:</strong> {actual_samples:,}<br>
|
235 |
+
<strong>Mean Score:</strong> {stats_dict['statistics']['mean']:.3f}<br>
|
236 |
+
<strong>Median Score:</strong> {stats_dict['statistics']['median']:.3f}</p>
|
237 |
+
</div>
|
238 |
+
"""
|
239 |
+
|
240 |
+
return summary_html, output_file_path, stats_file_path, plot_file, gr.update(visible=True), gr.update(visible=True)
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
error_html = f"""
|
244 |
+
<div style="padding: 20px; background-color: #fee; border: 1px solid #d00; border-radius: 10px;">
|
245 |
+
<h4>β Error</h4><pre style="white-space: pre-wrap; font-size: 14px;">{escape(e)}</pre>
|
246 |
+
</div>
|
247 |
+
"""
|
248 |
+
return error_html, None, None, None, gr.update(visible=False), gr.update(visible=False)
|
249 |
+
|
250 |
+
def upload_to_hub(
|
251 |
+
scored_file: str, stats_file: str, plot_file: str, new_dataset_id: str,
|
252 |
+
private: bool, hf_token: str, progress=gr.Progress(track_tqdm=True)
|
253 |
+
) -> str:
|
254 |
+
"""Upload results to Hugging Face Hub"""
|
255 |
+
if not hf_token: return 'β <span style="color: red;">Please provide your Hugging Face token.</span>'
|
256 |
+
if not all([scored_file, new_dataset_id]): return 'β <span style="color: red;">Missing scored file or new dataset ID.</span>'
|
257 |
+
|
258 |
+
try:
|
259 |
+
progress(0.1, desc="Connecting to Hub...")
|
260 |
+
api = HfApi(token=hf_token)
|
261 |
+
username = whoami(token=hf_token)["name"]
|
262 |
+
repo_id = f"{username}/{new_dataset_id}" if "/" not in new_dataset_id else new_dataset_id
|
263 |
+
|
264 |
+
progress(0.2, desc=f"Creating repo: {repo_id}")
|
265 |
+
repo_url = create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=private, token=hf_token).repo_url
|
266 |
+
|
267 |
+
progress(0.4, desc="Uploading scored dataset...")
|
268 |
+
upload_file(path_or_fileobj=scored_file, path_in_repo="data/scored_dataset.jsonl", repo_id=repo_id, repo_type="dataset", token=hf_token)
|
269 |
+
|
270 |
+
progress(0.6, desc="Uploading assets...")
|
271 |
+
if stats_file and os.path.exists(stats_file):
|
272 |
+
upload_file(path_or_fileobj=stats_file, path_in_repo="statistics.json", repo_id=repo_id, repo_type="dataset", token=hf_token)
|
273 |
+
if plot_file and os.path.exists(plot_file):
|
274 |
+
upload_file(path_or_fileobj=plot_file, path_in_repo="quality_distribution.png", repo_id=repo_id, repo_type="dataset", token=hf_token)
|
275 |
+
|
276 |
+
readme_content = dedent(f"""
|
277 |
+
---
|
278 |
+
license: apache-2.0
|
279 |
+
---
|
280 |
+
# Quality-Scored Dataset: {repo_id.split('/')[-1]}
|
281 |
+
This dataset was scored for quality using the [Dataset Quality Scorer Space](https://huggingface.co/spaces/ggml-org/dataset-quality-scorer).
|
282 |
+

|
283 |
+
## Usage
|
284 |
+
```python
|
285 |
+
from datasets import load_dataset
|
286 |
+
dataset = load_dataset("{repo_id}", split="train")
|
287 |
+
```
|
288 |
+
""").strip()
|
289 |
+
|
290 |
+
upload_file(path_or_fileobj=readme_content.encode(), path_in_repo="README.md", repo_id=repo_id, repo_type="dataset", token=hf_token)
|
291 |
+
progress(1.0, "Done!")
|
292 |
+
return f'β
<span style="color: green;">Successfully uploaded to <a href="{repo_url}" target="_blank">{repo_id}</a></span>'
|
293 |
+
|
294 |
+
except Exception as e:
|
295 |
+
return f'β <span style="color: red;">Upload failed: {escape(e)}</span>'
|
296 |
+
|
297 |
+
def create_demo():
|
298 |
+
with gr.Blocks(css=css, title="Dataset Quality Scorer") as demo:
|
299 |
+
gr.HTML(TITLE)
|
300 |
+
gr.HTML(DESCRIPTION)
|
301 |
+
|
302 |
+
gr.Markdown("### 1. Configure & Score Dataset")
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Column(scale=3):
|
305 |
+
dataset_search = HuggingfaceHubSearch(label="Hub Dataset ID", search_type="dataset", value="roneneldan/TinyStories")
|
306 |
+
text_column = gr.Textbox(label="Text Column Name", value="text", info="The column containing the text to score.")
|
307 |
+
with gr.Column(scale=2):
|
308 |
+
dataset_split = gr.Dropdown(["train", "validation", "test"], label="Split", value="train")
|
309 |
+
with gr.Row():
|
310 |
+
sample_size = gr.Number(label="Sample Size", value=1000, minimum=100, step=100, info="Max samples.")
|
311 |
+
batch_size = gr.Number(label="Batch Size", value=32, minimum=1, step=1, info="Processing batch.")
|
312 |
+
|
313 |
+
with gr.Row():
|
314 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
315 |
+
process_btn = gr.Button("π Start Scoring", variant="primary", size="lg")
|
316 |
+
|
317 |
+
# --- Results and Upload Sections (Initially Hidden) ---
|
318 |
+
with gr.Accordion("β
Results", open=True, visible=False) as results_accordion:
|
319 |
+
gr.Markdown("### 2. Review Results")
|
320 |
+
with gr.Row():
|
321 |
+
with gr.Column(scale=2):
|
322 |
+
summary_output = gr.HTML(label="Summary")
|
323 |
+
with gr.Column(scale=1):
|
324 |
+
plot_output = gr.Image(label="Quality Distribution", show_label=True)
|
325 |
+
with gr.Row():
|
326 |
+
scored_file_output = gr.File(label="π Download Scored Dataset (.jsonl)", type="filepath")
|
327 |
+
stats_file_output = gr.File(label="π Download Statistics (.json)", type="filepath")
|
328 |
+
|
329 |
+
with gr.Accordion("βοΈ Upload to Hub", open=False, visible=False) as upload_accordion:
|
330 |
+
gr.Markdown("### 3. (Optional) Upload to Hugging Face Hub")
|
331 |
+
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password", placeholder="hf_...", value=HF_TOKEN or "", info="Your HF token with 'write' permissions.")
|
332 |
+
new_dataset_id = gr.Textbox(label="New Dataset Name", placeholder="my-scored-dataset", info="Will be created under your username.")
|
333 |
+
private_checkbox = gr.Checkbox(label="Make dataset private", value=False)
|
334 |
+
upload_btn = gr.Button("π€ Upload to Hub", variant="primary")
|
335 |
+
upload_status = gr.HTML()
|
336 |
+
|
337 |
+
# --- Event Handlers ---
|
338 |
+
def clear_form():
|
339 |
+
return "roneneldan/TinyStories", "train", "text", 1000, 32, None, None, None, None, gr.update(visible=False), gr.update(visible=False), ""
|
340 |
+
|
341 |
+
clear_btn.click(
|
342 |
+
fn=clear_form,
|
343 |
+
outputs=[
|
344 |
+
dataset_search, dataset_split, text_column, sample_size, batch_size,
|
345 |
+
summary_output, scored_file_output, stats_file_output, plot_output,
|
346 |
+
results_accordion, upload_accordion, upload_status
|
347 |
+
]
|
348 |
+
)
|
349 |
+
|
350 |
+
process_btn.click(
|
351 |
+
fn=process_dataset,
|
352 |
+
inputs=[dataset_search, dataset_split, text_column, sample_size, batch_size],
|
353 |
+
outputs=[summary_output, scored_file_output, stats_file_output, plot_output, results_accordion, upload_accordion]
|
354 |
+
)
|
355 |
+
|
356 |
+
upload_btn.click(
|
357 |
+
fn=upload_to_hub,
|
358 |
+
inputs=[scored_file_output, stats_file_output, plot_output, new_dataset_id, private_checkbox, hf_token_input],
|
359 |
+
outputs=[upload_status]
|
360 |
+
)
|
361 |
+
return demo
|
362 |
+
|
363 |
+
# --- App Execution ---
|
364 |
+
demo = create_demo()
|
365 |
+
|
366 |
+
if os.environ.get("SPACE_ID"):
|
367 |
+
def restart_space():
|
368 |
+
if HF_TOKEN:
|
369 |
+
try:
|
370 |
+
print("Scheduler: Triggering space restart...")
|
371 |
+
api = HfApi()
|
372 |
+
api.restart_space(repo_id=os.environ["SPACE_ID"], token=HF_TOKEN)
|
373 |
+
except Exception as e:
|
374 |
+
print(f"Scheduler: Failed to restart space: {e}")
|
375 |
+
|
376 |
+
scheduler = BackgroundScheduler()
|
377 |
+
scheduler.add_job(restart_space, "interval", hours=6)
|
378 |
+
scheduler.start()
|
379 |
+
print("Background scheduler for periodic restarts is active.")
|
380 |
+
|
381 |
+
if __name__ == "__main__":
|
382 |
+
demo.queue().launch(debug=False, show_api=False)
|