Geologist_AI / core_dataset.py
solfedge's picture
Upload 9 files
71c32d5 verified
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from config import IMAGE_SIZE
class CoreDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.image_paths = [
os.path.join(image_dir, f)
for f in os.listdir(image_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
self.transform = transform or self.default_transform()
def default_transform(self):
return transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, img_path