from datasets import load_dataset, concatenate_datasets def load_disease_dataset(): pv = load_dataset("DScomp380/plant_village", split="train") pd = load_dataset("agyaatcoder/PlantDoc", split="train") # Normalize column names def normalize(ds, image_col, label_col): return ds.rename_column(image_col, "image").rename_column(label_col, "label") # Combine into one dataset disease_ds = concatenate_datasets([ normalize(pv, "image", "label"), normalize(pd, "image", "label") ]) # Split into train/validation split = disease_ds.train_test_split(test_size=0.2, stratify_by_column="label") return split["train"], split["test"]