ParamDev's picture
Create dataset.py
ac2eaad verified
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
class Mvtec(Dataset):
def __init__(self, root_dir, object_type=None, split=None, defect_type=None, im_size=None, transform=None):
self.root_dir = root_dir
self.object_type = object_type
self.split = split
self.defect_type = defect_type # 'all' or specific defect type for test split
self.im_size = im_size
self.image_paths = [] # List to store full paths to images
self.labels = [] # List to store corresponding labels (0 for good, 1 for anomaly)
# Define default transforms if none are provided
if transform:
self.transform = transform
else:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
self.im_size = (224, 224) if im_size is None else (im_size, im_size)
normalize_tf = transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
self.transform = transforms.Compose([
transforms.Resize(tuple(self.im_size), interpolation=InterpolationMode.LANCZOS),
transforms.ToTensor(),
normalize_tf
])
self._load_data() # Call the method to populate image_paths and labels
self.num_classes = 1 # Binary classification (normal/anomaly)
def _load_data(self):
\"\"\"
Loads image paths and assigns labels based on the folder structure.
\"\"\"
# Path to the specific object type (e.g., data/bottle)
object_path = os.path.join(self.root_dir, self.object_type)
# Path to the split directory (e.g., data/bottle/train or data/bottle/test)
split_path = os.path.join(object_path, self.split)
if not os.path.isdir(split_path):
raise FileNotFoundError(f"Split directory not found: {split_path}")
if self.split == 'train':
# For training, only load images from the 'good' subdirectory
good_images_path = os.path.join(split_path, 'good')
if not os.path.isdir(good_images_path):
raise FileNotFoundError(f"Training 'good' images directory not found: {good_images_path}")
for img_name in os.listdir(good_images_path):
# Filter for common image file extensions
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
self.image_paths.append(os.path.join(good_images_path, img_name))
self.labels.append(0) # 0 for good images (normal)
elif self.split == 'test':
# For testing, iterate through all subdirectories (good and defect types)
subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
subdirs.sort() # Ensure consistent order
for subdir_name in subdirs:
# If defect_type is specified and not 'all', only load that specific defect
if self.defect_type != 'all' and subdir_name != self.defect_type and subdir_name != 'good':
continue # Skip other defect types if a specific one is requested
current_dir_path = os.path.join(split_path, subdir_name)
for img_name in os.listdir(current_dir_path):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
self.image_paths.append(os.path.join(current_dir_path, img_name))
# Label 0 for 'good', 1 for any other defect type
self.labels.append(0 if subdir_name == 'good' else 1)
else:
raise ValueError(f"Invalid split: '{self.split}'. Must be 'train' or 'test'.")
if not self.image_paths:
raise RuntimeError(f"No images found for object_type '{self.object_type}' in '{self.split}' split.")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
# Convert grayscale images to RGB if necessary
if image.mode == 'L':
image = image.convert('RGB')
image = self.transform(image)
labels = self.labels[idx] # Labels are already prepared in _load_data
sample = {'data': image, 'label': labels, 'image_path': img_path} # Added image_path for debugging/info
return sample
def getclasses(self):
classes = [str(i) for i in range(self.num_classes)]
c = dict()
for i in range(len(classes)):
c[i] = classes[i]
return c