Generalized the DataGenerator to accept two lists of input and output (wrt network) labels to fetch from the dataset files.
ca253db
import numpy as np | |
from tensorflow import keras | |
import os | |
import h5py | |
import random | |
from PIL import Image | |
import DeepDeformationMapRegistration.utils.constants as C | |
from DeepDeformationMapRegistration.utils.operators import min_max_norm | |
class DataGeneratorManager(keras.utils.Sequence): | |
def __init__(self, dataset_path, batch_size=32, shuffle=True, | |
num_samples=None, validation_split=None, validation_samples=None, clip_range=[0., 1.], | |
input_labels=[C.H5_MOV_IMG, C.H5_FIX_IMG], output_labels=[C.H5_FIX_IMG, 'zero_gradient']): | |
# Get the list of files | |
self.__list_files = self.__get_dataset_files(dataset_path) | |
self.__list_files.sort() | |
self.__dataset_path = dataset_path | |
self.__shuffle = shuffle | |
self.__total_samples = len(self.__list_files) | |
self.__validation_split = validation_split | |
self.__clip_range = clip_range | |
self.__batch_size = batch_size | |
self.__validation_samples = validation_samples | |
self.__input_labels = input_labels | |
self.__output_labels = output_labels | |
if num_samples is not None: | |
self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples | |
else: | |
self.__num_samples = self.__total_samples | |
self.__internal_idxs = np.arange(self.__num_samples) | |
# Split it accordingly | |
if validation_split is None: | |
self.__validation_num_samples = None | |
self.__validation_idxs = list() | |
if self.__shuffle: | |
random.shuffle(self.__internal_idxs) | |
self.__training_idxs = self.__internal_idxs | |
self.__validation_generator = None | |
else: | |
self.__validation_num_samples = int(np.ceil(self.__num_samples * validation_split)) | |
if self.__shuffle: | |
self.__validation_idxs = np.random.choice(self.__internal_idxs, self.__validation_num_samples) | |
else: | |
self.__validation_idxs = self.__internal_idxs[0: self.__validation_num_samples] | |
self.__training_idxs = np.asarray([idx for idx in self.__internal_idxs if idx not in self.__validation_idxs]) | |
# Build them DataGenerators | |
self.__validation_generator = DataGenerator(self, 'validation') | |
self.__train_generator = DataGenerator(self, 'train') | |
self.reshuffle_indices() | |
def dataset_path(self): | |
return self.__dataset_path | |
def dataset_list_files(self): | |
return self.__list_files | |
def train_idxs(self): | |
return self.__training_idxs | |
def validation_idxs(self): | |
return self.__validation_idxs | |
def batch_size(self): | |
return self.__batch_size | |
def clip_rage(self): | |
return self.__clip_range | |
def shuffle(self): | |
return self.__shuffle | |
def input_labels(self): | |
return self.__input_labels | |
def output_labels(self): | |
return self.__output_labels | |
def get_generator_idxs(self, generator_type): | |
if generator_type == 'train': | |
return self.train_idxs | |
elif generator_type == 'validation': | |
return self.validation_idxs | |
else: | |
raise ValueError('Invalid generator type: ', generator_type) | |
def __get_dataset_files(search_path): | |
""" | |
Get the path to the dataset files | |
:param search_path: dir path to search for the hd5 files | |
:return: | |
""" | |
file_list = list() | |
for root, dirs, files in os.walk(search_path): | |
file_list.sort() | |
for data_file in files: | |
file_name, extension = os.path.splitext(data_file) | |
if extension.lower() == '.hd5': | |
file_list.append(os.path.join(root, data_file)) | |
if not file_list: | |
raise ValueError('No files found to train in ', search_path) | |
print('Found {} files in {}'.format(len(file_list), search_path)) | |
return file_list | |
def reshuffle_indices(self): | |
if self.__validation_num_samples is None: | |
if self.__shuffle: | |
random.shuffle(self.__internal_idxs) | |
self.__training_idxs = self.__internal_idxs | |
else: | |
if self.__shuffle: | |
self.__validation_idxs = np.random.choice(self.__internal_idxs, self.__validation_num_samples) | |
else: | |
self.__validation_idxs = self.__internal_idxs[0: self.__validation_num_samples] | |
self.__training_idxs = np.asarray([idx for idx in self.__internal_idxs if idx not in self.__validation_idxs]) | |
# Update the indices | |
self.__validation_generator.update_samples(self.__validation_idxs) | |
self.__train_generator.update_samples(self.__training_idxs) | |
def get_generator(self, type='train'): | |
if type.lower() == 'train': | |
return self.__train_generator | |
elif type.lower() == 'validation': | |
if self.__validation_generator is not None: | |
return self.__validation_generator | |
else: | |
raise Warning('No validation generator available. Set a non-zero validation_split to build one.') | |
else: | |
raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type)) | |
class DataGenerator(DataGeneratorManager): | |
def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'): | |
self.__complete_list_files = GeneratorManager.dataset_list_files | |
self.__list_files = [self.__complete_list_files[idx] for idx in GeneratorManager.get_generator_idxs(dataset_type)] | |
self.__batch_size = GeneratorManager.batch_size | |
self.__total_samples = len(self.__list_files) | |
self.__clip_range = GeneratorManager.clip_rage | |
self.__manager = GeneratorManager | |
self.__shuffle = GeneratorManager.shuffle | |
self.__num_samples = len(self.__list_files) | |
self.__internal_idxs = np.arange(self.__num_samples) | |
# These indices are internal to the generator, they are not the same as the dataset_idxs!! | |
self.__dataset_type = dataset_type | |
self.__last_batch = 0 | |
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size)) | |
self.__input_labels = GeneratorManager.input_labels | |
self.__output_labels = GeneratorManager.output_labels | |
def __get_dataset_files(search_path): | |
""" | |
Get the path to the dataset files | |
:param search_path: dir path to search for the hd5 files | |
:return: | |
""" | |
file_list = list() | |
for root, dirs, files in os.walk(search_path): | |
for data_file in files: | |
file_name, extension = os.path.splitext(data_file) | |
if extension.lower() == '.hd5': | |
file_list.append(os.path.join(root, data_file)) | |
if not file_list: | |
raise ValueError('No files found to train in ', search_path) | |
print('Found {} files in {}'.format(len(file_list), search_path)) | |
return file_list | |
def update_samples(self, new_sample_idxs): | |
self.__list_files = [self.__complete_list_files[idx] for idx in new_sample_idxs] | |
self.__num_samples = len(self.__list_files) | |
self.__internal_idxs = np.arange(self.__num_samples) | |
def on_epoch_end(self): | |
""" | |
To be executed at the end of each epoch. Reshuffle the assigned samples | |
:return: | |
""" | |
if self.__shuffle: | |
random.shuffle(self.__internal_idxs) | |
self.__last_batch = 0 | |
def __len__(self): | |
""" | |
Number of batches per epoch | |
:return: | |
""" | |
return self.__batches_per_epoch | |
def __build_list(data_dict, labels): | |
ret_list = list() | |
for label in labels: | |
if label in data_dict.keys(): | |
if label in [C.DG_LBL_FIX_IMG, C.DG_LBL_MOV_IMG]: | |
ret_list.append(min_max_norm(data_dict[label]).astype(np.float32)) | |
elif label in [C.DG_LBL_FIX_PARENCHYMA, C.DG_LBL_FIX_VESSELS, C.DG_LBL_FIX_TUMOR, | |
C.DG_LBL_MOV_PARENCHYMA, C.DG_LBL_MOV_VESSELS, C.DG_LBL_MOV_TUMOR]: | |
aux = data_dict[label] | |
aux[aux > 0.] = 1. | |
ret_list.append(aux) | |
elif label == C.DG_LBL_ZERO_GRADS: | |
ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE])) | |
return ret_list | |
def __getitem__(self, index): | |
""" | |
Generate one batch of data | |
:param index: epoch index | |
:return: | |
""" | |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size] | |
data_dict = self.__load_data(idxs) | |
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit | |
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights) | |
# The second element must match the outputs of the model, in this case (image, displacement map) | |
inputs = self.__build_list(data_dict, self.__input_labels) | |
outputs = self.__build_list(data_dict, self.__output_labels) | |
return (inputs, outputs) | |
def next_batch(self): | |
if self.__last_batch > self.__batches_per_epoch: | |
raise ValueError('No more batches for this epoch') | |
batch = self.__getitem__(self.__last_batch) | |
self.__last_batch += 1 | |
return batch | |
def __try_load(self, data_file, label, append_array=None): | |
if label in self.__input_labels or label in self.__output_labels: | |
# To avoid extra overhead | |
try: | |
retVal = data_file[label][:] | |
except KeyError: | |
# That particular label is not found in the file. But this should be known by the user by now | |
retVal = None | |
if append_array is not None and retVal is not None: | |
return np.append(append_array, [data_file[C.H5_FIX_IMG][:]], axis=0) | |
elif append_array is None: | |
return retVal[np.newaxis, ...] | |
else: | |
return retVal # None | |
else: | |
return None | |
def __load_data(self, idx_list): | |
""" | |
Build the batch with the samples in idx_list | |
:param idx_list: | |
:return: | |
""" | |
if isinstance(idx_list, (list, np.ndarray)): | |
fix_img = np.empty((0, ) + C.IMG_SHAPE) | |
mov_img = np.empty((0, ) + C.IMG_SHAPE) | |
fix_parench = np.empty((0, ) + C.IMG_SHAPE) | |
mov_parench = np.empty((0, ) + C.IMG_SHAPE) | |
fix_vessels = np.empty((0, ) + C.IMG_SHAPE) | |
mov_vessels = np.empty((0, ) + C.IMG_SHAPE) | |
fix_tumors = np.empty((0, ) + C.IMG_SHAPE) | |
mov_tumors = np.empty((0, ) + C.IMG_SHAPE) | |
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE) | |
for idx in idx_list: | |
data_file = h5py.File(self.__list_files[idx], 'r') | |
fix_img = self.__try_load(data_file, C.H5_FIX_IMG, fix_img) | |
mov_img = self.__try_load(data_file, C.H5_MOV_IMG, mov_img) | |
fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK, fix_parench) | |
mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK, mov_parench) | |
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels) | |
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels) | |
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, mov_parench) | |
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_parench) | |
disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map) | |
data_file.close() | |
batch_size = len(idx_list) | |
else: | |
data_file = h5py.File(self.__list_files[idx_list], 'r') | |
fix_img = self.__try_load(data_file, C.H5_FIX_IMG) | |
mov_img = self.__try_load(data_file, C.H5_MOV_IMG) | |
fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK) | |
mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK) | |
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK) | |
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK) | |
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK) | |
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK) | |
disp_map = self.__try_load(data_file, C.H5_GT_DISP) | |
data_file.close() | |
batch_size = 1 | |
data_dict = {C.H5_FIX_IMG: fix_img, | |
C.H5_FIX_TUMORS_MASK: fix_tumors, | |
C.H5_FIX_VESSELS_MASK: fix_vessels, | |
C.H5_FIX_PARENCHYMA_MASK: fix_parench, | |
C.H5_MOV_IMG: mov_img, | |
C.H5_MOV_TUMORS_MASK: mov_tumors, | |
C.H5_MOV_VESSELS_MASK: mov_vessels, | |
C.H5_MOV_PARENCHYMA_MASK: mov_parench, | |
C.H5_GT_DISP: disp_map, | |
'BATCH_SIZE': batch_size | |
} | |
return data_dict | |
def get_samples(self, num_samples, random=False): | |
if random: | |
idxs = np.random.randint(0, self.__num_samples, num_samples) | |
else: | |
idxs = np.arange(0, num_samples) | |
data_dict = self.__load_data(idxs) | |
# return X, y | |
return self.__build_list(data_dict, self.__input_labels), self.__build_list(data_dict, self.__output_labels) | |
def get_input_shape(self): | |
input_batch, _ = self.__getitem__(0) | |
data_dict = self.__load_data(0) | |
ret_val = data_dict[self.__input_labels[0]].shape | |
ret_val = (None, ) + ret_val[1:] | |
return ret_val # const.BATCH_SHAPE_SEGM | |
def who_are_you(self): | |
return self.__dataset_type | |
def print_datafiles(self): | |
return self.__list_files | |
class DataGeneratorManager2D: | |
FIX_IMG_H5 = 'input/1' | |
MOV_IMG_H5 = 'input/0' | |
def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None, | |
fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False): | |
self.__file_list = h5_file_list #h5py.File(h5_file, 'r') | |
self.__batch_size = batch_size | |
self.__data_split = data_split | |
self.__initialize() | |
self.__train_generator = DataGenerator2D(self.__train_file_list, | |
batch_size=self.__batch_size, | |
img_size=img_size, | |
fix_img_tag=fix_img_tag, | |
mov_img_tag=mov_img_tag, | |
multi_loss=multi_loss) | |
self.__val_generator = DataGenerator2D(self.__val_file_list, | |
batch_size=self.__batch_size, | |
img_size=img_size, | |
fix_img_tag=fix_img_tag, | |
mov_img_tag=mov_img_tag, | |
multi_loss=multi_loss) | |
def __initialize(self): | |
num_samples = len(self.__file_list) | |
random.shuffle(self.__file_list) | |
data_split = int(np.floor(num_samples * self.__data_split)) | |
self.__val_file_list = self.__file_list[0:data_split] | |
self.__train_file_list = self.__file_list[data_split:] | |
def train_generator(self): | |
return self.__train_generator | |
def validation_generator(self): | |
return self.__val_generator | |
class DataGenerator2D(keras.utils.Sequence): | |
FIX_IMG_H5 = 'input/1' | |
MOV_IMG_H5 = 'input/0' | |
def __init__(self, file_list: list, batch_size=32, img_size=None, fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False): | |
self.__file_list = file_list # h5py.File(h5_file, 'r') | |
self.__file_list.sort() | |
self.__batch_size = batch_size | |
self.__idx_list = np.arange(0, len(self.__file_list)) | |
self.__multi_loss = multi_loss | |
self.__tags = {'fix_img': fix_img_tag, | |
'mov_img': mov_img_tag} | |
self.__batches_seen = 0 | |
self.__batches_per_epoch = 0 | |
self.__img_size = img_size | |
self.__initialize() | |
def __len__(self): | |
return self.__batches_per_epoch | |
def __initialize(self): | |
random.shuffle(self.__idx_list) | |
if self.__img_size is None: | |
f = h5py.File(self.__file_list[0], 'r') | |
self.input_shape = f[self.__tags['fix_img']].shape # Already defined in super() | |
f.close() | |
else: | |
self.input_shape = self.__img_size | |
if self.__multi_loss: | |
self.input_shape = (self.input_shape, (*self.input_shape[:-1], 2)) | |
self.__batches_per_epoch = int(np.ceil(len(self.__file_list) / self.__batch_size)) | |
def __load_and_preprocess(self, fh, tag): | |
img = fh[tag][:] | |
if (self.__img_size is not None) and (img[..., 0].shape != self.__img_size): | |
im = Image.fromarray(img[..., 0]) # Can't handle the 1 channel | |
img = np.array(im.resize(self.__img_size[:-1], Image.LANCZOS)).astype(np.float32) | |
img = img[..., np.newaxis] | |
if img.max() > 1. or img.min() < 0.: | |
try: | |
img = min_max_norm(img).astype(np.float32) | |
except ValueError: | |
print(fh, tag, img.shape) | |
er_str = 'ERROR:\t[file]:\t{}\t[tag]:\t{}\t[img.shape]:\t{}\t'.format(fh, tag, img.shape) | |
raise ValueError(er_str) | |
return img.astype(np.float32) | |
def __getitem__(self, idx): | |
idxs = self.__idx_list[idx * self.__batch_size:(idx + 1) * self.__batch_size] | |
fix_imgs, mov_imgs = self.__load_samples(idxs) | |
zero_grad = np.zeros((*fix_imgs.shape[:-1], 2)) | |
inputs = [mov_imgs, fix_imgs] | |
outputs = [fix_imgs, zero_grad] | |
if self.__multi_loss: | |
return [mov_imgs, fix_imgs, zero_grad], | |
else: | |
return (inputs, outputs) | |
def __load_samples(self, idx_list): | |
if self.__multi_loss: | |
img_shape = (0, *self.input_shape[0]) | |
else: | |
img_shape = (0, *self.input_shape) | |
fix_imgs = np.empty(img_shape) | |
mov_imgs = np.empty(img_shape) | |
for i in idx_list: | |
f = h5py.File(self.__file_list[i], 'r') | |
fix_imgs = np.append(fix_imgs, [self.__load_and_preprocess(f, self.__tags['fix_img'])], axis=0) | |
mov_imgs = np.append(mov_imgs, [self.__load_and_preprocess(f, self.__tags['mov_img'])], axis=0) | |
f.close() | |
return fix_imgs, mov_imgs | |
def on_epoch_end(self): | |
np.random.shuffle(self.__idx_list) | |
def get_single_sample(self): | |
idx = random.randint(0, len(self.__idx_list)) | |
fix, mov = self.__load_samples([idx]) | |
return mov, fix | |