File size: 4,865 Bytes
ac2eaad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode


class Mvtec(Dataset):
    def __init__(self, root_dir, object_type=None, split=None, defect_type=None, im_size=None, transform=None):
        self.root_dir = root_dir
        self.object_type = object_type
        self.split = split
        self.defect_type = defect_type # 'all' or specific defect type for test split
        self.im_size = im_size

        self.image_paths = [] # List to store full paths to images
        self.labels = []      # List to store corresponding labels (0 for good, 1 for anomaly)

        # Define default transforms if none are provided
        if transform:
            self.transform = transform
        else:
            imagenet_mean = [0.485, 0.456, 0.406]
            imagenet_std = [0.229, 0.224, 0.225]
            self.im_size = (224, 224) if im_size is None else (im_size, im_size)
            normalize_tf = transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
            self.transform = transforms.Compose([
                transforms.Resize(tuple(self.im_size), interpolation=InterpolationMode.LANCZOS),
                transforms.ToTensor(),
                normalize_tf
            ])
        
        self._load_data() # Call the method to populate image_paths and labels

        self.num_classes = 1 # Binary classification (normal/anomaly)

    def _load_data(self):
        \"\"\"
        Loads image paths and assigns labels based on the folder structure.
        \"\"\"
        # Path to the specific object type (e.g., data/bottle)
        object_path = os.path.join(self.root_dir, self.object_type)
        
        # Path to the split directory (e.g., data/bottle/train or data/bottle/test)
        split_path = os.path.join(object_path, self.split)

        if not os.path.isdir(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        if self.split == 'train':
            # For training, only load images from the 'good' subdirectory
            good_images_path = os.path.join(split_path, 'good')
            if not os.path.isdir(good_images_path):
                raise FileNotFoundError(f"Training 'good' images directory not found: {good_images_path}")
            
            for img_name in os.listdir(good_images_path):
                # Filter for common image file extensions
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    self.image_paths.append(os.path.join(good_images_path, img_name))
                    self.labels.append(0) # 0 for good images (normal)

        elif self.split == 'test':
            # For testing, iterate through all subdirectories (good and defect types)
            subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
            subdirs.sort() # Ensure consistent order

            for subdir_name in subdirs:
                # If defect_type is specified and not 'all', only load that specific defect
                if self.defect_type != 'all' and subdir_name != self.defect_type and subdir_name != 'good':
                    continue # Skip other defect types if a specific one is requested

                current_dir_path = os.path.join(split_path, subdir_name)
                
                for img_name in os.listdir(current_dir_path):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                        self.image_paths.append(os.path.join(current_dir_path, img_name))
                        # Label 0 for 'good', 1 for any other defect type
                        self.labels.append(0 if subdir_name == 'good' else 1)
        else:
            raise ValueError(f"Invalid split: '{self.split}'. Must be 'train' or 'test'.")

        if not self.image_paths:
            raise RuntimeError(f"No images found for object_type '{self.object_type}' in '{self.split}' split.")


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        
        # Convert grayscale images to RGB if necessary
        if image.mode == 'L':
            image = image.convert('RGB')
            
        image = self.transform(image)
        labels = self.labels[idx] # Labels are already prepared in _load_data
        
        sample = {'data': image, 'label': labels, 'image_path': img_path} # Added image_path for debugging/info

        return sample

    def getclasses(self):
        classes = [str(i) for i in range(self.num_classes)]
        c = dict()
        for i in range(len(classes)):
            c[i] = classes[i]
        return c