crop_health_monitor / dataset.py
CZerion's picture
Update dataset.py
cf62d40 verified
raw
history blame contribute delete
693 Bytes
from datasets import load_dataset, concatenate_datasets
def load_disease_dataset():
pv = load_dataset("DScomp380/plant_village", split="train")
pd = load_dataset("agyaatcoder/PlantDoc", split="train")
# Normalize column names
def normalize(ds, image_col, label_col):
return ds.rename_column(image_col, "image").rename_column(label_col, "label")
# Combine into one dataset
disease_ds = concatenate_datasets([
normalize(pv, "image", "label"),
normalize(pd, "image", "label")
])
# Split into train/validation
split = disease_ds.train_test_split(test_size=0.2, stratify_by_column="label")
return split["train"], split["test"]