Commit
·
6f3d06e
1
Parent(s):
70abf20
fix returning duplicate labels
Browse files
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
|
@@ -186,13 +186,15 @@ def generate_dataset(
|
|
| 186 |
if isinstance(x, str): # single label
|
| 187 |
return [x.lower().strip()]
|
| 188 |
elif isinstance(x, list): # multiple labels
|
| 189 |
-
return
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
else:
|
| 195 |
-
return [random.choice(labels)]
|
| 196 |
|
| 197 |
dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
|
| 198 |
dataframe = dataframe[dataframe["labels"].notna()]
|
|
|
|
| 186 |
if isinstance(x, str): # single label
|
| 187 |
return [x.lower().strip()]
|
| 188 |
elif isinstance(x, list): # multiple labels
|
| 189 |
+
return list(
|
| 190 |
+
set(
|
| 191 |
+
label.lower().strip()
|
| 192 |
+
for label in x
|
| 193 |
+
if label.lower().strip() in labels
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
else:
|
| 197 |
+
return list(set([random.choice(labels)]))
|
| 198 |
|
| 199 |
dataframe["labels"] = dataframe["labels"].apply(_validate_labels)
|
| 200 |
dataframe = dataframe[dataframe["labels"].notna()]
|