|
import os |
|
import shutil |
|
import rawpy |
|
import random |
|
from PIL import Image |
|
import tifffile as tiff |
|
import zipfile |
|
|
|
import numpy as np |
|
import pandas as pd |
|
|
|
from torch.utils.data import Dataset, DataLoader, TensorDataset |
|
from sklearn.model_selection import StratifiedShuffleSplit |
|
|
|
if not os.path.exists('README.md'): |
|
os.chdir('..') |
|
|
|
from utils.dataset_utils import split_img, list_images_in_dir, load_image |
|
from utils.base import np2torch, torch2np, b2_download_folder |
|
|
|
IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff'] |
|
|
|
|
|
def get_dataset(name, I_ratio=1.0): |
|
|
|
if name in ('DC', 'Drone', 'DroneClassification', 'DroneDatasetClassificationTiled'): |
|
return DroneDatasetClassificationTiled(I_ratio=I_ratio) |
|
if name in ('DS', 'DroneSegmentation', 'DroneDatasetSegmentationTiled'): |
|
return DroneDatasetSegmentationTiled(I_ratio=I_ratio) |
|
|
|
|
|
if name in ('M', 'Microscopy', 'MicroscopyDataset'): |
|
return MicroscopyDataset(I_ratio=I_ratio) |
|
|
|
|
|
if name in ('DSF', 'DroneDatasetSegmentationFull'): |
|
return DroneDatasetSegmentationFull(I_ratio=I_ratio) |
|
if name in ('MRGB', 'MicroscopyRGB', 'MicroscopyDatasetRGB'): |
|
return MicroscopyDatasetRGB(I_ratio=I_ratio) |
|
|
|
raise ValueError(name) |
|
|
|
|
|
class ImageFolderDataset(Dataset): |
|
"""Creates a dataset of images in img_dir and corresponding masks in mask_dir. |
|
Corresponding mask files need to contain the filename of the image. |
|
Files are expected to be of the same filetype. |
|
|
|
Args: |
|
img_dir (str): path to image folder |
|
mask_dir (str): path to mask folder |
|
transform (callable, optional): transformation to apply to image and mask |
|
bits (int, optional): normalize image by dividing by 2^bits - 1 |
|
""" |
|
|
|
task = 'classification' |
|
|
|
def __init__(self, img_dir, labels, transform=None, bits=1): |
|
|
|
self.img_dir = img_dir |
|
self.labels = labels |
|
|
|
self.images = list_images_in_dir(img_dir) |
|
|
|
assert len(self.images) == len(self.labels) |
|
|
|
self.transform = transform |
|
self.bits = bits |
|
|
|
def __repr__(self): |
|
rep = f"{type(self).__name__}: ImageFolderDataset[{len(self.images)}]" |
|
for n, (img, label) in enumerate(zip(self.images, self.labels)): |
|
rep += f'\nimage: {img}\tlabel: {label}' |
|
if n > 10: |
|
rep += '\n...' |
|
break |
|
return rep |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
|
|
label = self.labels[idx] |
|
|
|
img = load_image(self.images[idx]) |
|
img = img / (2**self.bits - 1) |
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
if len(img.shape) == 2: |
|
assert img.shape == (256, 256), f"Invalid size for {self.images[idx]}" |
|
else: |
|
assert img.shape == (3, 256, 256), f"Invalid size for {self.images[idx]}" |
|
|
|
return img, label |
|
|
|
|
|
class ImageFolderDatasetSegmentation(Dataset): |
|
"""Creates a dataset of images in `img_dir` and corresponding masks in `mask_dir`. |
|
Corresponding mask files need to contain the filename of the image. |
|
Files are expected to be of the same filetype. |
|
|
|
Args: |
|
img_dir (str): path to image folder |
|
mask_dir (str): path to mask folder |
|
transform (callable, optional): transformation to apply to image and mask |
|
bits (int, optional): normalize image by dividing by 2^bits - 1 |
|
""" |
|
|
|
task = 'segmentation' |
|
|
|
def __init__(self, img_dir, mask_dir, transform=None, bits=1): |
|
|
|
self.img_dir = img_dir |
|
self.mask_dir = mask_dir |
|
|
|
self.images = list_images_in_dir(img_dir) |
|
self.masks = list_images_in_dir(mask_dir) |
|
|
|
check_image_folder_consistency(self.images, self.masks) |
|
|
|
self.transform = transform |
|
self.bits = bits |
|
|
|
def __repr__(self): |
|
rep = f"{type(self).__name__}: ImageFolderDatasetSegmentation[{len(self.images)}]" |
|
for n, (img, mask) in enumerate(zip(self.images, self.masks)): |
|
rep += f'\nimage: {img}\tmask: {mask}' |
|
if n > 10: |
|
rep += '\n...' |
|
break |
|
return rep |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
|
|
img = load_image(self.images[idx]) |
|
mask = load_image(self.masks[idx]) |
|
|
|
img = img / (2**self.bits - 1) |
|
mask = (mask > 0).astype(np.float32) |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
return img, mask |
|
|
|
|
|
class MultiIntensity(Dataset): |
|
"""Wrap datasets with different intesities |
|
|
|
Args: |
|
datasets (list): list of datasets to wrap |
|
""" |
|
|
|
def __init__(self, datasets): |
|
self.dataset = datasets[0] |
|
|
|
for d in range(1, len(datasets)): |
|
self.dataset.images = self.dataset.images + datasets[d].images |
|
self.dataset.labels = self.dataset.labels + datasets[d].labels |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __repr__(self): |
|
return f"Subset [{len(self.dataset)}] of " + repr(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
x, y = self.dataset[idx] |
|
if self.transform is not None: |
|
x = self.transform(x) |
|
return x, y |
|
|
|
|
|
class Subset(Dataset): |
|
"""Define a subset of a dataset by only selecting given indices. |
|
|
|
Args: |
|
dataset (Dataset): full dataset |
|
indices (list): subset indices |
|
""" |
|
|
|
def __init__(self, dataset, indices=None, transform=None): |
|
self.dataset = dataset |
|
self.indices = indices if indices is not None else range(len(dataset)) |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.indices) |
|
|
|
def __repr__(self): |
|
return f"Subset [{len(self)}] of " + repr(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
x, y = self.dataset[self.indices[idx]] |
|
if self.transform is not None: |
|
x = self.transform(x) |
|
return x, y |
|
|
|
|
|
class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation): |
|
"""Dataset consisting of full-sized numpy images and masks. Images are normalized to range [0, 1]. |
|
""" |
|
|
|
black_level = [0.0625, 0.0626, 0.0625, 0.0626] |
|
white_balance = [2.86653646, 1., 1.73079425] |
|
colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614, |
|
1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111] |
|
camera_parameters = black_level, white_balance, colour_matrix |
|
|
|
def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16): |
|
|
|
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0] |
|
|
|
img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}' |
|
mask_dir = 'data/drone/masks_full' |
|
|
|
download_drone_dataset(force_download) |
|
|
|
super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=bits) |
|
|
|
|
|
class DroneDatasetSegmentationTiled(ImageFolderDatasetSegmentation): |
|
"""Dataset consisting of tiled numpy images and masks. Images are in range [0, 1] |
|
Args: |
|
tile_size (int, optional): size of the tiled images. Defaults to 256. |
|
""" |
|
|
|
camera_parameters = DroneDatasetSegmentationFull.camera_parameters |
|
|
|
def __init__(self, I_ratio=1.0, transform=None): |
|
|
|
tile_size = 256 |
|
|
|
img_dir = f'data/drone/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}' |
|
mask_dir = f'data/drone/masks_tiles_{tile_size}' |
|
|
|
if not os.path.exists(img_dir) or not os.path.exists(mask_dir): |
|
dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1) |
|
print("tiling dataset..") |
|
create_tiles_dataset(dataset_full, img_dir, mask_dir, tile_size=tile_size) |
|
|
|
super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=16) |
|
|
|
|
|
class DroneDatasetClassificationTiled(ImageFolderDataset): |
|
|
|
camera_parameters = DroneDatasetSegmentationFull.camera_parameters |
|
|
|
def __init__(self, I_ratio=1.0, transform=None): |
|
|
|
random_state = 72 |
|
tile_size = 256 |
|
thr = 0.01 |
|
|
|
img_dir = f'data/drone/classification/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}_thr_{thr}' |
|
mask_dir = f'data/drone/classification/masks_tiles_{tile_size}_thr_{thr}' |
|
df_path = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv' |
|
|
|
if not os.path.exists(img_dir) or not os.path.exists(mask_dir): |
|
dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1) |
|
print("tiling dataset..") |
|
create_tiles_dataset_binary(dataset_full, img_dir, mask_dir, random_state, thr, tile_size=tile_size) |
|
|
|
self.classes = ['car', 'no car'] |
|
self.df = pd.read_csv(df_path) |
|
labels = self.df['label'].to_list() |
|
|
|
super().__init__(img_dir=img_dir, labels=labels, transform=transform, bits=16) |
|
|
|
images, class_labels = read_label_csv(self.df) |
|
self.images = [os.path.join(self.img_dir, image) for image in images] |
|
self.labels = class_labels |
|
|
|
|
|
class MicroscopyDataset(ImageFolderDataset): |
|
"""MicroscopyDataset raw images |
|
|
|
Args: |
|
I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0] |
|
raw (bool): Select rgb dataset or raw dataset |
|
transform (callable, optional): transformation to apply to image and mask |
|
bits (int, optional): normalize image by dividing by 2^bits - 1 |
|
""" |
|
|
|
black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06] |
|
white_balance = [-0.6567, 1.9673, 3.5304] |
|
colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308] |
|
|
|
camera_parameters = black_level, white_balance, colour_matrix |
|
|
|
dataset_mean = [0.91, 0.84, 0.94] |
|
dataset_std = [0.08, 0.12, 0.05] |
|
|
|
def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False): |
|
|
|
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0] |
|
|
|
download_microscopy_dataset(force_download=force_download) |
|
|
|
self.img_dir = f'data/microscopy/images/raw_scale{int(I_ratio*100):03d}' |
|
self.transform = transform |
|
self.bits = bits |
|
|
|
self.label_file = 'data/microscopy/labels/Ma190c_annotations.dat' |
|
|
|
self.valid_classes = ['BAS', 'EBO', 'EOS', 'KSC', 'LYA', 'LYT', 'MMZ', 'MOB', |
|
'MON', 'MYB', 'MYO', 'NGB', 'NGS', 'PMB', 'PMO', 'UNC'] |
|
|
|
self.invalid_files = ['Ma190c_lame3_zone13_composite_Mcropped_2.tiff', ] |
|
|
|
images, class_labels = read_label_file(self.label_file) |
|
|
|
|
|
self.valid_classes = [class_label for class_label in self.valid_classes |
|
if class_labels.count(class_label) > 4] |
|
|
|
|
|
images, class_labels = list(zip(*[ |
|
(image, class_label) |
|
for image, class_label in zip(images, class_labels) |
|
if class_label in self.valid_classes and image not in self.invalid_files |
|
])) |
|
|
|
self.classes = list(sorted({*class_labels})) |
|
|
|
|
|
self.images = [os.path.join(self.img_dir, image) for image in images] |
|
|
|
|
|
self.labels = [self.classes.index(class_label) for class_label in class_labels] |
|
|
|
|
|
class MicroscopyDatasetRGB(MicroscopyDataset): |
|
"""MicroscopyDataset RGB images |
|
|
|
Args: |
|
I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0] |
|
raw (bool): Select rgb dataset or raw dataset |
|
transform (callable, optional): transformation to apply to image and mask |
|
bits (int, optional): normalize image by dividing by 2^bits - 1 |
|
""" |
|
camera_parameters = None |
|
|
|
dataset_mean = None |
|
dataset_std = None |
|
|
|
def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False): |
|
super().__init__(I_ratio=I_ratio, transform=transform, bits=bits, force_download=force_download) |
|
self.images = [image.replace('raw', 'rgb') for image in self.images] |
|
|
|
|
|
def read_label_file(label_file_path): |
|
|
|
images = [] |
|
class_labels = [] |
|
|
|
with open(label_file_path, "rb") as data: |
|
for line in data: |
|
file_name, class_label = line.decode("utf-8").split() |
|
image = file_name + '.tiff' |
|
images.append(image) |
|
class_labels.append(class_label) |
|
|
|
return images, class_labels |
|
|
|
|
|
def read_label_csv(df): |
|
|
|
images = [] |
|
class_labels = [] |
|
|
|
for file_name, label in zip(df['file name'], df['label']): |
|
image = file_name + '.tif' |
|
images.append(image) |
|
class_labels.append(int(label)) |
|
return images, class_labels |
|
|
|
|
|
def download_drone_dataset(force_download): |
|
b2_download_folder('drone/images', 'data/drone/images_full', force_download=force_download) |
|
b2_download_folder('drone/masks', 'data/drone/masks_full', force_download=force_download) |
|
unzip_drone_images() |
|
|
|
|
|
def download_microscopy_dataset(force_download): |
|
b2_download_folder('Data histopathology/WhiteCellsImages', |
|
'data/microscopy/images', force_download=force_download) |
|
b2_download_folder('Data histopathology/WhiteCellsLabels', |
|
'data/microscopy/labels', force_download=force_download) |
|
unzip_microscopy_images() |
|
|
|
|
|
def unzip_microscopy_images(): |
|
|
|
if os.path.isfile('data/microscopy/labels/.bzEmpty'): |
|
os.remove('data/microscopy/labels/.bzEmpty') |
|
|
|
for file in os.listdir('data/microscopy/images'): |
|
if file.endswith(".zip"): |
|
zip = zipfile.ZipFile(os.path.join('data/microscopy/images', file)) |
|
zip.extractall('data/microscopy/images') |
|
os.remove(os.path.join('data/microscopy/images', file)) |
|
|
|
|
|
def unzip_drone_images(): |
|
|
|
if os.path.isfile('data/drone/masks_full/.bzEmpty'): |
|
os.remove('data/drone/masks_full/.bzEmpty') |
|
|
|
for file in os.listdir('data/drone/images_full'): |
|
if file.endswith(".zip"): |
|
zip = zipfile.ZipFile(os.path.join('data/drone/images_full', file)) |
|
zip.extractall('data/drone/images_full') |
|
os.remove(os.path.join('data/drone/images_full', file)) |
|
|
|
|
|
def create_tiles_dataset(dataset, img_dir, mask_dir, tile_size=256): |
|
for folder in [img_dir, mask_dir]: |
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
for n, (img, mask) in enumerate(dataset): |
|
tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) |
|
tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) |
|
tiled_img, tiled_mask = class_detection(tiled_img, tiled_mask) |
|
for i, (sub_img, sub_mask) in enumerate(zip(tiled_img, tiled_mask)): |
|
tile_id = f"{n:02d}_{i:05d}" |
|
Image.fromarray(sub_img).save(os.path.join(img_dir, tile_id + '.tif')) |
|
Image.fromarray(sub_mask > 0).save(os.path.join(mask_dir, tile_id + '.png')) |
|
|
|
|
|
def create_tiles_dataset_binary(dataset, img_dir, mask_dir, random_state, thr, tile_size=256): |
|
|
|
for folder in [img_dir, mask_dir]: |
|
if not os.path.exists(folder): |
|
os.makedirs(folder) |
|
|
|
ids = [] |
|
labels = [] |
|
|
|
for n, (img, mask) in enumerate(dataset): |
|
tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) |
|
tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size)) |
|
|
|
X_with, X_without, Y_with, Y_without = binary_class_detection( |
|
tiled_img, tiled_mask, random_state, thr) |
|
|
|
for i, (sub_X_with, sub_Y_with) in enumerate(zip(X_with, Y_with)): |
|
tile_id = f"{n:02d}_{i:05d}" |
|
ids.append(tile_id) |
|
labels.append(0) |
|
Image.fromarray(sub_X_with).save(os.path.join(img_dir, tile_id + '.tif')) |
|
Image.fromarray(sub_Y_with > 0).save(os.path.join(mask_dir, tile_id + '.png')) |
|
for j, (sub_X_without, sub_Y_without) in enumerate(zip(X_without, Y_without)): |
|
tile_id = f"{n:02d}_{i+1+j:05d}" |
|
ids.append(tile_id) |
|
labels.append(1) |
|
Image.fromarray(sub_X_without).save(os.path.join(img_dir, tile_id + '.tif')) |
|
Image.fromarray(sub_Y_without > 0).save(os.path.join(mask_dir, tile_id + '.png')) |
|
|
|
|
|
df = pd.DataFrame({'file name': ids, 'label': labels}) |
|
|
|
df_loc = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv' |
|
df.to_csv(df_loc) |
|
|
|
return |
|
|
|
|
|
def class_detection(X, Y): |
|
"""Split dataset in images which has the class in the target |
|
|
|
Args: |
|
X (ndarray): input image |
|
Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class) |
|
Returns: |
|
X_with_class (ndarray): input regions with the selected class |
|
Y_with_class (ndarray): target regions with the selected class |
|
X_without_class (ndarray): input regions without the selected class |
|
Y_without_class (ndarray): target regions without the selected class |
|
""" |
|
|
|
with_class = [] |
|
without_class = [] |
|
for i, img in enumerate(Y): |
|
if img.mean() == 0: |
|
without_class.append(i) |
|
else: |
|
with_class.append(i) |
|
|
|
X_with_class = np.delete(X, without_class, 0) |
|
Y_with_class = np.delete(Y, without_class, 0) |
|
|
|
return X_with_class, Y_with_class |
|
|
|
|
|
def binary_class_detection(X, Y, random_seed, thr): |
|
"""Splits subimages in subimages with the selected class and without the selected class by calculating the mean of the submasks; subimages with 0 < submask.mean()<=thr are disregared |
|
|
|
|
|
|
|
Args: |
|
X (ndarray): input image |
|
Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class) |
|
thr (flaot): sub images are not considered if 0 < sub_target.mean() <= thr |
|
balanced (bool): number of returned sub images is equal for both classes if true |
|
random_seed (None or int): selection of sub images in class with more elements according to random_seed if balanced |
|
Returns: |
|
X_with_class (ndarray): input regions with the selected class |
|
Y_with_class (ndarray): target regions with the selected class |
|
X_without_class (ndarray): input regions without the selected class |
|
Y_without_class (ndarray): target regions without the selected class |
|
""" |
|
|
|
with_class = [] |
|
without_class = [] |
|
no_class = [] |
|
|
|
for i, img in enumerate(Y): |
|
m = img.mean() |
|
if m == 0: |
|
without_class.append(i) |
|
else: |
|
if m > thr: |
|
with_class.append(i) |
|
else: |
|
no_class.append(i) |
|
|
|
N = len(with_class) |
|
M = len(without_class) |
|
random.seed(random_seed) |
|
if N <= M: |
|
random.shuffle(without_class) |
|
with_class.extend(without_class[:M - N]) |
|
else: |
|
random.shuffle(with_class) |
|
without_class.extend(with_class[:N - M]) |
|
|
|
X_with_class = np.delete(X, without_class + no_class, 0) |
|
X_without_class = np.delete(X, with_class + no_class, 0) |
|
Y_with_class = np.delete(Y, without_class + no_class, 0) |
|
Y_without_class = np.delete(Y, with_class + no_class, 0) |
|
|
|
return X_with_class, X_without_class, Y_with_class, Y_without_class |
|
|
|
|
|
def make_dataloader(dataset, batch_size, shuffle=True): |
|
|
|
X, Y = dataset |
|
|
|
X, Y = np2torch(X), np2torch(Y) |
|
|
|
dataset = TensorDataset(X, Y) |
|
dataset = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) |
|
|
|
return dataset |
|
|
|
|
|
def check_image_folder_consistency(images, masks): |
|
file_type_images = images[0].split('.')[-1].lower() |
|
file_type_masks = masks[0].split('.')[-1].lower() |
|
assert len(images) == len(masks), "images / masks length mismatch" |
|
for img_file, mask_file in zip(images, masks): |
|
img_name = img_file.split('/')[-1].split('.')[0] |
|
assert img_name in mask_file, f"image {img_file} corresponds to {mask_file}?" |
|
assert img_file.split('.')[-1].lower() == file_type_images, \ |
|
f"image file {img_file} file type mismatch. Shoule be: {file_type_images}" |
|
assert mask_file.split('.')[-1].lower() == file_type_masks, \ |
|
f"image file {mask_file} file type mismatch. Should be: {file_type_masks}" |
|
|