jpdefrutos commited on
Commit
ca253db
·
1 Parent(s): 8fb4d8e

Generalized the DataGenerator to accept two lists of input and output (wrt network) labels to fetch from the dataset files.

Browse files
DeepDeformationMapRegistration/data_generator.py CHANGED
@@ -1,10 +1,3 @@
1
- import sys, os
2
- currentdir = os.path.dirname(os.path.realpath(__file__))
3
- parentdir = os.path.dirname(currentdir)
4
- sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
-
6
- PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
-
8
  import numpy as np
9
  from tensorflow import keras
10
  import os
@@ -17,9 +10,9 @@ from DeepDeformationMapRegistration.utils.operators import min_max_norm
17
 
18
 
19
  class DataGeneratorManager(keras.utils.Sequence):
20
- def __init__(self, dataset_path, batch_size=32, shuffle=True, num_samples=None, validation_split=None, validation_samples=None,
21
- clip_range=[0., 1.], voxelmorph=False, segmentations=False,
22
- seg_labels: dict = {'bg': 0, 'vessels': 1, 'tumour': 2, 'parenchyma': 3}):
23
  # Get the list of files
24
  self.__list_files = self.__get_dataset_files(dataset_path)
25
  self.__list_files.sort()
@@ -32,9 +25,8 @@ class DataGeneratorManager(keras.utils.Sequence):
32
 
33
  self.__validation_samples = validation_samples
34
 
35
- self.__voxelmorph = voxelmorph
36
- self.__segmentations = segmentations
37
- self.__seg_labels = seg_labels
38
 
39
  if num_samples is not None:
40
  self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
@@ -93,6 +85,14 @@ class DataGeneratorManager(keras.utils.Sequence):
93
  def shuffle(self):
94
  return self.__shuffle
95
 
 
 
 
 
 
 
 
 
96
  def get_generator_idxs(self, generator_type):
97
  if generator_type == 'train':
98
  return self.train_idxs
@@ -150,18 +150,6 @@ class DataGeneratorManager(keras.utils.Sequence):
150
  else:
151
  raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
152
 
153
- @property
154
- def is_voxelmorph(self):
155
- return self.__voxelmorph
156
-
157
- @property
158
- def give_segmentations(self):
159
- return self.__segmentations
160
-
161
- @property
162
- def seg_labels(self):
163
- return self.__seg_labels
164
-
165
 
166
  class DataGenerator(DataGeneratorManager):
167
  def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
@@ -173,8 +161,6 @@ class DataGenerator(DataGeneratorManager):
173
  self.__manager = GeneratorManager
174
  self.__shuffle = GeneratorManager.shuffle
175
 
176
- self.__seg_labels = GeneratorManager.seg_labels
177
-
178
  self.__num_samples = len(self.__list_files)
179
  self.__internal_idxs = np.arange(self.__num_samples)
180
  # These indices are internal to the generator, they are not the same as the dataset_idxs!!
@@ -184,8 +170,8 @@ class DataGenerator(DataGeneratorManager):
184
  self.__last_batch = 0
185
  self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
186
 
187
- self.__voxelmorph = GeneratorManager.is_voxelmorph
188
- self.__segmentations = GeneratorManager.is_voxelmorph and GeneratorManager.give_segmentations
189
 
190
  @staticmethod
191
  def __get_dataset_files(search_path):
@@ -228,6 +214,22 @@ class DataGenerator(DataGeneratorManager):
228
  """
229
  return self.__batches_per_epoch
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def __getitem__(self, index):
232
  """
233
  Generate one batch of data
