Spaces:
Sleeping
Sleeping
Update dataset.py
Browse files- dataset.py +14 -11
dataset.py
CHANGED
|
@@ -1,16 +1,19 @@
|
|
| 1 |
from datasets import load_dataset, concatenate_datasets
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
# Normalize column names
|
| 7 |
-
def
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
return
|
| 15 |
-
|
| 16 |
-
train_ds = disease_ds.with_transform(preprocess)
|
|
|
|
| 1 |
from datasets import load_dataset, concatenate_datasets
|
| 2 |
|
| 3 |
+
def load_disease_dataset():
|
| 4 |
+
pv = load_dataset("DScomp380/plant_village", split="train")
|
| 5 |
+
pd = load_dataset("agyaatcoder/PlantDoc", split="train")
|
| 6 |
|
| 7 |
+
# Normalize column names
|
| 8 |
+
def normalize(ds, image_col, label_col):
|
| 9 |
+
return ds.rename_column(image_col, "image").rename_column(label_col, "label")
|
| 10 |
|
| 11 |
+
# Combine into one dataset
|
| 12 |
+
disease_ds = concatenate_datasets([
|
| 13 |
+
normalize(pv, "image", "label"),
|
| 14 |
+
normalize(pd, "image", "label")
|
| 15 |
+
])
|
| 16 |
|
| 17 |
+
# Split into train/validation
|
| 18 |
+
split = disease_ds.train_test_split(test_size=0.2, stratify_by_column="label")
|
| 19 |
+
return split["train"], split["test"]
|
|
|
|
|
|