import os import shutil import rawpy import random from PIL import Image import tifffile as tiff import zipfile import numpy as np import pandas as pd from torch.utils.data import Dataset, DataLoader, TensorDataset from sklearn.model_selection import StratifiedShuffleSplit if not os.path.exists('README.md'): # set pwd to root os.chdir('..') from utils.dataset_utils import split_img, list_images_in_dir, load_image from utils.base import np2torch, torch2np, b2_download_folder IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff'] def get_dataset(name, I_ratio=1.0): # DroneDataset if name in ('DC', 'Drone', 'DroneClassification', 'DroneDatasetClassificationTiled'): return DroneDatasetClassificationTiled(I_ratio=I_ratio) if name in ('DS', 'DroneSegmentation', 'DroneDatasetSegmentationTiled'): return DroneDatasetSegmentationTiled(I_ratio=I_ratio) # MicroscopyDataset if name in ('M', 'Microscopy', 'MicroscopyDataset'): return MicroscopyDataset(I_ratio=I_ratio) # for testing if name in ('DSF', 'DroneDatasetSegmentationFull'): return DroneDatasetSegmentationFull(I_ratio=I_ratio) if name in ('MRGB', 'MicroscopyRGB', 'MicroscopyDatasetRGB'): return MicroscopyDatasetRGB(I_ratio=I_ratio) raise ValueError(name) class ImageFolderDataset(Dataset): """Creates a dataset of images in img_dir and corresponding masks in mask_dir. Corresponding mask files need to contain the filename of the image. Files are expected to be of the same filetype. Args: img_dir (str): path to image folder mask_dir (str): path to mask folder transform (callable, optional): transformation to apply to image and mask bits (int, optional): normalize image by dividing by 2^bits - 1 """ task = 'classification' def __init__(self, img_dir, labels, transform=None, bits=1): self.img_dir = img_dir self.labels = labels self.images = list_images_in_dir(img_dir) assert len(self.images) == len(self.labels) self.transform = transform self.bits = bits def __repr__(self): rep = f"{type(self).__name__}: ImageFolderDataset[{len(self.images)}]" for n, (img, label) in enumerate(zip(self.images, self.labels)): rep += f'\nimage: {img}\tlabel: {label}' if n > 10: rep += '\n...' break return rep def __len__(self): return len(self.images) def __getitem__(self, idx): label = self.labels[idx] img = load_image(self.images[idx]) img = img / (2**self.bits - 1) if self.transform is not None: img = self.transform(img) if len(img.shape) == 2: assert img.shape == (256, 256), f"Invalid size for {self.images[idx]}" else: assert img.shape == (3, 256, 256), f"Invalid size for {self.images[idx]}" return img, label class ImageFolderDatasetSegmentation(Dataset): """Creates a dataset of images in `img_dir` and corresponding masks in `mask_dir`. Corresponding mask files need to contain the filename of the image. Files are expected to be of the same filetype. Args: img_dir (str): path to image folder mask_dir (str): path to mask folder transform (callable, optional): transformation to apply to image and mask bits (int, optional): normalize image by dividing by 2^bits - 1 """ task = 'segmentation' def __init__(self, img_dir, mask_dir, transform=None, bits=1): self.img_dir = img_dir self.mask_dir = mask_dir self.images = list_images_in_dir(img_dir) self.masks = list_images_in_dir(mask_dir) check_image_folder_consistency(self.images, self.masks) self.transform = transform self.bits = bits def __repr__(self): rep = f"{type(self).__name__}: ImageFolderDatasetSegmentation[{len(self.images)}]" for n, (img, mask) in enumerate(zip(self.images, self.masks)): rep += f'\nimage: {img}\tmask: {mask}' if n > 10: rep += '\n...' break return rep def __len__(self): return len(self.images) def __getitem__(self, idx): img = load_image(self.images[idx]) mask = load_image(self.masks[idx]) img = img / (2**self.bits - 1) mask = (mask > 0).astype(np.float32) if self.transform is not None: img = self.transform(img) return img, mask class MultiIntensity(Dataset): """Wrap datasets with different intesities Args: datasets (list): list of datasets to wrap """ def __init__(self, datasets): self.dataset = datasets[0] for d in range(1, len(datasets)): self.dataset.images = self.dataset.images + datasets[d].images self.dataset.labels = self.dataset.labels + datasets[d].labels def __len__(self): return len(self.dataset) def __repr__(self): return f"Subset [{len(self.dataset)}] of " + repr(self.dataset) def __getitem__(self, idx): x, y = self.dataset[idx] if self.transform is not None: x = self.transform(x) return x, y class Subset(Dataset): """Define a subset of a dataset by only selecting given indices. Args: dataset (Dataset): full dataset indices (list): subset indices """ def __init__(self, dataset, indices=None, transform=None): self.dataset = dataset self.indices = indices if indices is not None else range(len(dataset)) self.transform = transform def __len__(self): return len(self.indices) def __repr__(self): return f"Subset [{len(self)}] of " + repr(self.dataset) def __getitem__(self, idx): x, y = self.dataset[self.indices[idx]] if self.transform is not None: x = self.transform(x) return x, y class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation): """Dataset consisting of full-sized numpy images and masks. Images are normalized to range [0, 1]. """ black_level = [0.0625, 0.0626, 0.0625, 0.0626] white_balance = [2.86653646, 1., 1.73079425] colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614, 1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111] camera_parameters = black_level, white_balance, colour_matrix def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16): assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0] img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}' mask_dir = 'data/drone/masks_full' download_drone_dataset(force_download) # XXX: zip files and add checksum? date? super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=bits) class DroneDatasetSegmentationTiled(ImageFolderDatasetSegmentation): """Dataset consisting of tiled numpy images and masks. Images are in range [0, 1] Args: tile_size (int, optional): size of the tiled images. Defaults to 256. """ camera_parameters = DroneDatasetSegmentationFull.camera_parameters def __init__(self, I_ratio=1.0, transform=None): tile_size = 256 img_dir = f'data/drone/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}' mask_dir = f'data/drone/masks_tiles_{tile_size}' if not os.path.exists(img_dir) or not os.path.exists(mask_dir): dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1) print("tiling dataset..") create_tiles_dataset(dataset_full, img_dir, mask_dir, tile_size=tile_size) super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=16) class DroneDatasetClassificationTiled(ImageFolderDataset): camera_parameters = DroneDatasetSegmentationFull.camera_parameters def __init__(self, I_ratio=1.0, transform=None): random_state = 72 tile_size = 256 thr = 0.01 img_dir = f'data/drone/classification/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}_thr_{thr}' mask_dir = f'data/drone/classification/masks_tiles_{tile_size}_thr_{thr}' df_path = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv' if not os.path.exists(img_dir) or not os.path.exists(mask_dir): dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1) print("tiling dataset..") create_tiles_dataset_binary(dataset_full, img_dir, mask_dir, random_state, thr, tile_size=tile_size) self.classes = ['car', 'no car'] self.df = pd.read_csv(df_path) labels = self.df['label'].to_list() super().__init__(img_dir=img_dir, labels=labels, transform=transform, bits=16) images, class_labels = read_label_csv(self.df) self.images = [os.path.join(self.img_dir, image) for image in images] self.labels = class_labels class MicroscopyDataset(ImageFolderDataset): """MicroscopyDataset raw images Args: I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0] raw (bool): Select rgb dataset or raw dataset transform (callable, optional): transformation to apply to image and mask bits (int, optional): normalize image by dividing by 2^bits - 1 """ black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06] white_balance = [-0.6567, 1.9673, 3.5304] colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308] camera_parameters = black_level, white_balance, colour_matrix dataset_mean = [0.91, 0.84, 0.94] dataset_std = [0.08, 0.12, 0.05] def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False): assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0] download_microscopy_dataset(force_download=force_download) self.img_dir = f'data/microscopy/images/raw_scale{int(I_ratio*100):03d}' self.transform = transform self.bits = bits self.label_file = 'data/microscopy/labels/Ma190c_annotations.dat' self.valid_classes = ['BAS', 'EBO', 'EOS', 'KSC', 'LYA', 'LYT', 'MMZ', 'MOB', 'MON', 'MYB', 'MYO', 'NGB', 'NGS', 'PMB', 'PMO', 'UNC'] self.invalid_files = ['Ma190c_lame3_zone13_composite_Mcropped_2.tiff', ] images, class_labels = read_label_file(self.label_file) # filter classes with low appearance self.valid_classes = [class_label for class_label in self.valid_classes if class_labels.count(class_label) > 4] # remove invalid classes and invalid files from (images, class_labels) images, class_labels = list(zip(*[ (image, class_label) for image, class_label in zip(images, class_labels) if class_label in self.valid_classes and image not in self.invalid_files ])) self.classes = list(sorted({*class_labels})) # store full path self.images = [os.path.join(self.img_dir, image) for image in images] # reindex labels self.labels = [self.classes.index(class_label) for class_label in class_labels] class MicroscopyDatasetRGB(MicroscopyDataset): """MicroscopyDataset RGB images Args: I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0] raw (bool): Select rgb dataset or raw dataset transform (callable, optional): transformation to apply to image and mask bits (int, optional): normalize image by dividing by 2^bits - 1 """ camera_parameters = None dataset_mean = None dataset_std = None def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False): super().__init__(I_ratio=I_ratio, transform=transform, bits=bits, force_download=force_download) self.images = [image.replace('raw', 'rgb') for image in self.images] # XXX: hack def read_label_file(label_file_path): images = [] class_labels = [] with open(label_file_path, "rb") as data: for line in data: file_name, class_label = line.decode("utf-8").split() image = file_name + '.tiff' images.append(image) class_labels.append(class_label) return images, class_labels def read_label_csv(df): images = [] class_labels = [] for file_name, label in zip(df['file name'], df['label']): image = file_name + '.tif' images.append(image) class_labels.append(int(label)) return images, class_labels def download_drone_dataset(force_download): b2_download_folder('drone/images', 'data/drone/images_full', force_download=force_download) b2_download_folder('drone/masks', 'data/drone/masks_full', force_download=force_download) unzip_drone_images() def download_microscopy_dataset(force_download): b2_download_folder('Data histopathology/WhiteCellsImages', 'data/microscopy/images', force_download=force_download) b2_download_folder('Data histopathology/WhiteCellsLabels', 'data/microscopy/labels', force_download=force_download) unzip_microscopy_images() def unzip_microscopy_images(): if os.path.isfile('data/microscopy/labels/.bzEmpty'): os.remove('data/microscopy/labels/.bzEmpty') for file in os.listdir('data/microscopy/images'): if file.endswith(".zip"): zip = zipfile.ZipFile(os.path.join('data/microscopy/images', file)) zip.extractall('data/microscopy/images') os.remove(os.path.join('data/microscopy/images', file)) def unzip_drone_images(): if os.path.isfile('data/drone/masks_full/.bzEmpty'): os.remove('data/drone/masks_full/.bzEmpty') for file in os.listdir('data/drone/images_full'): if file.endswith(".zip"): zip = zipfile.ZipFile(os.path.join('data/drone/images_full', file)) zip.extractall('data/drone/images_full') os.remove(os.path.join('data/drone/images_full', file)) def create_tiles_dataset(dataset, img_dir, mask_dir, tile_size=256): for folder in [img_dir, mask_dir]: if not os.path.exists(folder): os.makedirs(folder) for n, (img, mask) in enumerate(dataset): tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) tiled_img, tiled_mask = class_detection(tiled_img, tiled_mask) # Remove images without cars in it for i, (sub_img, sub_mask) in enumerate(zip(tiled_img, tiled_mask)): tile_id = f"{n:02d}_{i:05d}" Image.fromarray(sub_img).save(os.path.join(img_dir, tile_id + '.tif')) Image.fromarray(sub_mask > 0).save(os.path.join(mask_dir, tile_id + '.png')) def create_tiles_dataset_binary(dataset, img_dir, mask_dir, random_state, thr, tile_size=256): for folder in [img_dir, mask_dir]: if not os.path.exists(folder): os.makedirs(folder) ids = [] labels = [] for n, (img, mask) in enumerate(dataset): tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) X_with, X_without, Y_with, Y_without = binary_class_detection( tiled_img, tiled_mask, random_state, thr) # creates balanced arrays with class and without class for i, (sub_X_with, sub_Y_with) in enumerate(zip(X_with, Y_with)): tile_id = f"{n:02d}_{i:05d}" ids.append(tile_id) labels.append(0) Image.fromarray(sub_X_with).save(os.path.join(img_dir, tile_id + '.tif')) Image.fromarray(sub_Y_with > 0).save(os.path.join(mask_dir, tile_id + '.png')) for j, (sub_X_without, sub_Y_without) in enumerate(zip(X_without, Y_without)): tile_id = f"{n:02d}_{i+1+j:05d}" ids.append(tile_id) labels.append(1) Image.fromarray(sub_X_without).save(os.path.join(img_dir, tile_id + '.tif')) Image.fromarray(sub_Y_without > 0).save(os.path.join(mask_dir, tile_id + '.png')) # Image.fromarray(sub_mask).save(os.path.join(mask_dir, tile_id + '.png')) df = pd.DataFrame({'file name': ids, 'label': labels}) df_loc = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv' df.to_csv(df_loc) return def class_detection(X, Y): """Split dataset in images which has the class in the target Args: X (ndarray): input image Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class) Returns: X_with_class (ndarray): input regions with the selected class Y_with_class (ndarray): target regions with the selected class X_without_class (ndarray): input regions without the selected class Y_without_class (ndarray): target regions without the selected class """ with_class = [] without_class = [] for i, img in enumerate(Y): if img.mean() == 0: without_class.append(i) else: with_class.append(i) X_with_class = np.delete(X, without_class, 0) Y_with_class = np.delete(Y, without_class, 0) return X_with_class, Y_with_class def binary_class_detection(X, Y, random_seed, thr): """Splits subimages in subimages with the selected class and without the selected class by calculating the mean of the submasks; subimages with 0 < submask.mean()<=thr are disregared Args: X (ndarray): input image Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class) thr (flaot): sub images are not considered if 0 < sub_target.mean() <= thr balanced (bool): number of returned sub images is equal for both classes if true random_seed (None or int): selection of sub images in class with more elements according to random_seed if balanced Returns: X_with_class (ndarray): input regions with the selected class Y_with_class (ndarray): target regions with the selected class X_without_class (ndarray): input regions without the selected class Y_without_class (ndarray): target regions without the selected class """ with_class = [] without_class = [] no_class = [] for i, img in enumerate(Y): m = img.mean() if m == 0: without_class.append(i) else: if m > thr: with_class.append(i) else: no_class.append(i) N = len(with_class) M = len(without_class) random.seed(random_seed) if N <= M: random.shuffle(without_class) with_class.extend(without_class[:M - N]) else: random.shuffle(with_class) without_class.extend(with_class[:N - M]) X_with_class = np.delete(X, without_class + no_class, 0) X_without_class = np.delete(X, with_class + no_class, 0) Y_with_class = np.delete(Y, without_class + no_class, 0) Y_without_class = np.delete(Y, with_class + no_class, 0) return X_with_class, X_without_class, Y_with_class, Y_without_class def make_dataloader(dataset, batch_size, shuffle=True): X, Y = dataset X, Y = np2torch(X), np2torch(Y) dataset = TensorDataset(X, Y) dataset = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) return dataset def check_image_folder_consistency(images, masks): file_type_images = images[0].split('.')[-1].lower() file_type_masks = masks[0].split('.')[-1].lower() assert len(images) == len(masks), "images / masks length mismatch" for img_file, mask_file in zip(images, masks): img_name = img_file.split('/')[-1].split('.')[0] assert img_name in mask_file, f"image {img_file} corresponds to {mask_file}?" assert img_file.split('.')[-1].lower() == file_type_images, \ f"image file {img_file} file type mismatch. Shoule be: {file_type_images}" assert mask_file.split('.')[-1].lower() == file_type_masks, \ f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"