@@ -236,36 +238,15 @@ class DataGenerator(DataGeneratorManager):
236
  """
237
  idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
238
 
239
- fix_img, mov_img, fix_vessels, mov_vessels, fix_tumour, mov_tumour, disp_map = self.__load_data(idxs)
240
-
241
- try:
242
- fix_img = min_max_norm(fix_img).astype(np.float32)
243
- mov_img = min_max_norm(mov_img).astype(np.float32)
244
- except ValueError:
245
- print(idxs, fix_img.shape, mov_img.shape)
246
- er_str = 'ERROR:\t[file]:\t{}\t[idx]:\t{}\t[fix_img.shape]:\t{}\t[mov_img.shape]:\t{}\t'.format(self.__list_files[idxs], idxs, fix_img.shape, mov_img.shape)
247
- raise ValueError(er_str)
248
-
249
- fix_vessels[fix_vessels > 0.] = self.__seg_labels['vessels']
250
- mov_vessels[mov_vessels > 0.] = self.__seg_labels['vessels']
251
-
252
- # fix_tumour[fix_tumour > 0.] = self.__seg_labels['tumour']
253
- # mov_tumour[mov_tumour > 0.] = self.__seg_labels['tumour']
254
 
255
  # https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
256
  # A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
257
  # The second element must match the outputs of the model, in this case (image, displacement map)
258
- if self.__voxelmorph:
259
- zero_grad = np.zeros([fix_img.shape[0], *C.DISP_MAP_SHAPE])
260
- if self.__segmentations:
261
- inputs = [mov_vessels, fix_vessels, mov_img, fix_img, zero_grad]
262
- outputs = [] #[fix_img, zero_grad]
263
- else:
264
- inputs = [mov_img, fix_img]
265
- outputs = [fix_img, zero_grad]
266
- return (inputs, outputs)
267
- else:
268
- return (fix_img, mov_img, fix_vessels, mov_vessels), # (None, fix_seg, fix_seg, fix_img)
269
 
270
  def next_batch(self):
271
  if self.__last_batch > self.__batches_per_epoch:
@@ -274,6 +255,24 @@ class DataGenerator(DataGeneratorManager):
274
  self.__last_batch += 1
275
  return batch
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  def __load_data(self, idx_list):
278
  """
279
  Build the batch with the samples in idx_list
@@ -283,73 +282,87 @@ class DataGenerator(DataGeneratorManager):
283
  if isinstance(idx_list, (list, np.ndarray)):
284
  fix_img = np.empty((0, ) + C.IMG_SHAPE)
285
  mov_img = np.empty((0, ) + C.IMG_SHAPE)
286
- disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
287
 
288
- # fix_segm = np.empty((0, ) + const.IMG_SHAPE)
289
- # mov_segm = np.empty((0, ) + const.IMG_SHAPE)
290
 
291
  fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
292
  mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
 
293
  fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
294
  mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
 
 
 
295
  for idx in idx_list:
296
  data_file = h5py.File(self.__list_files[idx], 'r')
297
 
298
- fix_img = np.append(fix_img, [data_file[C.H5_FIX_IMG][:]], axis=0)
299
- mov_img = np.append(mov_img, [data_file[C.H5_MOV_IMG][:]], axis=0)
 
 
 
300
 
301
- # fix_segm = np.append(fix_segm, [data_file[const.H5_FIX_PARENCHYMA_MASK][:]], axis=0)
302
- # mov_segm = np.append(mov_segm, [data_file[const.H5_MOV_PARENCHYMA_MASK][:]], axis=0)
303
 
304
- disp_map = np.append(disp_map, [data_file[C.H5_GT_DISP][:]], axis=0)
 
305
 
306
- fix_vessels = np.append(fix_vessels, [data_file[C.H5_FIX_VESSELS_MASK][:]], axis=0)
307
- mov_vessels = np.append(mov_vessels, [data_file[C.H5_MOV_VESSELS_MASK][:]], axis=0)
308
- fix_tumors = np.append(fix_tumors, [data_file[C.H5_FIX_TUMORS_MASK][:]], axis=0)
309
- mov_tumors = np.append(mov_tumors, [data_file[C.H5_MOV_TUMORS_MASK][:]], axis=0)
310
 
311
  data_file.close()
 
312
  else:
313
  data_file = h5py.File(self.__list_files[idx_list], 'r')
314
 
315
- fix_img = np.expand_dims(data_file[C.H5_FIX_IMG][:], 0)
316
- mov_img = np.expand_dims(data_file[C.H5_MOV_IMG][:], 0)
317
 
318
- # fix_segm = np.expand_dims(data_file[const.H5_FIX_PARENCHYMA_MASK][:], 0)
319
- # mov_segm = np.expand_dims(data_file[const.H5_MOV_PARENCHYMA_MASK][:], 0)
320
 
