ParamDev commited on
Commit
ac2eaad
·
verified ·
1 Parent(s): 56f90b5

Create dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +110 -0
dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from torchvision.transforms.functional import InterpolationMode
7
+
8
+
9
+ class Mvtec(Dataset):
10
+ def __init__(self, root_dir, object_type=None, split=None, defect_type=None, im_size=None, transform=None):
11
+ self.root_dir = root_dir
12
+ self.object_type = object_type
13
+ self.split = split
14
+ self.defect_type = defect_type # 'all' or specific defect type for test split
15
+ self.im_size = im_size
16
+
17
+ self.image_paths = [] # List to store full paths to images
18
+ self.labels = [] # List to store corresponding labels (0 for good, 1 for anomaly)
19
+
20
+ # Define default transforms if none are provided
21
+ if transform:
22
+ self.transform = transform
23
+ else:
24
+ imagenet_mean = [0.485, 0.456, 0.406]
25
+ imagenet_std = [0.229, 0.224, 0.225]
26
+ self.im_size = (224, 224) if im_size is None else (im_size, im_size)
27
+ normalize_tf = transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize(tuple(self.im_size), interpolation=InterpolationMode.LANCZOS),
30
+ transforms.ToTensor(),
31
+ normalize_tf
32
+ ])
33
+
34
+ self._load_data() # Call the method to populate image_paths and labels
35
+
36
+ self.num_classes = 1 # Binary classification (normal/anomaly)
37
+
38
+ def _load_data(self):
39
+ \"\"\"
40
+ Loads image paths and assigns labels based on the folder structure.
41
+ \"\"\"
42
+ # Path to the specific object type (e.g., data/bottle)
43
+ object_path = os.path.join(self.root_dir, self.object_type)
44
+
45
+ # Path to the split directory (e.g., data/bottle/train or data/bottle/test)
46
+ split_path = os.path.join(object_path, self.split)
47
+
48
+ if not os.path.isdir(split_path):
49
+ raise FileNotFoundError(f"Split directory not found: {split_path}")
50
+
51
+ if self.split == 'train':
52
+ # For training, only load images from the 'good' subdirectory
53
+ good_images_path = os.path.join(split_path, 'good')
54
+ if not os.path.isdir(good_images_path):
55
+ raise FileNotFoundError(f"Training 'good' images directory not found: {good_images_path}")
56
+
57
+ for img_name in os.listdir(good_images_path):
58
+ # Filter for common image file extensions
59
+ if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
60
+ self.image_paths.append(os.path.join(good_images_path, img_name))
61
+ self.labels.append(0) # 0 for good images (normal)
62
+
63
+ elif self.split == 'test':
64
+ # For testing, iterate through all subdirectories (good and defect types)
65
+ subdirs = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
66
+ subdirs.sort() # Ensure consistent order
67
+
68
+ for subdir_name in subdirs:
69
+ # If defect_type is specified and not 'all', only load that specific defect
70
+ if self.defect_type != 'all' and subdir_name != self.defect_type and subdir_name != 'good':
71
+ continue # Skip other defect types if a specific one is requested
72
+
73
+ current_dir_path = os.path.join(split_path, subdir_name)
74
+
75
+ for img_name in os.listdir(current_dir_path):
76
+ if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
77
+ self.image_paths.append(os.path.join(current_dir_path, img_name))
78
+ # Label 0 for 'good', 1 for any other defect type
79
+ self.labels.append(0 if subdir_name == 'good' else 1)
80
+ else:
81
+ raise ValueError(f"Invalid split: '{self.split}'. Must be 'train' or 'test'.")
82
+
83
+ if not self.image_paths:
84
+ raise RuntimeError(f"No images found for object_type '{self.object_type}' in '{self.split}' split.")
85
+
86
+
87
+ def __len__(self):
88
+ return len(self.image_paths)
89
+
90
+ def __getitem__(self, idx):
91
+ img_path = self.image_paths[idx]
92
+ image = Image.open(img_path)
93
+
94
+ # Convert grayscale images to RGB if necessary
95
+ if image.mode == 'L':
96
+ image = image.convert('RGB')
97
+
98
+ image = self.transform(image)
99
+ labels = self.labels[idx] # Labels are already prepared in _load_data
100
+
101
+ sample = {'data': image, 'label': labels, 'image_path': img_path} # Added image_path for debugging/info
102
+
103
+ return sample
104
+
105
+ def getclasses(self):
106
+ classes = [str(i) for i in range(self.num_classes)]
107
+ c = dict()
108
+ for i in range(len(classes)):
109
+ c[i] = classes[i]
110
+ return c