Spaces:
Configuration error
Configuration error
import os | |
import pandas as pd | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from torchvision.transforms.functional import InterpolationMode | |
class Mvtec(Dataset): | |
def __init__(self, root_dir, object_type=None, split=None, defect_type=None, im_size=None, transform=None): | |
self.root_dir = root_dir | |
self.object_type = object_type | |
self.split = split | |
self.defect_type = defect_type # 'all' or specific defect type for test split | |
self.im_size = im_size | |
self.image_paths = [] # List to store full paths to images | |
self.labels = [] # List to store corresponding labels (0 for good, 1 for anomaly) | |
# Define default transforms if none are provided | |
if transform: | |
self.transform = transform | |
else: | |
imagenet_mean = [0.485, 0.456, 0.406] | |
imagenet_std = [0.229, 0.224, 0.225] | |
self.im_size = (224, 224) if im_size is None else (im_size, im_size) | |
normalize_tf = transforms.Normalize(mean=imagenet_mean, std=imagenet_std) | |
self.transform = transforms.Compose([ | |
transforms.Resize(tuple(self.im_size), interpolation=InterpolationMode.LANCZOS), | |
transforms.ToTensor(), | |
normalize_tf | |
]) | |
self._load_data() # Call the method to populate image_paths and labels | |
self.num_classes = 1 # Binary classification (normal/anomaly) | |
def _load_data(self): | |
\"\"\" | |
Loads image paths and assigns labels based on the folder structure. | |
\"\"\" | |
# Path to the specific object type (e.g., data/bottle) | |
object_path = os.path.join(self.root_dir, self.object_type) | |
# Path to the split directory (e.g., data/bottle/train or data/bottle/test) | |
split_path = os.path.join(object_path, self.split) | |
if not os.path.isdir(split_path): | |
raise FileNotFoundError(f"Split directory not found: {split_path}") | |
if self.split == 'train': | |
# For training, only load images from the 'good' subdirectory | |
good_images_path = os.path.join(split_path, 'good') | |
if not os.path.isdir(good_images_path): | |
raise FileNotFoundError(f"Training 'good' images directory not found: {good_images_path}") | |
for img_name in os.listdir(good_images_path): | |
# Filter for common image file extensions | |
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): | |
self.image_paths.append(os.path.join(good_images_path, img_name)) | |
self.labels.append(0) # 0 for good images (normal) | |
elif self.split == 'test': | |
# For testing, iterate through all subdirectories (good and defect types) | |
subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))] | |
subdirs.sort() # Ensure consistent order | |
for subdir_name in subdirs: | |
# If defect_type is specified and not 'all', only load that specific defect | |
if self.defect_type != 'all' and subdir_name != self.defect_type and subdir_name != 'good': | |
continue # Skip other defect types if a specific one is requested | |
current_dir_path = os.path.join(split_path, subdir_name) | |
for img_name in os.listdir(current_dir_path): | |
if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): | |
self.image_paths.append(os.path.join(current_dir_path, img_name)) | |
# Label 0 for 'good', 1 for any other defect type | |
self.labels.append(0 if subdir_name == 'good' else 1) | |
else: | |
raise ValueError(f"Invalid split: '{self.split}'. Must be 'train' or 'test'.") | |
if not self.image_paths: | |
raise RuntimeError(f"No images found for object_type '{self.object_type}' in '{self.split}' split.") | |
def __len__(self): | |
return len(self.image_paths) | |
def __getitem__(self, idx): | |
img_path = self.image_paths[idx] | |
image = Image.open(img_path) | |
# Convert grayscale images to RGB if necessary | |
if image.mode == 'L': | |
image = image.convert('RGB') | |
image = self.transform(image) | |
labels = self.labels[idx] # Labels are already prepared in _load_data | |
sample = {'data': image, 'label': labels, 'image_path': img_path} # Added image_path for debugging/info | |
return sample | |
def getclasses(self): | |
classes = [str(i) for i in range(self.num_classes)] | |
c = dict() | |
for i in range(len(classes)): | |
c[i] = classes[i] | |
return c |