321
- fix_vessels = np.expand_dims(data_file[C.H5_FIX_VESSELS_MASK][:], axis=0)
322
- mov_vessels = np.expand_dims(data_file[C.H5_MOV_VESSELS_MASK][:], axis=0)
323
- fix_tumors = np.expand_dims(data_file[C.H5_FIX_TUMORS_MASK][:], axis=0)
324
- mov_tumors = np.expand_dims(data_file[C.H5_MOV_TUMORS_MASK][:], axis=0)
325
 
326
- disp_map = np.expand_dims(data_file[C.H5_GT_DISP][:], 0)
 
327
 
328
- data_file.close()
329
-
330
- return fix_img, mov_img, fix_vessels, mov_vessels, fix_tumors, mov_tumors, disp_map
331
 
332
- def get_single_sample(self):
333
- fix_img, mov_img, fix_segm, mov_segm, _ = self.__load_data(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # return X, y
335
- return np.expand_dims(np.concatenate([mov_img, fix_img, mov_segm, mov_segm], axis=-1), axis=0)
336
-
337
- def get_random_sample(self, num_samples):
338
- idxs = np.random.randint(0, self.__num_samples, num_samples)
339
- fix_img, mov_img, fix_segm, mov_segm, disp_map = self.__load_data(idxs)
340
-
341
- return (fix_img, mov_img, fix_segm, mov_segm, disp_map), [self.__list_files[f] for f in idxs]
342
 
343
  def get_input_shape(self):
344
  input_batch, _ = self.__getitem__(0)
345
- if self.__voxelmorph:
346
- ret_val = list(input_batch[0].shape)
347
- ret_val[-1] = 2
348
- ret_val = (None, ) + tuple(ret_val[1:])
349
- else:
350
- ret_val = input_batch.shape
351
- ret_val = (None, ) + ret_val[1:]
352
- return ret_val # const.BATCH_SHAPE_SEGM
353
 
354
  def who_are_you(self):
355
  return self.__dataset_type
@@ -361,6 +374,7 @@ class DataGenerator(DataGeneratorManager):
361
  class DataGeneratorManager2D:
362
  FIX_IMG_H5 = 'input/1'
363
  MOV_IMG_H5 = 'input/0'
 
364
  def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
365
  fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
366
  self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  from tensorflow import keras
3
  import os
 
10
 
11
 
12
  class DataGeneratorManager(keras.utils.Sequence):
13
+ def __init__(self, dataset_path, batch_size=32, shuffle=True,
14
+ num_samples=None, validation_split=None, validation_samples=None, clip_range=[0., 1.],
15
+ input_labels=[C.H5_MOV_IMG, C.H5_FIX_IMG], output_labels=[C.H5_FIX_IMG, 'zero_gradient']):
16
  # Get the list of files
17
  self.__list_files = self.__get_dataset_files(dataset_path)
18
  self.__list_files.sort()
 
25
 
26
  self.__validation_samples = validation_samples
27
 
28
+ self.__input_labels = input_labels
29
+ self.__output_labels = output_labels
 
30
 
31
  if num_samples is not None:
32
  self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
 
85
  def shuffle(self):
86
  return self.__shuffle
87
 
88
+ @property
89
+ def input_labels(self):
90
+ return self.__input_labels
91
+
92
+ @property
93
+ def output_labels(self):
94
+ return self.__output_labels
95
+
96
  def get_generator_idxs(self, generator_type):
97
  if generator_type == 'train':
98
  return self.train_idxs
 
150
  else:
151
  raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
152
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  class DataGenerator(DataGeneratorManager):
155
  def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
 
161
  self.__manager = GeneratorManager
162
  self.__shuffle = GeneratorManager.shuffle
163
 
 
 
164
  self.__num_samples = len(self.__list_files)
165
  self.__internal_idxs = np.arange(self.__num_samples)
166
  # These indices are internal to the generator, they are not the same as the dataset_idxs!!
 
170
  self.__last_batch = 0
171
  self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
172
 
173
+ self.__input_labels = GeneratorManager.input_labels
174
+ self.__output_labels = GeneratorManager.output_labels
175
 
176
  @staticmethod
177
  def __get_dataset_files(search_path):
 
214
  """
215
  return self.__batches_per_epoch
216
 
217
+ @staticmethod
218
+ def __build_list(data_dict, labels):
219
+ ret_list = list()
220
+ for label in labels:
221
+ if label in data_dict.keys():
222
+ if label in [C.DG_LBL_FIX_IMG, C.DG_LBL_MOV_IMG]:
223
+ ret_list.append(min_max_norm(data_dict[label]).astype(np.float32))
224
+ elif label in [C.DG_LBL_FIX_PARENCHYMA, C.DG_LBL_FIX_VESSELS, C.DG_LBL_FIX_TUMOR,
225
+ C.DG_LBL_MOV_PARENCHYMA, C.DG_LBL_MOV_VESSELS, C.DG_LBL_MOV_TUMOR]:
226
+ aux = data_dict[label]
227
+ aux[aux > 0.] = 1.
228
+ ret_list.append(aux)
229
+ elif label == C.DG_LBL_ZERO_GRADS:
230
+ ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
231
+ return ret_list
232
+
233
  def __getitem__(self, index):
234
  """
235
  Generate one batch of data
 
238
  """
239
  idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
240
 
241
+ data_dict = self.__load_data(idxs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  # https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
244
  # A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
245
  # The second element must match the outputs of the model, in this case (image, displacement map)
246
+ inputs = self.__build_list(data_dict, self.__input_labels)
247
+ outputs = self.__build_list(data_dict, self.__output_labels)
248
+
249
+ return (inputs, outputs)
 
 
 
 
 
 
 
250
 
251
  def next_batch(self):
252
  if self.__last_batch > self.__batches_per_epoch:
 
255
  self.__last_batch += 1
256
  return batch
257
 
258
+ def __try_load(self, data_file, label, append_array=None):
259
+ if label in self.__input_labels or label in self.__output_labels:
260
+ # To avoid extra overhead
261
+ try:
262
+ retVal = data_file[label][:]
263
+ except KeyError:
264
+ # That particular label is not found in the file. But this should be known by the user by now
265
+ retVal = None
266
+
267
+ if append_array is not None and retVal is not None:
268
+ return np.append(append_array, [data_file[C.H5_FIX_IMG][:]], axis=0)
269
+ elif append_array is None:
270
+ return retVal[np.newaxis, ...]
271
+ else:
272
+ return retVal # None
273
+ else:
274
+ return None
275
+
276
  def __load_data(self, idx_list):
277
  """
278
  Build the batch with the samples in idx_list
 
282
  if isinstance(idx_list, (list, np.ndarray)):
283
  fix_img = np.empty((0, ) + C.IMG_SHAPE)
284
  mov_img = np.empty((0, ) + C.IMG_SHAPE)
 
285
 
286
+ fix_parench = np.empty((0, ) + C.IMG_SHAPE)
287
+ mov_parench = np.empty((0, ) + C.IMG_SHAPE)
288
 
289
  fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
290
  mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
291
+
292
  fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
293
  mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
294
+
295
+ disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
296
+
297
  for idx in idx_list:
298
  data_file = h5py.File(self.__list_files[idx], 'r')
299
 
300
+ fix_img = self.__try_load(data_file, C.H5_FIX_IMG, fix_img)
301
+ mov_img = self.__try_load(data_file, C.H5_MOV_IMG, mov_img)
302
+
303
+ fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK, fix_parench)
304
+ mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK, mov_parench)
305
 
