import os import torch from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms from config import IMAGE_SIZE class CoreDataset(Dataset): def __init__(self, image_dir, transform=None): self.image_dir = image_dir self.image_paths = [ os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ] self.transform = transform or self.default_transform() def default_transform(self): return transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return image, img_path