ParamDev's picture
Create dataset.py
ac2eaad verified
raw
history blame
4.87 kB
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
class Mvtec(Dataset):
def __init__(self, root_dir, object_type=None, split=None, defect_type=None, im_size=None, transform=None):
self.root_dir = root_dir
self.object_type = object_type
self.split = split
self.defect_type = defect_type # 'all' or specific defect type for test split
self.im_size = im_size
self.image_paths = [] # List to store full paths to images
self.labels = [] # List to store corresponding labels (0 for good, 1 for anomaly)
# Define default transforms if none are provided
if transform:
self.transform = transform
else:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
self.im_size = (224, 224) if im_size is None else (im_size, im_size)
normalize_tf = transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
self.transform = transforms.Compose([
transforms.Resize(tuple(self.im_size), interpolation=InterpolationMode.LANCZOS),
transforms.ToTensor(),
normalize_tf
])
self._load_data() # Call the method to populate image_paths and labels
self.num_classes = 1 # Binary classification (normal/anomaly)
def _load_data(self):
\"\"\"
Loads image paths and assigns labels based on the folder structure.
\"\"\"
# Path to the specific object type (e.g., data/bottle)
object_path = os.path.join(self.root_dir, self.object_type)
# Path to the split directory (e.g., data/bottle/train or data/bottle/test)
split_path = os.path.join(object_path, self.split)
if not os.path.isdir(split_path):
raise FileNotFoundError(f"Split directory not found: {split_path}")
if self.split == 'train':
# For training, only load images from the 'good' subdirectory
good_images_path = os.path.join(split_path, 'good')
if not os.path.isdir(good_images_path):
raise FileNotFoundError(f"Training 'good' images directory not found: {good_images_path}")
for img_name in os.listdir(good_images_path):
# Filter for common image file extensions
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
self.image_paths.append(os.path.join(good_images_path, img_name))
self.labels.append(0) # 0 for good images (normal)
elif self.split == 'test':
# For testing, iterate through all subdirectories (good and defect types)
subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
subdirs.sort() # Ensure consistent order
for subdir_name in subdirs:
# If defect_type is specified and not 'all', only load that specific defect
if self.defect_type != 'all' and subdir_name != self.defect_type and subdir_name != 'good':
continue # Skip other defect types if a specific one is requested
current_dir_path = os.path.join(split_path, subdir_name)
for img_name in os.listdir(current_dir_path):
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
self.image_paths.append(os.path.join(current_dir_path, img_name))
# Label 0 for 'good', 1 for any other defect type
self.labels.append(0 if subdir_name == 'good' else 1)
else:
raise ValueError(f"Invalid split: '{self.split}'. Must be 'train' or 'test'.")
if not self.image_paths:
raise RuntimeError(f"No images found for object_type '{self.object_type}' in '{self.split}' split.")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
# Convert grayscale images to RGB if necessary
if image.mode == 'L':
image = image.convert('RGB')
image = self.transform(image)
labels = self.labels[idx] # Labels are already prepared in _load_data
sample = {'data': image, 'label': labels, 'image_path': img_path} # Added image_path for debugging/info
return sample
def getclasses(self):
classes = [str(i) for i in range(self.num_classes)]
c = dict()
for i in range(len(classes)):
c[i] = classes[i]
return c