Spaces:
Running
Running
Update dataset.py
Browse files- dataset.py +14 -11
dataset.py
CHANGED
@@ -1,16 +1,19 @@
|
|
1 |
from datasets import load_dataset, concatenate_datasets
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
# Normalize column names
|
7 |
-
def
|
8 |
-
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
return
|
15 |
-
|
16 |
-
train_ds = disease_ds.with_transform(preprocess)
|
|
|
1 |
from datasets import load_dataset, concatenate_datasets
|
2 |
|
3 |
+
def load_disease_dataset():
|
4 |
+
pv = load_dataset("DScomp380/plant_village", split="train")
|
5 |
+
pd = load_dataset("agyaatcoder/PlantDoc", split="train")
|
6 |
|
7 |
+
# Normalize column names
|
8 |
+
def normalize(ds, image_col, label_col):
|
9 |
+
return ds.rename_column(image_col, "image").rename_column(label_col, "label")
|
10 |
|
11 |
+
# Combine into one dataset
|
12 |
+
disease_ds = concatenate_datasets([
|
13 |
+
normalize(pv, "image", "label"),
|
14 |
+
normalize(pd, "image", "label")
|
15 |
+
])
|
16 |
|
17 |
+
# Split into train/validation
|
18 |
+
split = disease_ds.train_test_split(test_size=0.2, stratify_by_column="label")
|
19 |
+
return split["train"], split["test"]
|
|
|
|