import os import pandas as pd from torch.utils.data import Dataset from PIL import Image import torchvision.transforms as transforms from torchvision.transforms.functional import InterpolationMode class Mvtec(Dataset): def __init__(self, root_dir, object_type=None, split=None, defect_type=None, im_size=None, transform=None): self.root_dir = root_dir self.object_type = object_type self.split = split self.defect_type = defect_type # 'all' or specific defect type for test split self.im_size = im_size self.image_paths = [] # List to store full paths to images self.labels = [] # List to store corresponding labels (0 for good, 1 for anomaly) # Define default transforms if none are provided if transform: self.transform = transform else: imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] self.im_size = (224, 224) if im_size is None else (im_size, im_size) normalize_tf = transforms.Normalize(mean=imagenet_mean, std=imagenet_std) self.transform = transforms.Compose([ transforms.Resize(tuple(self.im_size), interpolation=InterpolationMode.LANCZOS), transforms.ToTensor(), normalize_tf ]) self._load_data() # Call the method to populate image_paths and labels self.num_classes = 1 # Binary classification (normal/anomaly) def _load_data(self): \"\"\" Loads image paths and assigns labels based on the folder structure. \"\"\" # Path to the specific object type (e.g., data/bottle) object_path = os.path.join(self.root_dir, self.object_type) # Path to the split directory (e.g., data/bottle/train or data/bottle/test) split_path = os.path.join(object_path, self.split) if not os.path.isdir(split_path): raise FileNotFoundError(f"Split directory not found: {split_path}") if self.split == 'train': # For training, only load images from the 'good' subdirectory good_images_path = os.path.join(split_path, 'good') if not os.path.isdir(good_images_path): raise FileNotFoundError(f"Training 'good' images directory not found: {good_images_path}") for img_name in os.listdir(good_images_path): # Filter for common image file extensions if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): self.image_paths.append(os.path.join(good_images_path, img_name)) self.labels.append(0) # 0 for good images (normal) elif self.split == 'test': # For testing, iterate through all subdirectories (good and defect types) subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))] subdirs.sort() # Ensure consistent order for subdir_name in subdirs: # If defect_type is specified and not 'all', only load that specific defect if self.defect_type != 'all' and subdir_name != self.defect_type and subdir_name != 'good': continue # Skip other defect types if a specific one is requested current_dir_path = os.path.join(split_path, subdir_name) for img_name in os.listdir(current_dir_path): if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): self.image_paths.append(os.path.join(current_dir_path, img_name)) # Label 0 for 'good', 1 for any other defect type self.labels.append(0 if subdir_name == 'good' else 1) else: raise ValueError(f"Invalid split: '{self.split}'. Must be 'train' or 'test'.") if not self.image_paths: raise RuntimeError(f"No images found for object_type '{self.object_type}' in '{self.split}' split.") def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path) # Convert grayscale images to RGB if necessary if image.mode == 'L': image = image.convert('RGB') image = self.transform(image) labels = self.labels[idx] # Labels are already prepared in _load_data sample = {'data': image, 'label': labels, 'image_path': img_path} # Added image_path for debugging/info return sample def getclasses(self): classes = [str(i) for i in range(self.num_classes)] c = dict() for i in range(len(classes)): c[i] = classes[i] return c