Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from config import IMAGE_SIZE | |
class CoreDataset(Dataset): | |
def __init__(self, image_dir, transform=None): | |
self.image_dir = image_dir | |
self.image_paths = [ | |
os.path.join(image_dir, f) | |
for f in os.listdir(image_dir) | |
if f.lower().endswith(('.png', '.jpg', '.jpeg')) | |
] | |
self.transform = transform or self.default_transform() | |
def default_transform(self): | |
return transforms.Compose([ | |
transforms.Resize(IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
img_path = self.image_paths[idx] | |
image = Image.open(img_path).convert("RGB") | |
if self.transform: | |
image = self.transform(image) | |
return image, img_path | |