File size: 693 Bytes
617796d
 
cf62d40
 
 
617796d
cf62d40
 
 
617796d
cf62d40
 
 
 
 
1481619
cf62d40
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from datasets import load_dataset, concatenate_datasets

def load_disease_dataset():
    pv = load_dataset("DScomp380/plant_village", split="train")  
    pd = load_dataset("agyaatcoder/PlantDoc", split="train")  

    # Normalize column names
    def normalize(ds, image_col, label_col):
        return ds.rename_column(image_col, "image").rename_column(label_col, "label")

    # Combine into one dataset
    disease_ds = concatenate_datasets([
        normalize(pv, "image", "label"),
        normalize(pd, "image", "label")
    ])

    # Split into train/validation
    split = disease_ds.train_test_split(test_size=0.2, stratify_by_column="label")
    return split["train"], split["test"]