CZerion commited on
Commit
cf62d40
·
verified ·
1 Parent(s): 9692df7

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +14 -11
dataset.py CHANGED
@@ -1,16 +1,19 @@
1
  from datasets import load_dataset, concatenate_datasets
2
 
3
- pv = load_dataset("DScomp380/plant_village", split="train") # Disease
4
- pd = load_dataset("agyaatcoder/PlantDoc", split="train") # Disease
 
5
 
6
- # Normalize column names and combine
7
- def norm(ds, image_col, label_col):
8
- return ds.rename_column(image_col, "image").rename_column(label_col, "label")
9
 
10
- disease_ds = concatenate_datasets([norm(pv, "image", "label"), norm(pd, "image", "label")])
 
 
 
 
11
 
12
- def preprocess(batch):
13
- imgs = [augment(img.convert("RGB")) for img in batch["image"]]
14
- return {"pixel_values": imgs, "label": batch["label"]}
15
-
16
- train_ds = disease_ds.with_transform(preprocess)
 
1
  from datasets import load_dataset, concatenate_datasets
2
 
3
+ def load_disease_dataset():
4
+ pv = load_dataset("DScomp380/plant_village", split="train")
5
+ pd = load_dataset("agyaatcoder/PlantDoc", split="train")
6
 
7
+ # Normalize column names
8
+ def normalize(ds, image_col, label_col):
9
+ return ds.rename_column(image_col, "image").rename_column(label_col, "label")
10
 
11
+ # Combine into one dataset
12
+ disease_ds = concatenate_datasets([
13
+ normalize(pv, "image", "label"),
14
+ normalize(pd, "image", "label")
15
+ ])
16
 
17
+ # Split into train/validation
18
+ split = disease_ds.train_test_split(test_size=0.2, stratify_by_column="label")
19
+ return split["train"], split["test"]