Spaces:
Runtime error
Runtime error
fix cast index and no-labels errors
Browse files
.gitignore
CHANGED
|
@@ -129,6 +129,7 @@ venv/
|
|
| 129 |
ENV/
|
| 130 |
env.bak/
|
| 131 |
venv.bak/
|
|
|
|
| 132 |
|
| 133 |
# Spyder project settings
|
| 134 |
.spyderproject
|
|
|
|
| 129 |
ENV/
|
| 130 |
env.bak/
|
| 131 |
venv.bak/
|
| 132 |
+
.python-version
|
| 133 |
|
| 134 |
# Spyder project settings
|
| 135 |
.spyderproject
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -64,7 +64,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
| 64 |
progress(1.0, desc="Prompt generated")
|
| 65 |
data = json.loads(result)
|
| 66 |
system_prompt = data["classification_task"]
|
| 67 |
-
labels = data["labels"]
|
| 68 |
return system_prompt, labels
|
| 69 |
|
| 70 |
|
|
@@ -177,14 +177,20 @@ def generate_dataset(
|
|
| 177 |
distiset_results.append(record)
|
| 178 |
|
| 179 |
dataframe = pd.DataFrame(distiset_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
if multi_label:
|
| 181 |
dataframe["labels"] = dataframe["labels"].apply(
|
| 182 |
lambda x: list(
|
| 183 |
set(
|
| 184 |
[
|
| 185 |
-
label.lower().strip()
|
| 186 |
for label in x
|
| 187 |
-
if label is not None and label.lower().strip() in labels
|
| 188 |
]
|
| 189 |
)
|
| 190 |
)
|
|
@@ -214,6 +220,7 @@ def push_dataset_to_hub(
|
|
| 214 |
pipeline_code: str = "",
|
| 215 |
progress=gr.Progress(),
|
| 216 |
):
|
|
|
|
| 217 |
progress(0.0, desc="Validating")
|
| 218 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 219 |
progress(0.3, desc="Preprocessing")
|
|
@@ -230,7 +237,10 @@ def push_dataset_to_hub(
|
|
| 230 |
features = Features(
|
| 231 |
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
| 232 |
)
|
| 233 |
-
dataset = Dataset.from_pandas(
|
|
|
|
|
|
|
|
|
|
| 234 |
dataset = combine_datasets(repo_id, dataset)
|
| 235 |
distiset = Distiset({"default": dataset})
|
| 236 |
progress(0.9, desc="Pushing dataset")
|
|
@@ -269,6 +279,7 @@ def push_dataset(
|
|
| 269 |
num_rows=num_rows,
|
| 270 |
temperature=temperature,
|
| 271 |
)
|
|
|
|
| 272 |
push_dataset_to_hub(
|
| 273 |
dataframe,
|
| 274 |
org_name,
|
|
@@ -365,7 +376,7 @@ def push_dataset(
|
|
| 365 |
and all(label in labels for label in sample["labels"])
|
| 366 |
)
|
| 367 |
)
|
| 368 |
-
else
|
| 369 |
),
|
| 370 |
)
|
| 371 |
for sample in hf_dataset
|
|
|
|
| 64 |
progress(1.0, desc="Prompt generated")
|
| 65 |
data = json.loads(result)
|
| 66 |
system_prompt = data["classification_task"]
|
| 67 |
+
labels = get_preprocess_labels(data["labels"])
|
| 68 |
return system_prompt, labels
|
| 69 |
|
| 70 |
|
|
|
|
| 177 |
distiset_results.append(record)
|
| 178 |
|
| 179 |
dataframe = pd.DataFrame(distiset_results)
|
| 180 |
+
if (
|
| 181 |
+
not labels
|
| 182 |
+
or len(set(label.lower().strip() for label in labels if label.strip())) < 2
|
| 183 |
+
):
|
| 184 |
+
raise gr.Error(
|
| 185 |
+
"Please provide at least 2 unique, non-empty labels to classify your text."
|
| 186 |
+
)
|
| 187 |
if multi_label:
|
| 188 |
dataframe["labels"] = dataframe["labels"].apply(
|
| 189 |
lambda x: list(
|
| 190 |
set(
|
| 191 |
[
|
| 192 |
+
label.lower().strip() if (label is not None and label.lower().strip() in labels) else random.choice(labels)
|
| 193 |
for label in x
|
|
|
|
| 194 |
]
|
| 195 |
)
|
| 196 |
)
|
|
|
|
| 220 |
pipeline_code: str = "",
|
| 221 |
progress=gr.Progress(),
|
| 222 |
):
|
| 223 |
+
gr.Info(message=f"Dataframe columns in push dataset to hub: {dataframe.columns}", duration=20)
|
| 224 |
progress(0.0, desc="Validating")
|
| 225 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
| 226 |
progress(0.3, desc="Preprocessing")
|
|
|
|
| 237 |
features = Features(
|
| 238 |
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
| 239 |
)
|
| 240 |
+
dataset = Dataset.from_pandas(
|
| 241 |
+
dataframe.reset_index(drop=True),
|
| 242 |
+
features=features,
|
| 243 |
+
)
|
| 244 |
dataset = combine_datasets(repo_id, dataset)
|
| 245 |
distiset = Distiset({"default": dataset})
|
| 246 |
progress(0.9, desc="Pushing dataset")
|
|
|
|
| 279 |
num_rows=num_rows,
|
| 280 |
temperature=temperature,
|
| 281 |
)
|
| 282 |
+
gr.Info(message=f"Dataframe columns: {dataframe.columns}", duration=20)
|
| 283 |
push_dataset_to_hub(
|
| 284 |
dataframe,
|
| 285 |
org_name,
|
|
|
|
| 376 |
and all(label in labels for label in sample["labels"])
|
| 377 |
)
|
| 378 |
)
|
| 379 |
+
else None
|
| 380 |
),
|
| 381 |
)
|
| 382 |
for sample in hf_dataset
|