Spaces:
Runtime error
Runtime error
| import bisect | |
| import numpy as np | |
| import albumentations | |
| from PIL import Image | |
| from torch.utils.data import Dataset, ConcatDataset | |
| class ConcatDatasetWithIndex(ConcatDataset): | |
| """Modified from original pytorch code to return dataset idx""" | |
| def __getitem__(self, idx): | |
| if idx < 0: | |
| if -idx > len(self): | |
| raise ValueError("absolute value of index should not exceed dataset length") | |
| idx = len(self) + idx | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| return self.datasets[dataset_idx][sample_idx], dataset_idx | |
| class ImagePaths(Dataset): | |
| def __init__(self, paths, size=None, random_crop=False, labels=None): | |
| self.size = size | |
| self.random_crop = random_crop | |
| self.labels = dict() if labels is None else labels | |
| self.labels["file_path_"] = paths | |
| self._length = len(paths) | |
| if self.size is not None and self.size > 0: | |
| self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) | |
| if not self.random_crop: | |
| self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) | |
| else: | |
| self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) | |
| self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) | |
| else: | |
| self.preprocessor = lambda **kwargs: kwargs | |
| def __len__(self): | |
| return self._length | |
| def preprocess_image(self, image_path): | |
| image = Image.open(image_path) | |
| if not image.mode == "RGB": | |
| image = image.convert("RGB") | |
| image = np.array(image).astype(np.uint8) | |
| image = self.preprocessor(image=image)["image"] | |
| image = (image/127.5 - 1.0).astype(np.float32) | |
| return image | |
| def __getitem__(self, i): | |
| example = dict() | |
| example["image"] = self.preprocess_image(self.labels["file_path_"][i]) | |
| for k in self.labels: | |
| example[k] = self.labels[k][i] | |
| return example | |
| class NumpyPaths(ImagePaths): | |
| def preprocess_image(self, image_path): | |
| image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 | |
| image = np.transpose(image, (1,2,0)) | |
| image = Image.fromarray(image, mode="RGB") | |
| image = np.array(image).astype(np.uint8) | |
| image = self.preprocessor(image=image)["image"] | |
| image = (image/127.5 - 1.0).astype(np.float32) | |
| return image | |