306
+ fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
307
+ mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
308
 
309
+ fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, mov_parench)
310
+ mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_parench)
311
 
312
+ disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
 
 
 
313
 
314
  data_file.close()
315
+ batch_size = len(idx_list)
316
  else:
317
  data_file = h5py.File(self.__list_files[idx_list], 'r')
318
 
319
+ fix_img = self.__try_load(data_file, C.H5_FIX_IMG)
320
+ mov_img = self.__try_load(data_file, C.H5_MOV_IMG)
321
 
322
+ fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK)
323
+ mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK)
324
 
325
+ fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK)
326
+ mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK)
 
 
327
 
328
+ fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK)
329
+ mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK)
330
 
331
+ disp_map = self.__try_load(data_file, C.H5_GT_DISP)
 
 
332
 
333
+ data_file.close()
334
+ batch_size = 1
335
+
336
+ data_dict = {C.H5_FIX_IMG: fix_img,
337
+ C.H5_FIX_TUMORS_MASK: fix_tumors,
338
+ C.H5_FIX_VESSELS_MASK: fix_vessels,
339
+ C.H5_FIX_PARENCHYMA_MASK: fix_parench,
340
+ C.H5_MOV_IMG: mov_img,
341
+ C.H5_MOV_TUMORS_MASK: mov_tumors,
342
+ C.H5_MOV_VESSELS_MASK: mov_vessels,
343
+ C.H5_MOV_PARENCHYMA_MASK: mov_parench,
344
+ C.H5_GT_DISP: disp_map,
345
+ 'BATCH_SIZE': batch_size
346
+ }
347
+
348
+ return data_dict
349
+
350
+ def get_samples(self, num_samples, random=False):
351
+ if random:
352
+ idxs = np.random.randint(0, self.__num_samples, num_samples)
353
+ else:
354
+ idxs = np.arange(0, num_samples)
355
+ data_dict = self.__load_data(idxs)
356
  # return X, y
