Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -20,8 +20,7 @@ import matplotlib.pyplot as plt
|
|
20 |
import seaborn as sns
|
21 |
from datetime import datetime
|
22 |
import warnings
|
23 |
-
from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download, whoami
|
24 |
-
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
25 |
from pathlib import Path
|
26 |
from textwrap import dedent
|
27 |
from scipy import stats
|
@@ -78,10 +77,10 @@ DESCRIPTION_MD = """
|
|
78 |
def escape(s: str) -> str:
|
79 |
"""Escape special characters for safe HTML display."""
|
80 |
s = str(s)
|
81 |
-
s = s.replace("&", "&")
|
82 |
-
s = s.replace("<", "
|
83 |
-
s = s.replace(">", "
|
84 |
-
s = s.replace('"', ""
|
85 |
s = s.replace("\n", "<br/>")
|
86 |
return s
|
87 |
|
@@ -95,29 +94,68 @@ def fasttext_preprocess(content: str, tokenizer) -> str:
|
|
95 |
return re.sub(r' +', ' ', content).strip()
|
96 |
|
97 |
def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def load_models():
|
106 |
global MODEL_LOADED, fasttext_model, tokenizer
|
107 |
-
if MODEL_LOADED
|
|
|
|
|
108 |
try:
|
109 |
model_dir = MODEL_CACHE_DIR / "Ultra-FineWeb-classifier"
|
110 |
if not model_dir.exists():
|
111 |
snapshot_download(repo_id="openbmb/Ultra-FineWeb-classifier", local_dir=str(model_dir), local_dir_use_symlinks=False)
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
116 |
MODEL_LOADED = True
|
117 |
-
return
|
118 |
except Exception as e:
|
119 |
-
|
120 |
-
return
|
121 |
|
122 |
def create_quality_plot(scores: List[float], dataset_name: str) -> str:
|
123 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
|
@@ -261,7 +299,11 @@ def create_demo():
|
|
261 |
with gr.Row():
|
262 |
with gr.Column(scale=3):
|
263 |
gr.Markdown("### 1. Configure Dataset")
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
text_column = gr.Textbox(label="Text Column Name", value="text")
|
266 |
with gr.Column(scale=2):
|
267 |
gr.Markdown("### 2. Configure Scoring")
|
@@ -304,15 +346,15 @@ def create_demo():
|
|
304 |
|
305 |
process_btn.click(
|
306 |
fn=process_dataset,
|
307 |
-
inputs=[
|
308 |
outputs=outputs_list
|
309 |
)
|
310 |
|
311 |
clear_btn.click(
|
312 |
fn=clear_form,
|
313 |
outputs=[
|
314 |
-
|
315 |
-
|
316 |
results_group, upload_group, upload_status
|
317 |
]
|
318 |
)
|
|
|
20 |
import seaborn as sns
|
21 |
from datetime import datetime
|
22 |
import warnings
|
23 |
+
from huggingface_hub import HfApi, create_repo, upload_file, snapshot_download, whoami, HfFolder
|
|
|
24 |
from pathlib import Path
|
25 |
from textwrap import dedent
|
26 |
from scipy import stats
|
|
|
77 |
def escape(s: str) -> str:
|
78 |
"""Escape special characters for safe HTML display."""
|
79 |
s = str(s)
|
80 |
+
s = s.replace("&", "&")
|
81 |
+
s = s.replace("<", "<")
|
82 |
+
s = s.replace(">", ">")
|
83 |
+
s = s.replace('"', """)
|
84 |
s = s.replace("\n", "<br/>")
|
85 |
return s
|
86 |
|
|
|
94 |
return re.sub(r' +', ' ', content).strip()
|
95 |
|
96 |
def fasttext_infer(norm_content: str, model) -> Tuple[str, float]:
|
97 |
+
"""Run inference using the FastText model.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
norm_content: Normalized text content to score
|
101 |
+
model: Loaded FastText model
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Tuple of (predicted_label, score) where score is between 0 and 1
|
105 |
+
"""
|
106 |
+
try:
|
107 |
+
# Get prediction from model
|
108 |
+
pred_label, pred_prob = model.predict(norm_content)
|
109 |
+
|
110 |
+
# Handle different label formats
|
111 |
+
if isinstance(pred_label, (list, np.ndarray)) and len(pred_label) > 0:
|
112 |
+
pred_label = pred_label[0]
|
113 |
+
|
114 |
+
# Default score if we can't process it
|
115 |
+
score = 0.5
|
116 |
+
|
117 |
+
# Handle different probability formats
|
118 |
+
if pred_prob is not None:
|
119 |
+
# If it's a numpy array, convert to list
|
120 |
+
if hasattr(pred_prob, 'tolist'):
|
121 |
+
pred_prob = pred_prob.tolist()
|
122 |
+
|
123 |
+
# Handle list/array formats
|
124 |
+
if isinstance(pred_prob, (list, np.ndarray)) and len(pred_prob) > 0:
|
125 |
+
# Get first element if it's a nested structure
|
126 |
+
first_prob = pred_prob[0] if not isinstance(pred_prob[0], (list, np.ndarray)) else pred_prob[0][0]
|
127 |
+
score = float(first_prob)
|
128 |
+
else:
|
129 |
+
# Try direct conversion if it's a single value
|
130 |
+
score = float(pred_prob)
|
131 |
+
|
132 |
+
# Ensure score is between 0 and 1
|
133 |
+
score = max(0.0, min(1.0, score))
|
134 |
+
return pred_label, score
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
print(f"Error in fasttext_infer: {e}")
|
138 |
+
return "__label__neg", 0.0
|
139 |
|
140 |
def load_models():
|
141 |
global MODEL_LOADED, fasttext_model, tokenizer
|
142 |
+
if MODEL_LOADED and tokenizer is not None and fasttext_model is not None:
|
143 |
+
return tokenizer, fasttext_model
|
144 |
+
|
145 |
try:
|
146 |
model_dir = MODEL_CACHE_DIR / "Ultra-FineWeb-classifier"
|
147 |
if not model_dir.exists():
|
148 |
snapshot_download(repo_id="openbmb/Ultra-FineWeb-classifier", local_dir=str(model_dir), local_dir_use_symlinks=False)
|
149 |
+
|
150 |
+
# Load tokenizer and model
|
151 |
+
tokenizer = LlamaTokenizerFast.from_pretrained(str(model_dir / "tokenizer"))
|
152 |
+
fasttext_model = fasttext.load_model(str(model_dir / "classifier.bin"))
|
153 |
+
|
154 |
MODEL_LOADED = True
|
155 |
+
return tokenizer, fasttext_model
|
156 |
except Exception as e:
|
157 |
+
print(f"Error loading models: {e}")
|
158 |
+
return None, None
|
159 |
|
160 |
def create_quality_plot(scores: List[float], dataset_name: str) -> str:
|
161 |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
|
|
|
299 |
with gr.Row():
|
300 |
with gr.Column(scale=3):
|
301 |
gr.Markdown("### 1. Configure Dataset")
|
302 |
+
dataset_id = gr.Textbox(
|
303 |
+
label="Hugging Face Dataset ID",
|
304 |
+
value="roneneldan/TinyStories",
|
305 |
+
placeholder="username/dataset_name"
|
306 |
+
)
|
307 |
text_column = gr.Textbox(label="Text Column Name", value="text")
|
308 |
with gr.Column(scale=2):
|
309 |
gr.Markdown("### 2. Configure Scoring")
|
|
|
346 |
|
347 |
process_btn.click(
|
348 |
fn=process_dataset,
|
349 |
+
inputs=[dataset_id, dataset_split, text_column, sample_size, batch_size],
|
350 |
outputs=outputs_list
|
351 |
)
|
352 |
|
353 |
clear_btn.click(
|
354 |
fn=clear_form,
|
355 |
outputs=[
|
356 |
+
dataset_id, dataset_split, text_column, sample_size, batch_size, live_log,
|
357 |
+
summary_output, scored_file_output, stats_file_output, plot_output,
|
358 |
results_group, upload_group, upload_status
|
359 |
]
|
360 |
)
|