|
import warnings |
|
import time |
|
import numpy as np |
|
from tensorflow import keras |
|
import os |
|
import h5py |
|
import random |
|
from PIL import Image |
|
import nibabel as nib |
|
from nilearn.image import resample_img |
|
from skimage.exposure import equalize_adapthist |
|
from scipy.ndimage import zoom |
|
import tensorflow as tf |
|
|
|
import ddmr.utils.constants as C |
|
from ddmr.utils.operators import min_max_norm |
|
from ddmr.utils.misc import segmentation_cardinal_to_ohe |
|
from ddmr.utils.thin_plate_splines import ThinPlateSplines |
|
from voxelmorph.tf.layers import SpatialTransformer |
|
from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT |
|
|
|
from tensorflow.python.keras.preprocessing.image import Iterator |
|
from tensorflow.python.keras.utils import Sequence |
|
import sys |
|
|
|
from collections import defaultdict |
|
|
|
from Brain_study.format_dataset import SEGMENTATION_LOC |
|
|
|
|
|
|
|
import time |
|
|
|
class BatchGenerator: |
|
def __init__(self, |
|
directory, |
|
batch_size, |
|
shuffle=True, |
|
split=0.7, |
|
combine_segmentations=True, |
|
labels=['all'], |
|
directory_val=None, |
|
return_isotropic_shape=False): |
|
self.file_directory = directory |
|
self.batch_size = batch_size |
|
self.combine_segmentations = combine_segmentations |
|
self.labels = labels |
|
self.shuffle = shuffle |
|
self.split = split |
|
self.return_isotropic_shape=return_isotropic_shape |
|
|
|
if directory_val is None: |
|
self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))] |
|
random.shuffle(self.file_list) if self.shuffle else self.file_list.sort() |
|
self.num_samples = len(self.file_list) |
|
training_samples = self.file_list[:int(self.num_samples * self.split)] |
|
|
|
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels, return_isotropic_shape=return_isotropic_shape) |
|
if self.split < 1.: |
|
validation_samples = list(set(self.file_list) - set(training_samples)) |
|
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'], |
|
validation=True, return_isotropic_shape=return_isotropic_shape) |
|
else: |
|
self.validation_iter = None |
|
else: |
|
training_samples = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))] |
|
random.shuffle(training_samples) if self.shuffle else training_samples.sort() |
|
|
|
validation_samples = [os.path.join(directory_val, f) for f in os.listdir(directory_val) if f.endswith(('h5', 'hd5'))] |
|
random.shuffle(validation_samples) if self.shuffle else validation_samples.sort() |
|
|
|
self.num_samples = len(training_samples) + len(validation_samples) |
|
self.file_list = training_samples + validation_samples |
|
|
|
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels) |
|
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, labels, |
|
validation=True) |
|
|
|
def get_train_generator(self): |
|
return self.train_iter |
|
|
|
def get_validation_generator(self): |
|
if self.validation_iter is not None: |
|
return self.validation_iter |
|
else: |
|
raise ValueError('No validation iterator. Split must be < 1.0') |
|
|
|
def get_file_list(self): |
|
return self.file_list |
|
|
|
def get_data_shape(self): |
|
return self.train_iter.get_data_shape() |
|
|
|
|
|
ALL_LABELS = {2., 3., 4., 6., 8., 9., 11., 12., 14., 16., 20., 23., 29., 33., 39., 53., 67., 76., 102., 203., 210., |
|
211., 218., 219., 232., 233., 254., 255.} |
|
ALL_LABELS_LOC = {label: loc for label, loc in zip(ALL_LABELS, range(0, len(ALL_LABELS)))} |
|
|
|
|
|
class BatchIterator(Sequence): |
|
def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'], |
|
zero_grads=[64, 64, 64, 3], validation=False, sequential_labels=True, |
|
return_isotropic_shape=False, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
self.file_list = file_list |
|
self.combine_segmentations = combine_segmentations |
|
self.labels = labels |
|
self.zero_grads = np.zeros(zero_grads) |
|
self.idx_list = np.arange(0, len(self.file_list)) |
|
self.validation = validation |
|
self.sequential_labels = sequential_labels |
|
self.return_isotropic_shape = return_isotropic_shape |
|
self._initialize() |
|
self.shuffle_samples() |
|
|
|
def _initialize(self): |
|
if (isinstance(self.labels[0], str) and self.labels[0].lower() != 'none'): |
|
if self.labels[0] != 'all': |
|
|
|
self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels] |
|
if not self.sequential_labels: |
|
self.labels = [SEGMENTATION_LOC[lbl] for lbl in self.labels] |
|
self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x in self.labels else 0 |
|
else: |
|
self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0 |
|
else: |
|
|
|
if self.sequential_labels: |
|
self.labels = list(set(SEGMENTATION_LOC.values())) |
|
self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x else 0 |
|
else: |
|
self.labels = list(ALL_LABELS) |
|
self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0 |
|
elif hasattr(self.labels[0], 'lower') and self.labels[0].lower() == 'none': |
|
|
|
self.labels_dict = dict() |
|
else: |
|
assert np.all([isinstance(lbl, (int, float)) for lbl in self.labels]), "Labels must be a str, int or float" |
|
|
|
|
|
self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0) |
|
|
|
|
|
|
|
with h5py.File(self.file_list[0], 'r') as f: |
|
self.image_shape = list(f['image'][:].shape) |
|
self.segm_shape = self.image_shape.copy() |
|
self.segm_shape[-1] = len(self.labels) if not self.combine_segmentations else 1 |
|
|
|
self.batch_shape = self.image_shape.copy() |
|
self.batch_shape[-1] = self.image_shape[-1] + self.segm_shape[-1] |
|
|
|
def shuffle_samples(self): |
|
np.random.shuffle(self.idx_list) |
|
|
|
def __len__(self): |
|
return self.num_steps |
|
|
|
def _filter_segmentations(self, segm, segm_labels): |
|
if self.combine_segmentations: |
|
|
|
warnings.warn('Cannot select labels when combine_segmentations options is active') |
|
if self.labels[0] != 'all': |
|
if set(self.labels).issubset(set(segm_labels)): |
|
|
|
idx = [ALL_LABELS_LOC[l] for l in self.labels] |
|
segm = segm[..., idx] |
|
else: |
|
|
|
idx = [ALL_LABELS_LOC[l] for l in list(set(self.labels).intersection(set(segm_labels)))] |
|
aux = segm.copy() |
|
segm = np.zeros(self.segm_shape) |
|
segm[..., :len(idx)] = aux[..., idx] |
|
|
|
return segm |
|
|
|
def _load_sample(self, file_path): |
|
with h5py.File(file_path, 'r') as f: |
|
img = f['image'][:] |
|
segm = f['segmentation'][:] |
|
isot_shape = f['isotropic_shape'][:] |
|
|
|
if not self.combine_segmentations: |
|
if self.sequential_labels: |
|
|
|
segm = np.squeeze(np.eye(len(self.labels))[segm]) |
|
else: |
|
lbls_list = list(ALL_LABELS) if self.labels[0] == 'all' else self.labels |
|
segm = segmentation_cardinal_to_ohe(segm, lbls_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
img = np.asarray(img, dtype=np.float32) |
|
segm = np.asarray(segm, dtype=np.float32) |
|
if not isinstance(self.labels[0], str) or self.labels[0].lower() != 'none' or self.validation: |
|
|
|
|
|
if self.validation: |
|
ret_val = np.concatenate([img, segm], axis=-1), (img, segm, self.zero_grads), isot_shape |
|
else: |
|
ret_val = np.concatenate([img, segm], axis=-1), (img, self.zero_grads), isot_shape |
|
else: |
|
ret_val = img, (img, self.zero_grads), isot_shape |
|
return ret_val |
|
|
|
def __getitem__(self, idx): |
|
in_batch = list() |
|
isotropic_shape = list() |
|
|
|
|
|
batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size] |
|
file_list = [self.file_list[i] for i in batch_idxs] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for batch_idx, f in enumerate(file_list): |
|
b, i, isot_shape = self._load_sample(f) |
|
|
|
if self.return_isotropic_shape: |
|
isotropic_shape.append(isot_shape) |
|
in_batch.append(b) |
|
|
|
|
|
in_batch = np.asarray(in_batch, dtype=np.float32) |
|
ret_val = (in_batch, in_batch) |
|
if self.return_isotropic_shape: |
|
isotropic_shape = np.asarray(isotropic_shape, dtype=np.int) |
|
ret_val += (isotropic_shape,) |
|
|
|
return ret_val |
|
|
|
def __iter__(self): |
|
"""Create a generator that iterate over the Sequence.""" |
|
for item in (self[i] for i in range(len(self))): |
|
yield item |
|
|
|
def get_data_shape(self): |
|
return self.batch_shape, self.image_shape, self.segm_shape |
|
|
|
def on_epoch_end(self): |
|
self.shuffle_samples() |
|
|
|
def get_segmentation_labels(self): |
|
if self.combine_segmentations: |
|
labels = [1] |
|
else: |
|
labels = self.labels |
|
return labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
def get_size(obj, seen=None): |
|
"""Recursively finds size of objects""" |
|
size = sys.getsizeof(obj) |
|
if seen is None: |
|
seen = set() |
|
obj_id = id(obj) |
|
if obj_id in seen: |
|
return 0 |
|
# Important mark as seen *before* entering recursion to gracefully handle |
|
# self-referential objects |
|
seen.add(obj_id) |
|
if isinstance(obj, dict): |
|
size += sum([get_size(v, seen) for v in obj.values()]) |
|
size += sum([get_size(k, seen) for k in obj.keys()]) |
|
elif hasattr(obj, '__dict__'): |
|
size += get_size(obj.__dict__, seen) |
|
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): |
|
size += sum([get_size(i, seen) for i in obj]) |
|
return size |
|
|
|
|
|
class BatchIterator(Iterator): |
|
def __init__(self, generator, file_list, input_shape, output_shape, batch_size, shuffle, all_files_in_batch): |
|
self.file_list = file_list |
|
self.generator = generator |
|
self.input_shape = input_shape |
|
self.nr_of_inputs = len(input_shape) |
|
self.output_shape = output_shape |
|
self.nr_of_outputs = len(output_shape) |
|
self.all_files_in_batch = all_files_in_batch |
|
self.preload_to_memory = False |
|
self.file_cache = {} |
|
self.max_cache_size = 10*1024 |
|
self.verbose = False |
|
if self.preload_to_memory: |
|
for filename, file_index in self.file_list: |
|
file = h5py.File(filename, 'r') |
|
inputs = {} |
|
for name, data in file['input'].items(): |
|
inputs[name] = np.copy(data) |
|
self.file_cache[filename] = {'input': inputs, 'output': np.copy(file['output'])} |
|
file.close() |
|
if get_size(self.file_cache) / (1024*1024) >= self.max_cache_size: |
|
print('File cache has reached limit of', self.max_cache_size, 'MBs') |
|
break |
|
epoch_size = len(file_list) |
|
if all_files_in_batch: |
|
epoch_size = len(file_list) * 10 |
|
super(BatchIterator, self).__init__(epoch_size, batch_size, shuffle, None) |
|
|
|
def _get_sample(self, index): |
|
filename, file_index = self.file_list[index] |
|
if filename in self.file_cache: |
|
file = self.file_cache[filename] |
|
else: |
|
file = h5py.File(filename, 'r') |
|
inputs = [] |
|
outputs = [] |
|
for name, data in file['input'].items(): |
|
inputs.append(data[file_index, :]) |
|
for name, data in file['output'].items(): |
|
if len(data.shape) > 1: |
|
outputs.append(data[file_index, :]) |
|
else: |
|
outputs.append(data[file_index]) |
|
#outputs.append(file['output'][file_index, :]) # TODO fix |
|
if filename not in self.file_cache: |
|
file.close() |
|
return inputs, outputs |
|
|
|
def _get_random_sample_in_file(self, file_index): |
|
filename = self.file_list[file_index] |
|
file = h5py.File(filename, 'r') |
|
x = file['output/0'] |
|
sample = np.random.randint(0, x.shape[0]) |
|
#print('Sampling image', sample, 'from file', filename) |
|
inputs = [] |
|
outputs = [] |
|
for name, data in file['input'].items(): |
|
inputs.append(data[sample, :]) |
|
for name, data in file['output'].items(): |
|
outputs.append(data[file_index, :]) |
|
#outputs.append(file['output'][sample, :]) # TODO FIX output |
|
file.close() |
|
return inputs, outputs |
|
|
|
def next(self): |
|
|
|
with self.lock: |
|
index_array = next(self.index_generator) |
|
|
|
#print(len(index_array)) |
|
return self._get_batches_of_transformed_samples(index_array) |
|
|
|
def _get_batches_of_transformed_samples(self, index_array): |
|
start_batch = time.time() |
|
batches_x = [] |
|
batches_y = [] |
|
for input_index in range(self.nr_of_inputs): |
|
batches_x.append(np.zeros(tuple([len(index_array)] + list(self.input_shape[input_index])))) |
|
for output_index in range(self.nr_of_outputs): |
|
batches_y.append(np.zeros(tuple([len(index_array)] + list(self.output_shape[output_index])))) |
|
|
|
timings_sampling = np.zeros((len(index_array,))) |
|
timings_transform = np.zeros((len(index_array,))) |
|
for batch_index, sample_index in enumerate(index_array): |
|
# Have to copy here in order to not modify original data |
|
start = time.time() |
|
if self.all_files_in_batch: |
|
input, output = self._get_random_sample_in_file(batch_index) |
|
else: |
|
input, output = self._get_sample(sample_index) |
|
timings_sampling[batch_index] = time.time() - start |
|
start = time.time() |
|
input, output = self.generator.transform(input, output) |
|
timings_transform[batch_index] = time.time() - start |
|
|
|
#print('inputs', self.nr_of_inputs, len(input)) |
|
for input_index in range(self.nr_of_inputs): |
|
batches_x[input_index][batch_index] = input[input_index] |
|
for output_index in range(self.nr_of_outputs): |
|
batches_y[output_index][batch_index] = output[output_index] |
|
|
|
elapsed = time.time() - start_batch |
|
if self.verbose: |
|
print('Time to prepare batch:', round(elapsed,3), 'seconds') |
|
print('Sampling mean:', round(timings_sampling.mean(), 3), 'seconds') |
|
print('Transform mean:', round(timings_transform.mean(), 3), 'seconds') |
|
|
|
return batches_x, batches_y |
|
|
|
|
|
CLASSIFICATION = 'classification' |
|
SEGMENTATION = 'segmentation' |
|
|
|
|
|
class BatchGenerator(): |
|
def __init__(self, filelist, all_files_in_batch=False): |
|
self.methods = [] |
|
self.args = [] |
|
self.crop_width_to = None |
|
self.image_list = [] |
|
self.input_shape = [] |
|
self.output_shape = [] |
|
self.all_files_in_batch = all_files_in_batch |
|
self.transforms = [] |
|
|
|
if all_files_in_batch: |
|
file = h5py.File(filelist[0], 'r') |
|
for name, data in file['input'].items(): |
|
self.input_shape.append(data.shape[1:]) |
|
for name, data in file['output'].items(): |
|
self.output_shape.append(data.shape[1:]) |
|
# TODO fix |
|
#self.output_shape.append(file['output'].shape[1:]) |
|
file.close() |
|
self.image_list = filelist |
|
return |
|
|
|
# Go through filelist |
|
first = True |
|
for filename in filelist: |
|
samples = None |
|
# Open file to see how many samples it has |
|
file = h5py.File(filename, 'r') |
|
for name, data in file['input'].items(): |
|
if first: |
|
self.input_shape.append(data.shape[1:]) |
|
samples = data.shape[0] |
|
# TODO fix |
|
for name, data in file['output'].items(): |
|
if first: |
|
self.output_shape.append(data.shape[1:]) |
|
if samples != data.shape[0]: |
|
raise ValueError() |
|
#self.output_shape.append(file['output'].shape[1:]) |
|
if len(self.output_shape) == 1: |
|
self.problem_type = CLASSIFICATION |
|
else: |
|
self.problem_type = SEGMENTATION |
|
|
|
file.close() |
|
if samples is None: |
|
raise ValueError() |
|
# Append a tuple to image_list for each image consisting of filename and index |
|
print(filename, samples) |
|
for i in range(samples): |
|
self.image_list.append((filename, i)) |
|
first = False |
|
|
|
print('Image generator with', len(self.image_list), ' image samples created') |
|
|
|
def flow(self, batch_size, shuffle=True): |
|
|
|
return BatchIterator(self, self.image_list, self.input_shape, self.output_shape, batch_size, shuffle, self.all_files_in_batch) |
|
|
|
def transform(self, inputs, outputs): |
|
#input = input.astype(np.float32) # TODO |
|
#output = output.astype(np.float32) |
|
for input_indices, output_indices, transform in self.transforms: |
|
transform.randomize() |
|
inputs, outputs = transform.transform_all(inputs, outputs, input_indices, output_indices) |
|
return inputs, outputs |
|
|
|
def add_transform(self, input_indices: Union[int, List[int], None], output_indices: Union[int, List[int], None], transform: Transform): |
|
if type(input_indices) is int: |
|
input_indices = [input_indices] |
|
if type(output_indices) is int: |
|
output_indices = [output_indices] |
|
|
|
self.transforms.append(( |
|
input_indices, |
|
output_indices, |
|
transform |
|
)) |
|
|
|
def get_size(self): |
|
if self.all_files_in_batch: |
|
return 10*len(self.image_list) |
|
else: |
|
return len(self.image_list) |
|
|
|
''' |
|
|