357
+ return self.__build_list(data_dict, self.__input_labels), self.__build_list(data_dict, self.__output_labels)
 
 
 
 
 
 
358
 
359
  def get_input_shape(self):
360
  input_batch, _ = self.__getitem__(0)
361
+ data_dict = self.__load_data(0)
362
+
363
+ ret_val = data_dict[self.__input_labels[0]].shape
364
+ ret_val = (None, ) + ret_val[1:]
365
+ return ret_val # const.BATCH_SHAPE_SEGM
 
 
 
366
 
367
  def who_are_you(self):
368
  return self.__dataset_type
 
374
  class DataGeneratorManager2D:
375
  FIX_IMG_H5 = 'input/1'
376
  MOV_IMG_H5 = 'input/0'
377
+
378
  def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
379
  fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
380
  self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
DeepDeformationMapRegistration/utils/constants.py CHANGED
@@ -65,6 +65,17 @@ MAX_FLIPS = 2 # Axes to flip over
65
  NUM_ROTATIONS = 5
66
  MAX_WORKERS = 10
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Training constants
69
  MODEL = 'unet'
70
  BATCH_NORM = False
@@ -478,10 +489,9 @@ EPS_1_tf = tf.constant(EPS_1)
478
  # LDDMM
479
  GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
480
 
481
- # Constants for MultiLoss layer
482
  PRIOR_W = [1., 1 / 60, 1.]
483
  MANUAL_W = [1.] * len(PRIOR_W)
484
 
485
  REG_PRIOR_W = [1e-3]
486
  REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
487
-
 
65
  NUM_ROTATIONS = 5
66
  MAX_WORKERS = 10
67
 
68
+ # Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
69
+ DG_LBL_FIX_IMG = H5_FIX_IMG
70
+ DG_LBL_FIX_VESSELS = H5_FIX_VESSELS_MASK
71
+ DG_LBL_FIX_PARENCHYMA = H5_FIX_PARENCHYMA_MASK
72
+ DG_LBL_FIX_TUMOR = H5_FIX_TUMORS_MASK
73
+ DG_LBL_MOV_IMG = H5_MOV_IMG
74
+ DG_LBL_MOV_VESSELS = H5_MOV_VESSELS_MASK
75
+ DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
76
+ DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
77
+ DG_LBL_ZERO_GRADS = 'zero_gradients'
78
+
79
  # Training constants
80
  MODEL = 'unet'
81
  BATCH_NORM = False
 
489
  # LDDMM
490
  GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
491
 
492
+ # Constants for Unsupervised Learning layer
493
  PRIOR_W = [1., 1 / 60, 1.]
494
  MANUAL_W = [1.] * len(PRIOR_W)
495
 
496
  REG_PRIOR_W = [1e-3]
497
  REG_MANUAL_W = [1.] * len(REG_PRIOR_W)