jpdefrutos commited on
Commit
74c6a32
·
1 Parent(s): f42fb70

Update DeepDeformationMapRegistration package

Browse files
DeepDeformationMapRegistration/callbacks.py CHANGED
@@ -1,5 +1,7 @@
1
  import tensorflow as tf
2
  import tensorflow.keras.backend as K
 
 
3
 
4
 
5
  class RollingAverageWeighting(tf.keras.callbacks.Callback):
@@ -43,3 +45,16 @@ class RollingAverageWeighting(tf.keras.callbacks.Callback):
43
  for name, val in zip(self.loss_weights.keys(), new_weights):
44
  out_str += '{}: {:7.2f}\t'.format(name, val)
45
  print('WEIGHTS UPDATE: ' + out_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import tensorflow as tf
2
  import tensorflow.keras.backend as K
3
+ import logging
4
+ import numpy as np
5
 
6
 
7
  class RollingAverageWeighting(tf.keras.callbacks.Callback):
 
45
  for name, val in zip(self.loss_weights.keys(), new_weights):
46
  out_str += '{}: {:7.2f}\t'.format(name, val)
47
  print('WEIGHTS UPDATE: ' + out_str)
48
+
49
+
50
+ class UncertaintyWeightingRollingAverageCallback(tf.keras.callbacks.Callback):
51
+ def __init__(self, method, epoch_update):
52
+ super(UncertaintyWeightingRollingAverageCallback, self).__init__()
53
+ self.method = method
54
+ self.epoch_update = epoch_update
55
+
56
+ def on_epoch_end(self, epoch, logs=None):
57
+ if epoch > self.epoch_update:
58
+ self.method()
59
+ print('Calling method: '+self.method.__name__)
60
+
DeepDeformationMapRegistration/data_generator.py CHANGED
@@ -4,9 +4,16 @@ import os
4
  import h5py
5
  import random
6
  from PIL import Image
 
 
 
 
 
7
 
8
  import DeepDeformationMapRegistration.utils.constants as C
9
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
 
 
10
 
11
 
12
  class DataGeneratorManager(keras.utils.Sequence):
@@ -113,7 +120,7 @@ class DataGeneratorManager(keras.utils.Sequence):
113
  file_list.sort()
114
  for data_file in files:
115
  file_name, extension = os.path.splitext(data_file)
116
- if extension.lower() == '.hd5':
117
  file_list.append(os.path.join(root, data_file))
118
 
119
  if not file_list:
@@ -230,12 +237,7 @@ class DataGenerator(DataGeneratorManager):
230
  ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
231
  return ret_list
232
 
233
- def __getitem__(self, index):
234
- """
235
- Generate one batch of data
236
- :param index: epoch index
237
- :return:
238
- """
239
  idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
240
 
241
  data_dict = self.__load_data(idxs)
@@ -248,6 +250,14 @@ class DataGenerator(DataGeneratorManager):
248
 
249
  return (inputs, outputs)
250
 
 
 
 
 
 
 
 
 
251
  def next_batch(self):
252
  if self.__last_batch > self.__batches_per_epoch:
253
  raise ValueError('No more batches for this epoch')
@@ -259,15 +269,15 @@ class DataGenerator(DataGeneratorManager):
259
  if label in self.__input_labels or label in self.__output_labels:
260
  # To avoid extra overhead
261
  try:
262
- retVal = data_file[label][:]
263
  except KeyError:
264
  # That particular label is not found in the file. But this should be known by the user by now
265
  retVal = None
266
 
267
  if append_array is not None and retVal is not None:
268
- return np.append(append_array, [data_file[C.H5_FIX_IMG][:]], axis=0)
269
  elif append_array is None:
270
- return retVal[np.newaxis, ...]
271
  else:
272
  return retVal # None
273
  else:
@@ -280,19 +290,22 @@ class DataGenerator(DataGeneratorManager):
280
  :return:
281
  """
282
  if isinstance(idx_list, (list, np.ndarray)):
283
- fix_img = np.empty((0, ) + C.IMG_SHAPE)
284
- mov_img = np.empty((0, ) + C.IMG_SHAPE)
 
 
 
285
 
286
- fix_parench = np.empty((0, ) + C.IMG_SHAPE)
287
- mov_parench = np.empty((0, ) + C.IMG_SHAPE)
288
 
289
- fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
290
- mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
291
 
292
- fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
293
- mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
294
 
295
- disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
 
296
 
297
  for idx in idx_list:
298
  data_file = h5py.File(self.__list_files[idx], 'r')
@@ -306,11 +319,14 @@ class DataGenerator(DataGeneratorManager):
306
  fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
307
  mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
308
 
309
- fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, mov_parench)
310
- mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_parench)
311
 
312
  disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
313
 
 
 
 
314
  data_file.close()
315
  batch_size = len(idx_list)
316
  else:
@@ -330,6 +346,9 @@ class DataGenerator(DataGeneratorManager):
330
 
331
  disp_map = self.__try_load(data_file, C.H5_GT_DISP)
332
 
 
 
 
333
  data_file.close()
334
  batch_size = 1
335
 
@@ -342,11 +361,61 @@ class DataGenerator(DataGeneratorManager):
342
  C.H5_MOV_VESSELS_MASK: mov_vessels,
343
  C.H5_MOV_PARENCHYMA_MASK: mov_parench,
344
  C.H5_GT_DISP: disp_map,
 
 
345
  'BATCH_SIZE': batch_size
346
  }
347
 
348
  return data_dict
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def get_samples(self, num_samples, random=False):
351
  if random:
352
  idxs = np.random.randint(0, self.__num_samples, num_samples)
@@ -509,3 +578,345 @@ class DataGenerator2D(keras.utils.Sequence):
509
  return mov, fix
510
 
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import h5py
5
  import random
6
  from PIL import Image
7
+ import nibabel as nib
8
+ from nilearn.image import resample_img
9
+ from skimage.exposure import equalize_adapthist
10
+ from scipy.ndimage import zoom
11
+ import tensorflow as tf
12
 
13
  import DeepDeformationMapRegistration.utils.constants as C
14
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
15
+ from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
16
+ from voxelmorph.tf.layers import SpatialTransformer
17
 
18
 
19
  class DataGeneratorManager(keras.utils.Sequence):
 
120
  file_list.sort()
121
  for data_file in files:
122
  file_name, extension = os.path.splitext(data_file)
123
+ if extension.lower() == '.hd5' or '.h5':
124
  file_list.append(os.path.join(root, data_file))
125
 
126
  if not file_list:
 
237
  ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
238
  return ret_list
239
 
240
+ def __getitem1(self, index):
 
 
 
 
 
241
  idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
242
 
243
  data_dict = self.__load_data(idxs)
 
250
 
251
  return (inputs, outputs)
252
 
253
+ def __getitem__(self, index):
254
+ """
255
+ Generate one batch of data
256
+ :param index: epoch index
257
+ :return:
258
+ """
259
+ return self.__getitem2(index)
260
+
261
  def next_batch(self):
262
  if self.__last_batch > self.__batches_per_epoch:
263
  raise ValueError('No more batches for this epoch')
 
269
  if label in self.__input_labels or label in self.__output_labels:
270
  # To avoid extra overhead
271
  try:
272
+ retVal = data_file[label][:][np.newaxis, ...]
273
  except KeyError:
274
  # That particular label is not found in the file. But this should be known by the user by now
275
  retVal = None
276
 
277
  if append_array is not None and retVal is not None:
278
+ return np.append(append_array, retVal, axis=0)
279
  elif append_array is None:
280
+ return retVal
281
  else:
282
  return retVal # None
283
  else:
 
290
  :return:
291
  """
292
  if isinstance(idx_list, (list, np.ndarray)):
293
+ fix_img = np.empty((0, ) + C.IMG_SHAPE, np.float32)
294
+ mov_img = np.empty((0, ) + C.IMG_SHAPE, np.float32)
295
+
296
+ fix_parench = np.empty((0, ) + C.IMG_SHAPE, np.float32)
297
+ mov_parench = np.empty((0, ) + C.IMG_SHAPE, np.float32)
298
 
299
+ fix_vessels = np.empty((0, ) + C.IMG_SHAPE, np.float32)
300
+ mov_vessels = np.empty((0, ) + C.IMG_SHAPE, np.float32)
301
 
302
+ fix_tumors = np.empty((0, ) + C.IMG_SHAPE, np.float32)
303
+ mov_tumors = np.empty((0, ) + C.IMG_SHAPE, np.float32)
304
 
305
+ disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE, np.float32)
 
306
 
307
+ fix_centroid = np.empty((0, 3))
308
+ mov_centroid = np.empty((0, 3))
309
 
310
  for idx in idx_list:
311
  data_file = h5py.File(self.__list_files[idx], 'r')
 
319
  fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
320
  mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
321
 
322
+ fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, fix_tumors)
323
+ mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_tumors)
324
 
325
  disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
326
 
327
+ fix_centroid = self.__try_load(data_file, C.H5_FIX_CENTROID, fix_centroid)
328
+ mov_centroid = self.__try_load(data_file, C.H5_MOV_CENTROID, mov_centroid)
329
+
330
  data_file.close()
331
  batch_size = len(idx_list)
332
  else:
 
346
 
347
  disp_map = self.__try_load(data_file, C.H5_GT_DISP)
348
 
349
+ fix_centroid = self.__try_load(data_file, C.H5_FIX_CENTROID)
350
+ mov_centroid = self.__try_load(data_file, C.H5_MOV_CENTROID)
351
+
352
  data_file.close()
353
  batch_size = 1
354
 
 
361
  C.H5_MOV_VESSELS_MASK: mov_vessels,
362
  C.H5_MOV_PARENCHYMA_MASK: mov_parench,
363
  C.H5_GT_DISP: disp_map,
364
+ C.H5_FIX_CENTROID: fix_centroid,
365
+ C.H5_MOV_CENTROID: mov_centroid,
366
  'BATCH_SIZE': batch_size
367
  }
368
 
369
  return data_dict
370
 
371
+ @staticmethod
372
+ def __get_data_shape(file_path, label):
373
+ f = h5py.File(file_path, 'r')
374
+ shape = f[label][:].shape
375
+ f.close()
376
+ return shape
377
+
378
+ def __load_data_by_label(self, label, idx_list):
379
+ if isinstance(idx_list, (list, np.ndarray)):
380
+ data_shape = self.__get_data_shape(self.__list_files[idx_list[0]], label)
381
+ container = np.empty((0, *data_shape), np.float32)
382
+ # if label == C.H5_GT_DISP:
383
+ # container = np.empty((0, ) + C.DISP_MAP_SHAPE, np.float32)
384
+ # elif label == C.H5_MOV_CENTROID or label == C.H5_FIX_CENTROID:
385
+ # container = np.empty((0, 3), np.float32)
386
+ # else:
387
+ # container = np.empty((0, ) + C.IMG_SHAPE, np.float32)
388
+
389
+ for idx in idx_list:
390
+ data_file = h5py.File(self.__list_files[idx], 'r')
391
+ container = self.__try_load(data_file, label, container)
392
+ data_file.close()
393
+ else:
394
+ data_file = h5py.File(self.__list_files[idx_list], 'r')
395
+ container = self.__try_load(data_file, label)
396
+ data_file.close()
397
+
398
+ return container
399
+
400
+ def __build_list2(self, label_list, file_idxs):
401
+ ret_list = list()
402
+ for label in label_list:
403
+ if label is C.DG_LBL_ZERO_GRADS:
404
+ aux = np.zeros([len(file_idxs), *C.DISP_MAP_SHAPE])
405
+ else:
406
+ aux = self.__load_data_by_label(label, file_idxs)
407
+
408
+ if label in [C.DG_LBL_MOV_IMG, C.DG_LBL_FIX_IMG]:
409
+ aux = min_max_norm(aux).astype(np.float32)
410
+ ret_list.append(aux)
411
+ return ret_list
412
+
413
+ def __getitem2(self, index):
414
+ f_indices = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
415
+
416
+ return self.__build_list2(self.__input_labels, f_indices), self.__build_list2(self.__output_labels, f_indices)
417
+
418
+
419
  def get_samples(self, num_samples, random=False):
420
  if random:
421
  idxs = np.random.randint(0, self.__num_samples, num_samples)
 
578
  return mov, fix
579
 
580
 
581
+ FILE_EXT = {'nifti': '.nii.gz',
582
+ 'h5': '.h5'}
583
+ CTRL_GRID = C.CoordinatesGrid()
584
+ CTRL_GRID.set_coords_grid([128]*3, [C.TPS_NUM_CTRL_PTS_PER_AXIS]*3, batches=False, norm=False, img_type=tf.float32)
585
+
586
+ FINE_GRID = C.CoordinatesGrid()
587
+ FINE_GRID.set_coords_grid([128]*3, [128]*3, batches=FINE_GRID, norm=False)
588
+
589
+ class DataGeneratorAugment(DataGeneratorManager):
590
+ def __init__(self, GeneratorManager: DataGeneratorManager, file_type='nifti', dataset_type='train'):
591
+ self.__complete_list_files = GeneratorManager.dataset_list_files
592
+ self.__list_files = [self.__complete_list_files[idx] for idx in GeneratorManager.get_generator_idxs(dataset_type)]
593
+ self.__batch_size = GeneratorManager.batch_size
594
+ self.__augm_per_sample = 10
595
+ self.__samples_per_batch = np.ceil(self.__batch_size / (self.__augm_per_sample + 1)) # B = S + S*A
596
+ self.__total_samples = len(self.__list_files)
597
+ self.__clip_range = GeneratorManager.clip_rage
598
+ self.__manager = GeneratorManager
599
+ self.__shuffle = GeneratorManager.shuffle
600
+ self.__file_extension = FILE_EXT[file_type]
601
+
602
+ self.__num_samples = len(self.__list_files)
603
+ self.__internal_idxs = np.arange(self.__num_samples)
604
+ # These indices are internal to the generator, they are not the same as the dataset_idxs!!
605
+
606
+ self.__dataset_type = dataset_type
607
+
608
+ self.__last_batch = 0
609
+ self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
610
+
611
+ self.__input_labels = GeneratorManager.input_labels
612
+ self.__output_labels = GeneratorManager.output_labels
613
+
614
+
615
+ def __get_dataset_files(self, search_path):
616
+ """
617
+ Get the path to the dataset files
618
+ :param search_path: dir path to search for the hd5 files
619
+ :return:
620
+ """
621
+ file_list = list()
622
+ for root, dirs, files in os.walk(search_path):
623
+ for data_file in files:
624
+ file_name, extension = os.path.splitext(data_file)
625
+ if extension.lower() == self.__file_extension:
626
+ file_list.append(os.path.join(root, data_file))
627
+
628
+ if not file_list:
629
+ raise ValueError('No files found to train in ', search_path)
630
+
631
+ print('Found {} files in {}'.format(len(file_list), search_path))
632
+ return file_list
633
+
634
+ def update_samples(self, new_sample_idxs):
635
+ self.__list_files = [self.__complete_list_files[idx] for idx in new_sample_idxs]
636
+ self.__num_samples = len(self.__list_files)
637
+ self.__internal_idxs = np.arange(self.__num_samples)
638
+
639
+ def on_epoch_end(self):
640
+ """
641
+ To be executed at the end of each epoch. Reshuffle the assigned samples
642
+ :return:
643
+ """
644
+ if self.__shuffle:
645
+ random.shuffle(self.__internal_idxs)
646
+ self.__last_batch = 0
647
+
648
+ def __len__(self):
649
+ """
650
+ Number of batches per epoch
651
+ :return:
652
+ """
653
+ return self.__batches_per_epoch
654
+
655
+ def __getitem__(self, index):
656
+ """
657
+ Generate one batch of data
658
+ :param index: epoch index
659
+ :return:
660
+ """
661
+ return self.__getitem(index)
662
+
663
+ def next_batch(self):
664
+ if self.__last_batch > self.__batches_per_epoch:
665
+ raise ValueError('No more batches for this epoch')
666
+ batch = self.__getitem__(self.__last_batch)
667
+ self.__last_batch += 1
668
+ return batch
669
+
670
+ def __try_load(self, data_file, label, append_array=None):
671
+ if label in self.__input_labels or label in self.__output_labels:
672
+ # To avoid extra overhead
673
+ try:
674
+ retVal = data_file[label][:][np.newaxis, ...]
675
+ except KeyError:
676
+ # That particular label is not found in the file. But this should be known by the user by now
677
+ retVal = None
678
+
679
+ if append_array is not None and retVal is not None:
680
+ return np.append(append_array, retVal, axis=0)
681
+ elif append_array is None:
682
+ return retVal
683
+ else:
684
+ return retVal # None
685
+ else:
686
+ return None
687
+
688
+ @staticmethod
689
+ def __get_data_shape(file_path, label):
690
+ f = h5py.File(file_path, 'r')
691
+ shape = f[label][:].shape
692
+ f.close()
693
+ return shape
694
+
695
+ def __load_data_by_label(self, label, idx_list):
696
+ if isinstance(idx_list, (list, np.ndarray)):
697
+ data_shape = self.__get_data_shape(self.__list_files[idx_list[0]], label)
698
+ container = np.empty((0, *data_shape), np.float32)
699
+ # if label == C.H5_GT_DISP:
700
+ # container = np.empty((0, ) + C.DISP_MAP_SHAPE, np.float32)
701
+ # elif label == C.H5_MOV_CENTROID or label == C.H5_FIX_CENTROID:
702
+ # container = np.empty((0, 3), np.float32)
703
+ # else:
704
+ # container = np.empty((0, ) + C.IMG_SHAPE, np.float32)
705
+
706
+ for idx in idx_list:
707
+ data_file = h5py.File(self.__list_files[idx], 'r')
708
+ container = self.__try_load(data_file, label, container)
709
+ data_file.close()
710
+ else:
711
+ data_file = h5py.File(self.__list_files[idx_list], 'r')
712
+ container = self.__try_load(data_file, label)
713
+ data_file.close()
714
+
715
+ return container
716
+
717
+ def __build_list(self, label_list, file_idxs):
718
+ ret_list = list()
719
+ for label in label_list:
720
+ if label is C.DG_LBL_ZERO_GRADS:
721
+ aux = np.zeros([len(file_idxs), *C.DISP_MAP_SHAPE])
722
+ else:
723
+ aux = self.__load_data_by_label(label, file_idxs)
724
+
725
+ if label in [C.DG_LBL_MOV_IMG, C.DG_LBL_FIX_IMG]:
726
+ aux = min_max_norm(aux).astype(np.float32)
727
+ ret_list.append(aux)
728
+ return ret_list
729
+
730
+ def __getitem(self, index):
731
+ f_indices = self.__internal_idxs[index * self.__samples_per_batch:(index + 1) * self.__samples_per_batch]
732
+ # https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
733
+ # A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
734
+ # The second element must match the outputs of the model, in this case (image, displacement map)
735
+ if 'h5' in self.__file_extension:
736
+ return self.__build_list(self.__input_labels, f_indices), self.__build_list(self.__output_labels, f_indices)
737
+ else:
738
+ f_list = [self.__list_files[i] for i in f_indices]
739
+ return self.__augment(f_list, 'fixed', C.H5_FIX_IMG), self.__augment(f_list, 'moving', C.H5_FIX_IMG)
740
+
741
+
742
+ def __intensity_preprocessing(self, img_data):
743
+ # Histogram normalization
744
+ processed_img = equalize_adapthist(img_data, clip_limit=0.03)
745
+ processed_img = min_max_norm(processed_img)
746
+
747
+ return processed_img
748
+
749
+
750
+ def __resize_img(self, img, output_shape):
751
+ if isinstance(output_shape, int):
752
+ output_shape = [output_shape] * len(img.shape)
753
+ # Resize
754
+ zoom_vals = np.asarray(output_shape) / np.asarray(img.shape)
755
+ return zoom(img, zoom_vals)
756
+
757
+
758
+ def __build_augmented_batch(self, f_list, mode):
759
+ for f_path in f_list:
760
+ h5_file = h5py.File(f_path, 'r')
761
+ img_nib = nib.load(h5_file[C.H5_FIX_IMG][:])
762
+ img_nib = resample_img(img_nib, np.eye(3))
763
+ try:
764
+ seg_nib = nib.load(h5_file[C.H5_FIX_SEGMENTATIONS][:])
765
+ seg_nib = resample_img(seg_nib, np.eye(3))
766
+ except FileNotFoundError:
767
+ seg_nib = None
768
+
769
+ img_nib = self.__intensity_preprocessing(img_nib)
770
+ img_nib = self.__resize_img(img_nib, 128)
771
+
772
+
773
+
774
+
775
+
776
+
777
+
778
+
779
+ def get_samples(self, num_samples, random=False):
780
+ return
781
+
782
+ def get_input_shape(self):
783
+ input_batch, _ = self.__getitem__(0)
784
+ data_dict = self.__load_data(0)
785
+
786
+ ret_val = data_dict[self.__input_labels[0]].shape
787
+ ret_val = (None, ) + ret_val[1:]
788
+ return ret_val # const.BATCH_SHAPE_SEGM
789
+
790
+ def who_are_you(self):
791
+ return self.__dataset_type
792
+
793
+ def print_datafiles(self):
794
+ return self.__list_files
795
+
796
+
797
+ def tf_graph_deform():
798
+ # Place holders
799
+ fix_img = tf.placeholder(tf.float32, [128]*3, 'fix_img')
800
+ fix_segmentations = tf.placeholder_with_default(np.zeros([128]*3), shape=[128]*3, name='fix_segmentations')
801
+ max_deformation = tf.placeholder(tf.float32, shape=(), name='max_deformation')
802
+ max_displacement = tf.placeholder(tf.float32, shape=(), name='max_displacement')
803
+ max_rotation = tf.placeholder(tf.float32, shape=(), name='max_rotation')
804
+ num_moved_points = tf.placeholder_with_default(50, shape=(), name='num_moved_points')
805
+ only_image = tf.placeholder_with_default(True, shape=(), name='only_image')
806
+
807
+ search_voxels = tf.cond(only_image,
808
+ lambda: fix_img,
809
+ lambda: fix_segmentations)
810
+
811
+ # Apply TPS deformation
812
+ # Get points in the segmentation or image, and add it to the control grid and target grid
813
+ # Indices of the points in the seaerch image with intensity greater than 0 (It would be bad if we only move the bg)
814
+ idx_points_in_label = tf.where(tf.greater(search_voxels, 0.0))
815
+
816
+ # Randomly select one of the points
817
+ random_idx = tf.random.uniform((num_moved_points,), minval=0, maxval=tf.shape(idx_points_in_label)[0], dtype=tf.int32)
818
+
819
+ disp_location = tf.gather_nd(idx_points_in_label, random_idx) # And get the coordinates
820
+ disp_location = tf.cast(disp_location, tf.float32)
821
+ # Get the coordinates of the control point displaces
822
+ rand_disp = tf.random.uniform((num_moved_points, 3), minval=-1, maxval=1, dtype=tf.float32) * max_deformation
823
+ warped_location = disp_location + rand_disp
824
+
825
+ # Add the selected locations to the control grid and the warped locations to the target grid
826
+ control_grid = tf.concat([CTRL_GRID.grid_flat(), disp_location], axis=0)
827
+ trg_grid = tf.concat([CTRL_GRID.grid_flat(), warped_location], axis=0)
828
+
829
+ # Add global affine transformation
830
+ trg_grid, aff = transform_points(trg_grid, max_displacement=max_displacement, max_rotation=max_rotation)
831
+
832
+ tps = ThinPlateSplines(control_grid, trg_grid)
833
+ def_grid = tps.interpolate(FINE_GRID.grid_flat())
834
+
835
+ disp_map = FINE_GRID.grid_flat() - def_grid
836
+ disp_map = tf.reshape(disp_map, (*FINE_GRID.shape, -1))
837
+ # disp_map = interpn(disp_map, FULL_FINE_GRID.grid)
838
+
839
+ # add the batch and channel dimensions
840
+ fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
841
+ fix_segmentations = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
842
+ disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
843
+
844
+ mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
845
+ mov_segmentations = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_segmentations, disp_map])
846
+
847
+ return tf.squeeze(mov_img),\
848
+ tf.squeeze(mov_segmentations),\
849
+ tf.squeeze(disp_map),\
850
+ disp_location,\
851
+ rand_disp,\
852
+ aff #, w, trg_grid, def_grid
853
+
854
+
855
+ def transform_points(points: tf.Tensor, max_displacement, max_rotation):
856
+ axis = tf.random.uniform((), 0, 3)
857
+
858
+ alpha = tf.cond(tf.less_equal(axis, 0.),
859
+ lambda: tf.random.uniform((1,), -max_rotation, max_rotation),
860
+ lambda: tf.zeros((1,), tf.float32))
861
+ beta = tf.cond(tf.less_equal(axis, 1.),
862
+ lambda: tf.random.uniform((1,), -max_rotation, max_rotation),
863
+ lambda: tf.zeros((1,), tf.float32))
864
+ gamma = tf.cond(tf.less_equal(axis, 2.),
865
+ lambda: tf.random.uniform((1,), -max_rotation, max_rotation),
866
+ lambda: tf.zeros((1,), tf.float32))
867
+
868
+ ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * max_displacement
869
+ tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * max_displacement
870
+ tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * max_displacement
871
+
872
+ M = build_affine_trf(tf.convert_to_tensor(FINE_GRID.shape, tf.float32), alpha, beta, gamma, ti, tj, tk)
873
+ if points.shape.as_list()[-1] == 3:
874
+ points = tf.transpose(points)
875
+ new_pts = tf.matmul(M[:3, :3], points)
876
+ new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
877
+ return tf.transpose(new_pts), M # Remove the last row of ones
878
+
879
+
880
+ def build_affine_trf(img_size, alpha, beta, gamma, ti, tj, tk):
881
+ img_centre = tf.expand_dims(tf.divide(img_size, 2.), -1)
882
+
883
+ # Rotation matrix around the image centre
884
+ # R* = T(p) R(ang) T(-p)
885
+ # tf.cos and tf.sin expect radians
886
+ zero = tf.zeros((1,))
887
+ one = tf.ones((1,))
888
+
889
+ T = tf.convert_to_tensor([[one, zero, zero, ti],
890
+ [zero, one, zero, tj],
891
+ [zero, zero, one, tk],
892
+ [zero, zero, zero, one]], tf.float32)
893
+ T = tf.squeeze(T)
894
+
895
+ R = tf.convert_to_tensor([[tf.math.cos(gamma) * tf.math.cos(beta),
896
+ tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.sin(alpha) - tf.math.sin(gamma) * tf.math.cos(alpha),
897
+ tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.cos(alpha) + tf.math.sin(gamma) * tf.math.sin(alpha),
898
+ zero],
899
+ [tf.math.sin(gamma) * tf.math.cos(beta),
900
+ tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.sin(gamma) + tf.math.cos(gamma) * tf.math.cos(alpha),
901
+ tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.cos(gamma) - tf.math.cos(gamma) * tf.math.sin(gamma),
902
+ zero],
903
+ [-tf.math.sin(beta),
904
+ tf.math.cos(beta) * tf.math.sin(alpha),
905
+ tf.math.cos(beta) * tf.math.cos(alpha),
906
+ zero],
907
+ [zero, zero, zero, one]], tf.float32)
908
+
909
+ R = tf.squeeze(R)
910
+
911
+ Tc = tf.convert_to_tensor([[one, zero, zero, img_centre[0]],
912
+ [zero, one, zero, img_centre[1]],
913
+ [zero, zero, one, img_centre[2]],
914
+ [zero, zero, zero, one]], tf.float32)
915
+ Tc = tf.squeeze(Tc)
916
+ Tc_ = tf.convert_to_tensor([[one, zero, zero, -img_centre[0]],
917
+ [zero, one, zero, -img_centre[1]],
918
+ [zero, zero, one, -img_centre[2]],
919
+ [zero, zero, zero, one]], tf.float32)
920
+ Tc_ = tf.squeeze(Tc_)
921
+
922
+ return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
DeepDeformationMapRegistration/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .augmentation import AugmentationLayer
2
+ from .uncertainty_weighting import UncertaintyWeighting, UncertaintyWeightingWithRollingAverage
3
+ from .upsampling import UpSampling1D, UpSampling2D, UpSampling3D
4
+ from .depthwise_conv_3d import DepthwiseConv3D
5
+ from .b_splines import interpolate_spline
DeepDeformationMapRegistration/layers/augmentation.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
8
+
9
+ import tensorflow.keras.layers as kl
10
+ import tensorflow as tf
11
+ from tensorflow.python.framework.errors import InvalidArgumentError
12
+
13
+ from DeepDeformationMapRegistration.utils.operators import soft_threshold, gaussian_kernel, sample_unique
14
+ import DeepDeformationMapRegistration.utils.constants as C
15
+ from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
16
+ from voxelmorph.tf.layers import SpatialTransformer
17
+
18
+
19
+ class AugmentationLayer(kl.Layer):
20
+ def __init__(self,
21
+ max_deformation,
22
+ max_displacement,
23
+ max_rotation,
24
+ num_control_points,
25
+ in_img_shape,
26
+ out_img_shape,
27
+ num_augmentations=1,
28
+ gamma_augmentation=True,
29
+ brightness_augmentation=True,
30
+ only_image=False,
31
+ only_resize=True,
32
+ return_displacement_map=False,
33
+ **kwargs):
34
+ super(AugmentationLayer, self).__init__(**kwargs)
35
+
36
+ self.max_deformation = max_deformation
37
+ self.max_displacement = max_displacement
38
+ self.max_rotation = max_rotation
39
+ self.num_control_points = num_control_points
40
+ self.num_augmentations = num_augmentations
41
+ self.in_img_shape = in_img_shape
42
+ self.out_img_shape = out_img_shape
43
+ self.only_image = only_image
44
+ self.return_disp_map = return_displacement_map
45
+
46
+ self.do_gamma_augm = gamma_augmentation
47
+ self.do_brightness_augm = brightness_augmentation
48
+
49
+ grid = C.CoordinatesGrid()
50
+ grid.set_coords_grid(in_img_shape, [C.TPS_NUM_CTRL_PTS_PER_AXIS] * 3)
51
+ self.control_grid = tf.identity(grid.grid_flat(), name='control_grid')
52
+ self.target_grid = tf.identity(grid.grid_flat(), name='target_grid')
53
+
54
+ grid.set_coords_grid(in_img_shape, in_img_shape)
55
+ self.fine_grid = tf.identity(grid.grid_flat(), 'fine_grid')
56
+
57
+ if out_img_shape is not None:
58
+ self.downsample_factor = [i // o for o, i in zip(out_img_shape, in_img_shape)]
59
+ self.img_gauss_filter = gaussian_kernel(3, 0.001, 1, 1, 3)
60
+ # self.resize_transf = tf.diag([*self.downsample_factor, 1])[:-1, :]
61
+ # self.resize_transf = tf.expand_dims(tf.reshape(self.resize_transf, [-1]), 0, name='resize_transformation') # ST expects a (12,) vector
62
+
63
+ self.augment = not only_resize
64
+
65
+ def compute_output_shape(self, input_shape):
66
+ input_shape = tf.TensorShape(input_shape).as_list()
67
+ img_shape = (input_shape[0], *self.out_img_shape, 1)
68
+ seg_shape = (input_shape[0], *self.out_img_shape, input_shape[-1] - 1)
69
+ disp_shape = (input_shape[0], *self.out_img_shape, 3)
70
+ # Expect the input to have the image and segmentations in the same tensor
71
+ if self.return_disp_map:
72
+ return (img_shape, img_shape, seg_shape, seg_shape, disp_shape)
73
+ else:
74
+ return (img_shape, img_shape, seg_shape, seg_shape)
75
+
76
+ #@tf.custom_gradient
77
+ def call(self, in_data, training=None):
78
+ # def custom_grad(in_grad):
79
+ # return tf.ones_like(in_grad)
80
+ if training is not None:
81
+ self.augment = training
82
+ return self.build_batch(in_data)# , custom_grad
83
+
84
+ def build_batch(self, fix_data: tf.Tensor):
85
+ if len(fix_data.get_shape().as_list()) < 5:
86
+ fix_data = tf.expand_dims(fix_data, axis=0) # Add Batch dimension
87
+ # fix_data = tf.tile(fix_data, (self.num_augmentations, *(1,)*4))
88
+ fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map = tf.map_fn(lambda x: self.augment_sample(x),
89
+ fix_data,
90
+ dtype=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
91
+ # map_fn unstacks elems on axis 0
92
+ if self.return_disp_map:
93
+ return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map
94
+ else:
95
+ return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch
96
+
97
+ def augment_sample(self, fix_data: tf.Tensor):
98
+ if self.only_image or not self.augment:
99
+ fix_img = fix_data
100
+ fix_segm = tf.zeros_like(fix_data, dtype=tf.float32)
101
+ else:
102
+ fix_img = fix_data[..., 0]
103
+ fix_img = tf.expand_dims(fix_img, -1)
104
+ fix_segm = fix_data[..., 1:] # We expect several segmentation masks
105
+
106
+ if self.augment:
107
+ # If we are training, do the full-fledged augmentation
108
+ fix_img = self.min_max_normalization(fix_img)
109
+
110
+ mov_img, mov_segm, disp_map = self.deform_image(tf.squeeze(fix_img), fix_segm)
111
+ mov_img = tf.expand_dims(mov_img, -1) # Add the removed channel axis
112
+
113
+ # Resample to output_shape
114
+ if self.out_img_shape is not None:
115
+ fix_img = self.downsize_image(fix_img)
116
+ mov_img = self.downsize_image(mov_img)
117
+
118
+ fix_segm = self.downsize_segmentation(fix_segm)
119
+ mov_segm = self.downsize_segmentation(mov_segm)
120
+
121
+ disp_map = self.downsize_displacement_map(disp_map)
122
+
123
+ if self.do_gamma_augm:
124
+ fix_img = self.gamma_augmentation(fix_img)
125
+ mov_img = self.gamma_augmentation(mov_img)
126
+
127
+ if self.do_brightness_augm:
128
+ fix_img = self.brightness_augmentation(fix_img)
129
+ mov_img = self.brightness_augmentation(mov_img)
130
+
131
+ else:
132
+ # During inference, just resize the input images
133
+ mov_img = tf.zeros_like(fix_img)
134
+ mov_segm = tf.zeros_like(fix_segm)
135
+
136
+ disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3])
137
+
138
+ if self.out_img_shape is not None:
139
+ fix_img = self.downsize_image(fix_img)
140
+ mov_img = self.downsize_image(mov_img)
141
+
142
+ fix_segm = self.downsize_segmentation(fix_segm)
143
+ mov_segm = self.downsize_segmentation(mov_segm)
144
+
145
+ disp_map = self.downsize_displacement_map(disp_map)
146
+
147
+ fix_img = self.min_max_normalization(fix_img)
148
+ mov_img = self.min_max_normalization(mov_img)
149
+ return fix_img, mov_img, fix_segm, mov_segm, disp_map
150
+
151
+ def downsize_image(self, img):
152
+ img = tf.expand_dims(img, axis=0)
153
+ # The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
154
+ img = tf.nn.conv3d(img, self.img_gauss_filter, strides=[1, ] * 5, padding='SAME', data_format='NDHWC')
155
+ img = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(img)
156
+
157
+ return tf.squeeze(img, axis=0)
158
+
159
+ def downsize_segmentation(self, segm):
160
+ segm = tf.expand_dims(segm, axis=0)
161
+ segm = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(segm)
162
+
163
+ segm = tf.cast(segm, tf.float32)
164
+ return tf.squeeze(segm, axis=0)
165
+
166
+ def downsize_displacement_map(self, disp_map):
167
+ disp_map = tf.expand_dims(disp_map, axis=0)
168
+ # The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
169
+ disp_map = tf.layers.AveragePooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(disp_map)
170
+
171
+ # self.downsample_factor = in_shape / out_shape, but here we need out_shape / in_shape. Hence, 1 / factor
172
+ if self.downsample_factor[0] != self.downsample_factor[1] != self.downsample_factor[2]:
173
+ # Downsize the displacement magnitude along the different axes
174
+ disp_map_x = disp_map[..., 0] * 1 / self.downsample_factor[0]
175
+ disp_map_y = disp_map[..., 1] * 1 / self.downsample_factor[1]
176
+ disp_map_z = disp_map[..., 2] * 1 / self.downsample_factor[2]
177
+
178
+ disp_map = tf.stack([disp_map_x, disp_map_y, disp_map_z], axis=-1)
179
+ else:
180
+ disp_map = disp_map * 1 / self.downsample_factor[0]
181
+
182
+ return tf.squeeze(disp_map, axis=0)
183
+
184
+ def gamma_augmentation(self, in_img: tf.Tensor):
185
+ in_img += 1e-5 # To prvent NaNs
186
+ gamma = tf.random.uniform((), 0.5, 2, tf.float32)
187
+
188
+ return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
189
+
190
+ def brightness_augmentation(self, in_img: tf.Tensor):
191
+ c = tf.random.uniform((), 0.5, 2, tf.float32)
192
+ return tf.clip_by_value(c*in_img, 0, 1)
193
+
194
+ def min_max_normalization(self, in_img: tf.Tensor):
195
+ return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
196
+ tf.subtract(tf.reduce_max(in_img), tf.reduce_min(in_img)))
197
+
198
+ def deform_image(self, fix_img: tf.Tensor, fix_segm: tf.Tensor):
199
+ # Get locations where the intensity > 0.0
200
+ idx_points_in_label = tf.where(tf.greater(fix_img, 0.0))
201
+
202
+ # Randomly select N points
203
+ # random_idx = tf.random.uniform((self.num_control_points,),
204
+ # minval=0, maxval=tf.shape(idx_points_in_label)[0],
205
+ # dtype=tf.int32)
206
+ #
207
+ # disp_location = tf.gather(idx_points_in_label, random_idx) # And get the coordinates
208
+ # disp_location = tf.cast(disp_location, tf.float32)
209
+ disp_location = sample_unique(idx_points_in_label, self.num_control_points, tf.float32)
210
+
211
+ # Get the coordinates of the control point displaces
212
+ rand_disp = tf.random.uniform((self.num_control_points, 3), minval=-1, maxval=1, dtype=tf.float32) * self.max_deformation
213
+ warped_location = disp_location + rand_disp
214
+
215
+ # Add the selected locations to the control grid and the warped locations to the target grid
216
+ control_grid = tf.concat([self.control_grid, disp_location], axis=0)
217
+ trg_grid = tf.concat([self.control_grid, warped_location], axis=0)
218
+
219
+ # Apply global transformation
220
+ valid_trf = False
221
+ while not valid_trf:
222
+ trg_grid, aff = self.global_transformation(trg_grid)
223
+
224
+ # Interpolate the displacement map
225
+ try:
226
+ tps = ThinPlateSplines(control_grid, trg_grid)
227
+ def_grid = tps.interpolate(self.fine_grid)
228
+ except InvalidArgumentError as err:
229
+ # If the transformation raises a non-invertible error,
230
+ # try again until we get a valid transformation
231
+ continue
232
+ else:
233
+ valid_trf = True
234
+
235
+ disp_map = self.fine_grid - def_grid
236
+ disp_map = tf.reshape(disp_map, (*self.in_img_shape, -1))
237
+
238
+ # Apply the displacement map
239
+ fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
240
+ fix_segm = tf.expand_dims(fix_segm, 0)
241
+ disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
242
+
243
+ mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
244
+ mov_segm = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([fix_segm, disp_map])
245
+
246
+ mov_img = tf.where(tf.is_nan(mov_img), tf.zeros_like(mov_img), mov_img)
247
+ mov_img = tf.where(tf.is_inf(mov_img), tf.zeros_like(mov_img), mov_img)
248
+
249
+ mov_segm = tf.where(tf.is_nan(mov_segm), tf.zeros_like(mov_segm), mov_segm)
250
+ mov_segm = tf.where(tf.is_inf(mov_segm), tf.zeros_like(mov_segm), mov_segm)
251
+
252
+ return tf.squeeze(mov_img), tf.squeeze(mov_segm, axis=0), tf.squeeze(disp_map, axis=0)
253
+
254
+ def global_transformation(self, points: tf.Tensor):
255
+ axis = tf.random.uniform((), 0, 3)
256
+
257
+ alpha = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 0.), tf.less_equal(axis, 1.)),
258
+ lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
259
+ lambda: tf.zeros((), tf.float32))
260
+ beta = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 1.), tf.less_equal(axis, 2.)),
261
+ lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
262
+ lambda: tf.zeros((), tf.float32))
263
+ gamma = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 2.), tf.less_equal(axis, 3.)),
264
+ lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
265
+ lambda: tf.zeros((), tf.float32))
266
+
267
+ ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
268
+ tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
269
+ tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
270
+
271
+ M = self.build_affine_transformation(tf.convert_to_tensor(self.in_img_shape, tf.float32),
272
+ alpha, beta, gamma, ti, tj, tk)
273
+
274
+ points = tf.transpose(points)
275
+ new_pts = tf.matmul(M[:3, :3], points)
276
+ new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
277
+ return tf.transpose(new_pts), M
278
+
279
+ @staticmethod
280
+ def build_affine_transformation(img_shape, alpha, beta, gamma, ti, tj, tk):
281
+ img_centre = tf.divide(img_shape, 2.)
282
+
283
+ # Rotation matrix around the image centre
284
+ # R* = T(p) R(ang) T(-p)
285
+ # tf.cos and tf.sin expect radians
286
+
287
+ T = tf.convert_to_tensor([[1, 0, 0, ti],
288
+ [0, 1, 0, tj],
289
+ [0, 0, 1, tk],
290
+ [0, 0, 0, 1]], tf.float32)
291
+
292
+ Ri = tf.convert_to_tensor([[1, 0, 0, 0],
293
+ [0, tf.math.cos(alpha), -tf.math.sin(alpha), 0],
294
+ [0, tf.math.sin(alpha), tf.math.cos(alpha), 0],
295
+ [0, 0, 0, 1]], tf.float32)
296
+
297
+ Rj = tf.convert_to_tensor([[ tf.math.cos(beta), 0, tf.math.sin(beta), 0],
298
+ [0, 1, 0, 0],
299
+ [-tf.math.sin(beta), 0, tf.math.cos(beta), 0],
300
+ [0, 0, 0, 1]], tf.float32)
301
+
302
+ Rk = tf.convert_to_tensor([[tf.math.cos(gamma), -tf.math.sin(gamma), 0, 0],
303
+ [tf.math.sin(gamma), tf.math.cos(gamma), 0, 0],
304
+ [0, 0, 1, 0],
305
+ [0, 0, 0, 1]], tf.float32)
306
+
307
+ R = tf.matmul(tf.matmul(Ri, Rj), Rk)
308
+
309
+ Tc = tf.convert_to_tensor([[1, 0, 0, img_centre[0]],
310
+ [0, 1, 0, img_centre[1]],
311
+ [0, 0, 1, img_centre[2]],
312
+ [0, 0, 0, 1]], tf.float32)
313
+
314
+ Tc_ = tf.convert_to_tensor([[1, 0, 0, -img_centre[0]],
315
+ [0, 1, 0, -img_centre[1]],
316
+ [0, 0, 1, -img_centre[2]],
317
+ [0, 0, 0, 1]], tf.float32)
318
+
319
+ return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
320
+
321
+ def get_config(self):
322
+ config = super(AugmentationLayer, self).get_config()
323
+ return config
324
+
325
+
326
+
DeepDeformationMapRegistration/layers/b_splines.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Polyharmonic spline interpolation."""
16
+
17
+ import tensorflow as tf
18
+ from typing import Union, List
19
+ import numpy as np
20
+
21
+ Number = Union[
22
+ float,
23
+ int,
24
+ np.float16,
25
+ np.float32,
26
+ np.float64,
27
+ np.int8,
28
+ np.int16,
29
+ np.int32,
30
+ np.int64,
31
+ np.uint8,
32
+ np.uint16,
33
+ np.uint32,
34
+ np.uint64,
35
+ ]
36
+
37
+ TensorLike = Union[
38
+ List[Union[Number, list]],
39
+ tuple,
40
+ Number,
41
+ np.ndarray,
42
+ tf.Tensor,
43
+ tf.SparseTensor,
44
+ tf.Variable,
45
+ ]
46
+ FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
47
+
48
+ EPSILON = 0.0000000001
49
+
50
+
51
+ def _cross_squared_distance_matrix(x: TensorLike, y: TensorLike) -> tf.Tensor:
52
+ """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
53
+ Computes the pairwise distances between rows of x and rows of y.
54
+ Args:
55
+ x: `[batch_size, n, d]` float `Tensor`.
56
+ y: `[batch_size, m, d]` float `Tensor`.
57
+ Returns:
58
+ squared_dists: `[batch_size, n, m]` float `Tensor`, where
59
+ `squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2`.
60
+ """
61
+ x_norm_squared = tf.reduce_sum(tf.square(x), 2)
62
+ y_norm_squared = tf.reduce_sum(tf.square(y), 2)
63
+
64
+ # Expand so that we can broadcast.
65
+ x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2)
66
+ y_norm_squared_tile = tf.expand_dims(y_norm_squared, 1)
67
+
68
+ x_y_transpose = tf.matmul(x, y, adjoint_b=True)
69
+
70
+ # squared_dists[b,i,j] = ||x_bi - y_bj||^2 =
71
+ # x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
72
+ squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile
73
+
74
+ return squared_dists
75
+
76
+
77
+ def _pairwise_squared_distance_matrix(x: TensorLike) -> tf.Tensor:
78
+ """Pairwise squared distance among a (batch) matrix's rows (2nd dim).
79
+ This saves a bit of computation vs. using
80
+ `_cross_squared_distance_matrix(x, x)`
81
+ Args:
82
+ x: `[batch_size, n, d]` float `Tensor`.
83
+ Returns:
84
+ squared_dists: `[batch_size, n, n]` float `Tensor`, where
85
+ `squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2`.
86
+ """
87
+
88
+ x_x_transpose = tf.matmul(x, x, adjoint_b=True)
89
+ x_norm_squared = tf.linalg.diag_part(x_x_transpose)
90
+ x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2)
91
+
92
+ # squared_dists[b,i,j] = ||x_bi - x_bj||^2 =
93
+ # = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
94
+ squared_dists = (
95
+ x_norm_squared_tile
96
+ - 2 * x_x_transpose
97
+ + tf.transpose(x_norm_squared_tile, [0, 2, 1])
98
+ )
99
+
100
+ return squared_dists
101
+
102
+
103
+ def _solve_interpolation(
104
+ train_points: TensorLike,
105
+ train_values: TensorLike,
106
+ order: int,
107
+ regularization_weight: FloatTensorLike,
108
+ ) -> TensorLike:
109
+ r"""Solve for interpolation coefficients.
110
+ Computes the coefficients of the polyharmonic interpolant for the
111
+ 'training' data defined by `(train_points, train_values)` using the kernel
112
+ $\phi$.
113
+ Args:
114
+ train_points: `[b, n, d]` interpolation centers.
115
+ train_values: `[b, n, k]` function values.
116
+ order: order of the interpolation.
117
+ regularization_weight: weight to place on smoothness regularization term.
118
+ Returns:
119
+ w: `[b, n, k]` weights on each interpolation center
120
+ v: `[b, d, k]` weights on each input dimension
121
+ Raises:
122
+ ValueError: if d or k is not fully specified.
123
+ """
124
+
125
+ # These dimensions are set dynamically at runtime.
126
+ b, n, _ = tf.unstack(tf.shape(train_points), num=3)
127
+
128
+ d = train_points.shape[-1]
129
+ if d is None:
130
+ raise ValueError(
131
+ "The dimensionality of the input points (d) must be "
132
+ "statically-inferrable."
133
+ )
134
+
135
+ k = train_values.shape[-1]
136
+ if k is None:
137
+ raise ValueError(
138
+ "The dimensionality of the output values (k) must be "
139
+ "statically-inferrable."
140
+ )
141
+
142
+ # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
143
+ # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
144
+ # To account for python style guidelines we use
145
+ # matrix_a for A and matrix_b for B.
146
+
147
+ c = train_points
148
+ f = train_values
149
+
150
+ # Next, construct the linear system.
151
+ with tf.name_scope("construct_linear_system"):
152
+
153
+ matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
154
+ if regularization_weight > 0:
155
+ batch_identity_matrix = tf.expand_dims(tf.eye(n, dtype=c.dtype), 0)
156
+ matrix_a += regularization_weight * batch_identity_matrix
157
+
158
+ # Append ones to the feature values for the bias term
159
+ # in the linear model.
160
+ ones = tf.ones_like(c[..., :1], dtype=c.dtype)
161
+ matrix_b = tf.concat([c, ones], 2) # [b, n, d + 1]
162
+
163
+ # [b, n + d + 1, n]
164
+ left_block = tf.concat([matrix_a, tf.transpose(matrix_b, [0, 2, 1])], 1)
165
+
166
+ num_b_cols = matrix_b.get_shape()[2] # d + 1
167
+ lhs_zeros = tf.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
168
+ right_block = tf.concat([matrix_b, lhs_zeros], 1) # [b, n + d + 1, d + 1]
169
+ lhs = tf.concat([left_block, right_block], 2) # [b, n + d + 1, n + d + 1]
170
+
171
+ rhs_zeros = tf.zeros([b, d + 1, k], train_points.dtype)
172
+ rhs = tf.concat([f, rhs_zeros], 1) # [b, n + d + 1, k]
173
+
174
+ # Then, solve the linear system and unpack the results.
175
+ with tf.name_scope("solve_linear_system"):
176
+ w_v = tf.linalg.solve(lhs, rhs)
177
+ w = w_v[:, :n, :]
178
+ v = w_v[:, n:, :]
179
+
180
+ return w, v
181
+
182
+
183
+ def _apply_interpolation(
184
+ query_points: TensorLike,
185
+ train_points: TensorLike,
186
+ w: TensorLike,
187
+ v: TensorLike,
188
+ order: int,
189
+ ) -> TensorLike:
190
+ """Apply polyharmonic interpolation model to data.
191
+ Given coefficients w and v for the interpolation model, we evaluate
192
+ interpolated function values at query_points.
193
+ Args:
194
+ query_points: `[b, m, d]` x values to evaluate the interpolation at.
195
+ train_points: `[b, n, d]` x values that act as the interpolation centers
196
+ (the c variables in the wikipedia article).
197
+ w: `[b, n, k]` weights on each interpolation center.
198
+ v: `[b, d, k]` weights on each input dimension.
199
+ order: order of the interpolation.
200
+ Returns:
201
+ Polyharmonic interpolation evaluated at points defined in `query_points`.
202
+ """
203
+
204
+ # First, compute the contribution from the rbf term.
205
+ pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
206
+ phi_pairwise_dists = _phi(pairwise_dists, order)
207
+
208
+ rbf_term = tf.matmul(phi_pairwise_dists, w)
209
+
210
+ # Then, compute the contribution from the linear term.
211
+ # Pad query_points with ones, for the bias term in the linear model.
212
+ query_points_pad = tf.concat(
213
+ [query_points, tf.ones_like(query_points[..., :1], train_points.dtype)], 2
214
+ )
215
+ linear_term = tf.matmul(query_points_pad, v)
216
+
217
+ return rbf_term + linear_term
218
+
219
+
220
+ def _phi(r: FloatTensorLike, order: int) -> FloatTensorLike:
221
+ """Coordinate-wise nonlinearity used to define the order of the
222
+ interpolation.
223
+ See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
224
+ Args:
225
+ r: input op.
226
+ order: interpolation order.
227
+ Returns:
228
+ `phi_k` evaluated coordinate-wise on `r`, for `k = r`.
229
+ """
230
+
231
+ # using EPSILON prevents log(0), sqrt0), etc.
232
+ # sqrt(0) is well-defined, but its gradient is not
233
+ with tf.name_scope("phi"):
234
+ if order == 1:
235
+ r = tf.maximum(r, EPSILON)
236
+ r = tf.sqrt(r)
237
+ return r
238
+ elif order == 2:
239
+ return 0.5 * r * tf.math.log(tf.maximum(r, EPSILON))
240
+ elif order == 4:
241
+ return 0.5 * tf.square(r) * tf.math.log(tf.maximum(r, EPSILON))
242
+ elif order % 2 == 0:
243
+ r = tf.maximum(r, EPSILON)
244
+ return 0.5 * tf.pow(r, 0.5 * order) * tf.math.log(r)
245
+ else:
246
+ r = tf.maximum(r, EPSILON)
247
+ return tf.pow(r, 0.5 * order)
248
+
249
+
250
+ def interpolate_spline(
251
+ train_points: TensorLike,
252
+ train_values: TensorLike,
253
+ query_points: TensorLike,
254
+ order: int,
255
+ regularization_weight: FloatTensorLike = 0.0,
256
+ name: str = "interpolate_spline",
257
+ ) -> tf.Tensor:
258
+ r"""Interpolate signal using polyharmonic interpolation.
259
+ The interpolant has the form
260
+ $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$
261
+ This is a sum of two terms: (1) a weighted sum of radial basis function
262
+ (RBF) terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term
263
+ with a bias. The \\(c_i\\) vectors are 'training' points.
264
+ In the code, b is absorbed into v
265
+ by appending 1 as a final dimension to x. The coefficients w and v are
266
+ estimated such that the interpolant exactly fits the value of the function
267
+ at the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\),
268
+ and the vector w sums to 0. With these constraints, the coefficients
269
+ can be obtained by solving a linear system.
270
+ \\(\phi\\) is an RBF, parametrized by an interpolation
271
+ order. Using order=2 produces the well-known thin-plate spline.
272
+ We also provide the option to perform regularized interpolation. Here, the
273
+ interpolant is selected to trade off between the squared loss on the
274
+ training data and a certain measure of its curvature
275
+ ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
276
+ Using a regularization weight greater than zero has the effect that the
277
+ interpolant will no longer exactly fit the training data. However, it may
278
+ be less vulnerable to overfitting, particularly for high-order
279
+ interpolation.
280
+ Note the interpolation procedure is differentiable with respect to all
281
+ inputs besides the order parameter.
282
+ We support dynamically-shaped inputs, where batch_size, n, and m are None
283
+ at graph construction time. However, d and k must be known.
284
+ Args:
285
+ train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
286
+ locations. These do not need to be regularly-spaced.
287
+ train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional
288
+ values evaluated at train_points.
289
+ query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
290
+ where we will output the interpolant's values.
291
+ order: order of the interpolation. Common values are 1 for
292
+ \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\)
293
+ (thin-plate spline), or 3 for \\(\phi(r) = r^3\\).
294
+ regularization_weight: weight placed on the regularization term.
295
+ This will depend substantially on the problem, and it should always be
296
+ tuned. For many problems, it is reasonable to use no regularization.
297
+ If using a non-zero value, we recommend a small value like 0.001.
298
+ name: name prefix for ops created by this function
299
+ Returns:
300
+ `[b, m, k]` float `Tensor` of query values. We use train_points and
301
+ train_values to perform polyharmonic interpolation. The query values are
302
+ the values of the interpolant evaluated at the locations specified in
303
+ query_points.
304
+ """
305
+ with tf.name_scope(name or "interpolate_spline"):
306
+ train_points = tf.convert_to_tensor(train_points)
307
+ train_values = tf.convert_to_tensor(train_values)
308
+ query_points = tf.convert_to_tensor(query_points)
309
+
310
+ # First, fit the spline to the observed data.
311
+ with tf.name_scope("solve"):
312
+ w, v = _solve_interpolation(
313
+ train_points, train_values, order, regularization_weight
314
+ )
315
+
316
+ # Then, evaluate the spline at the query locations.
317
+ with tf.name_scope("predict"):
318
+ query_values = _apply_interpolation(query_points, train_points, w, v, order)
319
+
320
+ return query_values
DeepDeformationMapRegistration/layers/depthwise_conv_3d.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SRC: https://github.com/alexandrosstergiou/keras-DepthwiseConv3D
2
+
3
+ '''
4
+ This is a modification of the SeparableConv3D code in Keras,
5
+ to perform just the Depthwise Convolution (1st step) of the
6
+ Depthwise Separable Convolution layer.
7
+ '''
8
+ from __future__ import absolute_import
9
+
10
+ from tensorflow.keras import backend as K
11
+ from tensorflow.keras import initializers
12
+ from tensorflow.keras import regularizers
13
+ from tensorflow.keras import constraints
14
+ import tensorflow.keras.utils as conv_utils
15
+ from tensorflow.keras.layers import Conv3D, InputSpec
16
+ from tensorflow.python.keras.backend import _preprocess_padding, _preprocess_conv3d_input
17
+
18
+ import tensorflow as tf
19
+
20
+
21
+ class DepthwiseConv3D(Conv3D):
22
+ """Depthwise 3D convolution.
23
+ Depth-wise part of separable convolutions consist in performing
24
+ just the first step/operation
25
+ (which acts on each input channel separately).
26
+ It does not perform the pointwise convolution (second step).
27
+ The `depth_multiplier` argument controls how many
28
+ output channels are generated per input channel in the depthwise step.
29
+ # Arguments
30
+ kernel_size: An integer or tuple/list of 3 integers, specifying the
31
+ depth, width and height of the 3D convolution window.
32
+ Can be a single integer to specify the same value for
33
+ all spatial dimensions.
34
+ strides: An integer or tuple/list of 3 integers,
35
+ specifying the strides of the convolution along the depth, width and height.
36
+ Can be a single integer to specify the same value for
37
+ all spatial dimensions.
38
+ padding: one of `"valid"` or `"same"` (case-insensitive).
39
+ depth_multiplier: The number of depthwise convolution output channels
40
+ for each input channel.
41
+ The total number of depthwise convolution output
42
+ channels will be equal to `filterss_in * depth_multiplier`.
43
+ groups: The depth size of the convolution (as a variant of the original Depthwise conv)
44
+ data_format: A string,
45
+ one of `channels_last` (default) or `channels_first`.
46
+ The ordering of the dimensions in the inputs.
47
+ `channels_last` corresponds to inputs with shape
48
+ `(batch, height, width, channels)` while `channels_first`
49
+ corresponds to inputs with shape
50
+ `(batch, channels, height, width)`.
51
+ It defaults to the `image_data_format` value found in your
52
+ Keras config file at `~/.keras/keras.json`.
53
+ If you never set it, then it will be "channels_last".
54
+ activation: Activation function to use
55
+ (see [activations](../activations.md)).
56
+ If you don't specify anything, no activation is applied
57
+ (ie. "linear" activation: `a(x) = x`).
58
+ use_bias: Boolean, whether the layer uses a bias vector.
59
+ depthwise_initializer: Initializer for the depthwise kernel matrix
60
+ (see [initializers](../initializers.md)).
61
+ bias_initializer: Initializer for the bias vector
62
+ (see [initializers](../initializers.md)).
63
+ depthwise_regularizer: Regularizer function applied to
64
+ the depthwise kernel matrix
65
+ (see [regularizer](../regularizers.md)).
66
+ bias_regularizer: Regularizer function applied to the bias vector
67
+ (see [regularizer](../regularizers.md)).
68
+ dialation_rate: List of ints.
69
+ Defines the dilation factor for each dimension in the
70
+ input. Defaults to (1,1,1)
71
+ activity_regularizer: Regularizer function applied to
72
+ the output of the layer (its "activation").
73
+ (see [regularizer](../regularizers.md)).
74
+ depthwise_constraint: Constraint function applied to
75
+ the depthwise kernel matrix
76
+ (see [constraints](../constraints.md)).
77
+ bias_constraint: Constraint function applied to the bias vector
78
+ (see [constraints](../constraints.md)).
79
+ # Input shape
80
+ 5D tensor with shape:
81
+ `(batch, depth, channels, rows, cols)` if data_format='channels_first'
82
+ or 5D tensor with shape:
83
+ `(batch, depth, rows, cols, channels)` if data_format='channels_last'.
84
+ # Output shape
85
+ 5D tensor with shape:
86
+ `(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first'
87
+ or 4D tensor with shape:
88
+ `(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'.
89
+ `rows` and `cols` values might have changed due to padding.
90
+ """
91
+
92
+ #@legacy_depthwise_conv3d_support
93
+ def __init__(self,
94
+ kernel_size,
95
+ strides=(1, 1, 1),
96
+ padding='valid',
97
+ depth_multiplier=1,
98
+ groups=None,
99
+ data_format=None,
100
+ activation=None,
101
+ use_bias=True,
102
+ depthwise_initializer='glorot_uniform',
103
+ bias_initializer='zeros',
104
+ dilation_rate = (1, 1, 1),
105
+ depthwise_regularizer=None,
106
+ bias_regularizer=None,
107
+ activity_regularizer=None,
108
+ depthwise_constraint=None,
109
+ bias_constraint=None,
110
+ **kwargs):
111
+ super(DepthwiseConv3D, self).__init__(
112
+ filters=None,
113
+ kernel_size=kernel_size,
114
+ strides=strides,
115
+ padding=padding,
116
+ data_format=data_format,
117
+ activation=activation,
118
+ use_bias=use_bias,
119
+ bias_regularizer=bias_regularizer,
120
+ dilation_rate=dilation_rate,
121
+ activity_regularizer=activity_regularizer,
122
+ bias_constraint=bias_constraint,
123
+ **kwargs)
124
+ self.depth_multiplier = depth_multiplier
125
+ self.groups = groups
126
+ self.depthwise_initializer = initializers.get(depthwise_initializer)
127
+ self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
128
+ self.depthwise_constraint = constraints.get(depthwise_constraint)
129
+ self.bias_initializer = initializers.get(bias_initializer)
130
+ self.dilation_rate = dilation_rate
131
+ self._padding = _preprocess_padding(self.padding)
132
+ self._strides = (1,) + self.strides + (1,)
133
+ self._data_format = "NDHWC"
134
+ self.input_dim = None
135
+
136
+ def build(self, input_shape):
137
+ if len(input_shape) < 5:
138
+ raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. '
139
+ 'Received input shape:', str(input_shape))
140
+ if self.data_format == 'channels_first':
141
+ channel_axis = 1
142
+ else:
143
+ channel_axis = -1
144
+ if input_shape[channel_axis] is None:
145
+ raise ValueError('The channel dimension of the inputs to '
146
+ '`DepthwiseConv3D` '
147
+ 'should be defined. Found `None`.')
148
+ self.input_dim = int(input_shape[channel_axis])
149
+
150
+ if self.groups is None:
151
+ self.groups = self.input_dim
152
+
153
+ if self.groups > self.input_dim:
154
+ raise ValueError('The number of groups cannot exceed the number of channels')
155
+
156
+ if self.input_dim % self.groups != 0:
157
+ raise ValueError('Warning! The channels dimension is not divisible by the group size chosen')
158
+
159
+ depthwise_kernel_shape = (self.kernel_size[0],
160
+ self.kernel_size[1],
161
+ self.kernel_size[2],
162
+ self.input_dim,
163
+ self.depth_multiplier)
164
+
165
+ self.depthwise_kernel = self.add_weight(
166
+ shape=depthwise_kernel_shape,
167
+ initializer=self.depthwise_initializer,
168
+ name='depthwise_kernel',
169
+ regularizer=self.depthwise_regularizer,
170
+ constraint=self.depthwise_constraint)
171
+
172
+ if self.use_bias:
173
+ self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,),
174
+ initializer=self.bias_initializer,
175
+ name='bias',
176
+ regularizer=self.bias_regularizer,
177
+ constraint=self.bias_constraint)
178
+ else:
179
+ self.bias = None
180
+ # Set input spec.
181
+ self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim})
182
+ self.built = True
183
+
184
+ def call(self, inputs, training=None):
185
+ inputs = _preprocess_conv3d_input(inputs, self.data_format)
186
+
187
+ if self.data_format == 'channels_last':
188
+ dilation = (1,) + self.dilation_rate + (1,)
189
+ else:
190
+ dilation = self.dilation_rate + (1,) + (1,)
191
+
192
+ if self._data_format == 'NCDHW':
193
+ outputs = tf.concat(
194
+ [tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
195
+ strides=self._strides,
196
+ padding=self._padding,
197
+ dilations=dilation,
198
+ data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1)
199
+
200
+ else:
201
+ outputs = tf.concat(
202
+ [tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
203
+ strides=self._strides,
204
+ padding=self._padding,
205
+ dilations=dilation,
206
+ data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1)
207
+
208
+ if self.bias is not None:
209
+ outputs = K.bias_add(
210
+ outputs,
211
+ self.bias,
212
+ data_format=self.data_format)
213
+
214
+ if self.activation is not None:
215
+ return self.activation(outputs)
216
+
217
+ return outputs
218
+
219
+ def compute_output_shape(self, input_shape):
220
+ if self.data_format == 'channels_first':
221
+ depth = input_shape[2]
222
+ rows = input_shape[3]
223
+ cols = input_shape[4]
224
+ out_filters = self.groups * self.depth_multiplier
225
+ elif self.data_format == 'channels_last':
226
+ depth = input_shape[1]
227
+ rows = input_shape[2]
228
+ cols = input_shape[3]
229
+ out_filters = self.groups * self.depth_multiplier
230
+
231
+ depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
232
+ self.padding,
233
+ self.strides[0])
234
+
235
+ rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
236
+ self.padding,
237
+ self.strides[1])
238
+
239
+ cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
240
+ self.padding,
241
+ self.strides[2])
242
+
243
+ if self.data_format == 'channels_first':
244
+ return (input_shape[0], out_filters, depth, rows, cols)
245
+
246
+ elif self.data_format == 'channels_last':
247
+ return (input_shape[0], depth, rows, cols, out_filters)
248
+
249
+ def get_config(self):
250
+ config = super(DepthwiseConv3D, self).get_config()
251
+ config.pop('filters')
252
+ config.pop('kernel_initializer')
253
+ config.pop('kernel_regularizer')
254
+ config.pop('kernel_constraint')
255
+ config['depth_multiplier'] = self.depth_multiplier
256
+ config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
257
+ config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
258
+ config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
259
+ return config
260
+
261
+ def __call__(self, inputs, training=True):
262
+ return self.call(inputs, training)
263
+
264
+ DepthwiseConvolution3D = DepthwiseConv3D
DeepDeformationMapRegistration/layers/uncertainty_weighting.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ # currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ # parentdir = os.path.dirname(currentdir)
5
+ # sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+ #
7
+ # PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
8
+
9
+ import tensorflow.keras.layers as kl
10
+ import tensorflow.keras.backend as K
11
+ import tensorflow as tf
12
+ import numpy as np
13
+ import random
14
+
15
+ from DeepDeformationMapRegistration.utils.operators import soft_threshold, gaussian_kernel, sample_unique
16
+ import DeepDeformationMapRegistration.utils.constants as C
17
+ from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
18
+ from voxelmorph.tf.layers import SpatialTransformer
19
+ from neurite.tf.utils import resize
20
+ #from cupyx.scipy.ndimage import zoom
21
+ #import cupy
22
+
23
+
24
+ class UncertaintyWeighting(kl.Layer):
25
+ def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
26
+ reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
27
+ **kwargs):
28
+ assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
29
+ assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
30
+ self.num_loss = num_loss_fns
31
+ if len(loss_fns) == 1 and self.num_loss > 1:
32
+ self.loss_fns = loss_fns * self.num_loss
33
+ else:
34
+ self.loss_fns = loss_fns
35
+
36
+ if len(prior_loss_w) == 1:
37
+ self.prior_loss_w = prior_loss_w * num_loss_fns
38
+ else:
39
+ self.prior_loss_w = prior_loss_w
40
+ self.prior_loss_w = np.log(self.prior_loss_w)
41
+
42
+ if len(manual_loss_w) == 1:
43
+ self.manual_loss_w = manual_loss_w * num_loss_fns
44
+ else:
45
+ self.manual_loss_w = manual_loss_w
46
+
47
+ self.num_reg = num_reg_fns
48
+ if self.num_reg != 0:
49
+ if len(reg_fns) == 1 and self.num_reg > 1:
50
+ self.reg_fns = reg_fns * self.num_reg
51
+ else:
52
+ self.reg_fns = reg_fns
53
+
54
+ self.is_placeholder = True
55
+ if self.num_reg != 0:
56
+ if len(prior_reg_w) == 1:
57
+ self.prior_reg_w = prior_reg_w * num_reg_fns
58
+ else:
59
+ self.prior_reg_w = prior_reg_w
60
+ self.prior_reg_w = np.log(self.prior_reg_w)
61
+
62
+ if len(manual_reg_w) == 1:
63
+ self.manual_reg_w = manual_reg_w * num_reg_fns
64
+ else:
65
+ self.manual_reg_w = manual_reg_w
66
+
67
+ else:
68
+ self.prior_reg_w = list()
69
+ self.manual_reg_w = list()
70
+
71
+ super(UncertaintyWeighting, self).__init__(**kwargs)
72
+
73
+ def build(self, input_shape=None):
74
+ self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
75
+ initializer=tf.keras.initializers.Constant(self.prior_loss_w),
76
+ trainable=True)
77
+ self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
78
+
79
+ if self.num_reg != 0:
80
+ self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
81
+ initializer=tf.keras.initializers.Constant(self.prior_reg_w),
82
+ trainable=True)
83
+ if self.num_reg == 1:
84
+ self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
85
+ else:
86
+ self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
87
+
88
+ super(UncertaintyWeighting, self).build(input_shape)
89
+
90
+ def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
91
+ loss_values = list()
92
+ loss_names_loss = list()
93
+ loss_names_reg = list()
94
+
95
+ for y_true, y_pred, loss_fn, man_w in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w):
96
+ loss_values.append(tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
97
+ loss_names_loss.append(loss_fn.__name__)
98
+
99
+ loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
100
+ loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
101
+
102
+ if self.num_reg != 0:
103
+ loss_reg = list()
104
+ for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
105
+ loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
106
+ loss_names_reg.append(reg_fn.__name__)
107
+
108
+ reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
109
+ loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
110
+
111
+ for i, loss_name in enumerate(loss_names_loss):
112
+ self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
113
+ aggregation='mean')
114
+ self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
115
+ aggregation='mean')
116
+ if self.num_reg != 0:
117
+ for i, loss_name in enumerate(loss_names_reg):
118
+ self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
119
+ aggregation='mean')
120
+ self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
121
+ aggregation='mean')
122
+
123
+ return K.sum(loss)
124
+
125
+ def call(self, inputs):
126
+ ys_true = inputs[:self.num_loss]
127
+ ys_pred = inputs[self.num_loss:self.num_loss*2]
128
+ reg_true = inputs[-self.num_reg*2:-self.num_reg]
129
+ reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
130
+ loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
131
+ self.add_loss(loss, inputs=inputs)
132
+ # We won't actually use the output, but we need something for the TF graph
133
+ return K.concatenate(inputs, -1)
134
+
135
+ def get_config(self):
136
+ base_config = super(UncertaintyWeighting, self).get_config()
137
+ base_config['num_loss_fns'] = self.num_loss
138
+ base_config['num_reg_fns'] = self.num_reg
139
+
140
+ return base_config
141
+
142
+
143
+ class UncertaintyWeightingWithRollingAverage(kl.Layer):
144
+ def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
145
+ reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
146
+ roll_avg_reference=0, # position in loss_fns of the reference loss function for the rolling avg
147
+ **kwargs):
148
+ assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
149
+ assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
150
+ # Rolling average attributes
151
+ self.ref_loss = roll_avg_reference
152
+ self.compute_roll_avg = False # Toogle between computing the average of the losses or updating a know average
153
+ self.scale_factor = [1.] * num_loss_fns
154
+ self.n = 0 # Number of viewed samples
155
+ self.temp_storage = [0.] * num_loss_fns
156
+
157
+ self.num_loss = num_loss_fns
158
+ if len(loss_fns) == 1 and self.num_loss > 1:
159
+ self.loss_fns = loss_fns * self.num_loss
160
+ else:
161
+ self.loss_fns = loss_fns
162
+
163
+ if len(prior_loss_w) == 1:
164
+ self.prior_loss_w = prior_loss_w * num_loss_fns
165
+ else:
166
+ self.prior_loss_w = prior_loss_w
167
+ self.prior_loss_w = np.log(self.prior_loss_w)
168
+
169
+ if len(manual_loss_w) == 1:
170
+ self.manual_loss_w = manual_loss_w * num_loss_fns
171
+ else:
172
+ self.manual_loss_w = manual_loss_w
173
+
174
+ self.num_reg = num_reg_fns
175
+ if self.num_reg != 0:
176
+ if len(reg_fns) == 1 and self.num_reg > 1:
177
+ self.reg_fns = reg_fns * self.num_reg
178
+ else:
179
+ self.reg_fns = reg_fns
180
+
181
+ self.is_placeholder = True
182
+ if self.num_reg != 0:
183
+ if len(prior_reg_w) == 1:
184
+ self.prior_reg_w = prior_reg_w * num_reg_fns
185
+ else:
186
+ self.prior_reg_w = prior_reg_w
187
+ self.prior_reg_w = np.log(self.prior_reg_w)
188
+
189
+ if len(manual_reg_w) == 1:
190
+ self.manual_reg_w = manual_reg_w * num_reg_fns
191
+ else:
192
+ self.manual_reg_w = manual_reg_w
193
+
194
+ else:
195
+ self.prior_reg_w = list()
196
+ self.manual_reg_w = list()
197
+
198
+ super(UncertaintyWeightingWithRollingAverage, self).__init__(**kwargs)
199
+
200
+ def build(self, input_shape=None):
201
+ self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
202
+ initializer=tf.keras.initializers.Constant(self.prior_loss_w),
203
+ trainable=True)
204
+ self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
205
+
206
+ if self.num_reg != 0:
207
+ self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
208
+ initializer=tf.keras.initializers.Constant(self.prior_reg_w),
209
+ trainable=True)
210
+ if self.num_reg == 1:
211
+ self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
212
+ else:
213
+ self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
214
+
215
+ super(UncertaintyWeightingWithRollingAverage, self).build(input_shape)
216
+
217
+ def store_values(self, new_loss_values):
218
+ for i, (t, v) in enumerate(zip(self.temp_storage, new_loss_values)):
219
+ self.temp_storage[i] = t + v
220
+ self.n += 1
221
+
222
+ def compute_scale_factors(self):
223
+ for i, val in enumerate(self.temp_storage):
224
+ self.scale_factor[i] = self.n / val # 1/avg
225
+
226
+ self.scale_factor[self.ref_loss] = 1.
227
+
228
+ self.temp_storage = [0.] * self.num_loss
229
+ self.n = 0
230
+
231
+ @property
232
+ def ref_on_epoch_end_function(self):
233
+ return self.compute_scale_factors
234
+
235
+ def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
236
+ loss_values = list()
237
+ loss_names_loss = list()
238
+ loss_names_reg = list()
239
+
240
+ for y_true, y_pred, loss_fn, man_w, sf in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w, self.scale_factor):
241
+ loss_values.append(sf * tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
242
+ loss_names_loss.append(loss_fn.__name__)
243
+
244
+ self.store_values(loss_values)
245
+ loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
246
+ loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
247
+
248
+ if self.num_reg != 0:
249
+ loss_reg = list()
250
+ for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
251
+ loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
252
+ loss_names_reg.append(reg_fn.__name__)
253
+
254
+ reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
255
+ loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
256
+
257
+ for i, loss_name in enumerate(loss_names_loss):
258
+ self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
259
+ aggregation='mean')
260
+ self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
261
+ aggregation='mean')
262
+ if self.num_reg != 0:
263
+ for i, loss_name in enumerate(loss_names_reg):
264
+ self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
265
+ aggregation='mean')
266
+ self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
267
+ aggregation='mean')
268
+ sc_tf = tf.convert_to_tensor(self.scale_factor, dtype=tf.float32, name='scale_factors_tf')
269
+ self.add_metric(tf.slice(sc_tf, [i], [1]), name='SCALE_FACTOR_{}_{}'.format(i, loss_name),
270
+ aggregation='mean')
271
+
272
+ return K.sum(loss)
273
+
274
+ def call(self, inputs):
275
+ ys_true = inputs[:self.num_loss]
276
+ ys_pred = inputs[self.num_loss:self.num_loss*2]
277
+ reg_true = inputs[-self.num_reg*2:-self.num_reg]
278
+ reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
279
+ loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
280
+ self.add_loss(loss, inputs=inputs)
281
+ # We won't actually use the output, but we need something for the TF graph
282
+ return K.concatenate(inputs, -1)
283
+
284
+ def get_config(self):
285
+ base_config = super(UncertaintyWeighting, self).get_config()
286
+ base_config['num_loss_fns'] = self.num_loss
287
+ base_config['num_reg_fns'] = self.num_reg
288
+
289
+ return base_config
290
+
291
+
292
+ def distance_map(coord1, coord2, dist, img_shape_w_channel=(64, 64, 1)):
293
+ max_dist = np.max(img_shape_w_channel)
294
+ dm_p = np.ones(img_shape_w_channel, np.float32)*max_dist
295
+ dm_n = np.ones(img_shape_w_channel, np.float32)*max_dist
296
+
297
+ for c1, c2, d in zip(coord1, coord2, dist):
298
+ dm_p[c1, c2, 0] = d if dm_p[c1, c2, 0] > d else dm_p[c1, c2]
299
+ d_n = 64. - max_dist
300
+ dm_n[c1, c2, 0] = d_n if dm_n[c1, c2, 0] > d_n else dm_n[c1, c2]
301
+
302
+ return dm_p/max_dist, dm_n/max_dist
303
+
304
+
305
+ def volume_to_ov_and_dm(in_volume: tf.Tensor):
306
+ # This one is run as a preprocessing step
307
+ def get_ov_projections_and_dm(volume):
308
+ # tf.sign returns -1, 0, 1 depending on the sign of the elements of the input (negative, zero, positive)
309
+ i, j, k, c = tf.where(volume > 0.0)
310
+ top = tf.sign(tf.reduce_sum(volume, axis=0), name='ov_top')
311
+ right = tf.sign(tf.reduce_sum(volume, axis=1), name='ov_right')
312
+ front = tf.sign(tf.reduce_sum(volume, axis=2), name='ov_front')
313
+
314
+ top_p, top_n = tf.py_func(distance_map, [j, k, i], tf.float32)
315
+ right_p, right_n = tf.py_func(distance_map, [i, k, j], tf.float32)
316
+ front_p, front_n = tf.py_func(distance_map, [i, j, k], tf.float32)
317
+
318
+ return [front, right, top], [front_p, front_n, top_p, top_n, right_p, right_n]
319
+
320
+ if len(in_volume.shape.as_list()) > 4:
321
+ return tf.map_fn(get_ov_projections_and_dm, in_volume, [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32])
322
+ else:
323
+ return get_ov_projections_and_dm(in_volume)
324
+
325
+
326
+ def ov_and_dm_to_volume(ov_projections):
327
+ front, right, top = ov_projections
328
+
329
+ def get_volume(front: tf.Tensor, right: tf.Tensor, top: tf.Tensor):
330
+ front_shape = front.shape.as_list() # Assume (H, W, C)
331
+ top_shape = top.shape.as_list()
332
+
333
+ front_vol = tf.tile(tf.expand_dims(front, 2), [1, 1, top_shape[0], 1])
334
+ right_vol = tf.tile(tf.expand_dims(right, 1), [1, front_shape[1], 1, 1])
335
+ top_vol = tf.tile(tf.expand_dims(top, 0), [front_shape[0], 1, 1, 1])
336
+ sum = tf.add(tf.add(front_vol, right_vol), top_vol)
337
+ return soft_threshold(sum, 2., 'get_volume')
338
+
339
+ if len(front.shape.as_list()) > 3:
340
+ return tf.map_fn(lambda x: get_volume(x[0], x[1], x[2]), ov_projections, tf.float32)
341
+ else:
342
+ return get_volume(front, right, top)
343
+
344
+ # TODO: Recovering the coordinates from the distance maps to prevent artifacts
345
+ # will the gradients be backpropagated??!?!!?!?!
346
+
347
+
DeepDeformationMapRegistration/layers/upsampling.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from numpy import (zeros, where, diff, floor, minimum, maximum, array, concatenate, logical_or, logical_xor,
4
+ sqrt)
5
+ from tensorflow.python.framework import tensor_shape
6
+ from tensorflow.python.keras.utils import conv_utils
7
+ from tensorflow.python.keras.engine.base_layer import Layer
8
+ from tensorflow.python.keras.engine.input_spec import InputSpec
9
+ from tensorflow.python.util.tf_export import keras_export # api_export
10
+ # SRC: https://github.com/tensorflow/tensorflow/issues/46609
11
+ # import functools
12
+
13
+ # keras_export = functools.partial(api_export, 'keras') # keras_export is not defined in 1.13 but in 1.15 --> https://github.com/tensorflow/tensorflow/blob/3d6e4f24e32b5dbe0a83aaa6e9d0f6671ba41da8/tensorflow/python/util/tf_export.py
14
+
15
+ def linear_interpolate(x_fix, y_fix, x_var):
16
+ '''
17
+ Functionality:
18
+ 1D linear interpolation
19
+ Author:
20
+ Michael Osthege
21
+ Link:
22
+ https://gist.github.com/michaelosthege/e20d242bc62a434843b586c78ebce6cc
23
+ '''
24
+
25
+ x_repeat = tf.tile(x_var[:, None], (len(x_fix), ))
26
+ distances = tf.abs(x_repeat - x_fix)
27
+
28
+ x_indices = tf.searchsorted(x_fix, x_var)
29
+
30
+ weights = tf.zeros_like(distances)
31
+ idx = tf.arange(len(x_indices))
32
+ weights[idx, x_indices] = distances[idx, x_indices - 1]
33
+ weights[idx, x_indices - 1] = distances[idx, x_indices]
34
+ weights /= np.sum(weights, axis=1)[:, None]
35
+
36
+ y_var = np.dot(weights, y_fix.T)
37
+
38
+ return y_var
39
+
40
+
41
+ def cubic_interpolate(x, y, x0):
42
+ '''
43
+ Functionliaty:
44
+ 1D cubic spline interpolation
45
+ Author:
46
+ Raphael Valentin
47
+ Link:
48
+ https://stackoverflow.com/questions/31543775/how-to-perform-cubic-spline-interpolation-in-python
49
+ '''
50
+
51
+ x = np.asfarray(x)
52
+ y = np.asfarray(y)
53
+
54
+ # remove non finite values
55
+ # indexes = np.isfinite(x)
56
+ # x = x[indexes]
57
+ # y = y[indexes]
58
+
59
+ # check if sorted
60
+ if np.any(np.diff(x) < 0):
61
+ indexes = np.argsort(x)
62
+ x = x[indexes]
63
+ y = y[indexes]
64
+
65
+ size = len(x)
66
+
67
+ xdiff = np.diff(x)
68
+ ydiff = np.diff(y)
69
+
70
+ # allocate buffer matrices
71
+ Li = np.empty(size)
72
+ Li_1 = np.empty(size - 1)
73
+ z = np.empty(size)
74
+
75
+ # fill diagonals Li and Li-1 and solve [L][y] = [B]
76
+ Li[0] = sqrt(2 * xdiff[0])
77
+ Li_1[0] = 0.0
78
+ B0 = 0.0 # natural boundary
79
+ z[0] = B0 / Li[0]
80
+
81
+ for i in range(1, size - 1, 1):
82
+ Li_1[i] = xdiff[i - 1] / Li[i - 1]
83
+ Li[i] = sqrt(2 * (xdiff[i - 1] + xdiff[i]) - Li_1[i - 1] * Li_1[i - 1])
84
+ Bi = 6 * (ydiff[i] / xdiff[i] - ydiff[i - 1] / xdiff[i - 1])
85
+ z[i] = (Bi - Li_1[i - 1] * z[i - 1]) / Li[i]
86
+
87
+ i = size - 1
88
+ Li_1[i - 1] = xdiff[-1] / Li[i - 1]
89
+ Li[i] = sqrt(2 * xdiff[-1] - Li_1[i - 1] * Li_1[i - 1])
90
+ Bi = 0.0 # natural boundary
91
+ z[i] = (Bi - Li_1[i - 1] * z[i - 1]) / Li[i]
92
+
93
+ # solve [L.T][x] = [y]
94
+ i = size - 1
95
+ z[i] = z[i] / Li[i]
96
+ for i in range(size - 2, -1, -1):
97
+ z[i] = (z[i] - Li_1[i - 1] * z[i + 1]) / Li[i]
98
+
99
+ # find index
100
+ index = x.searchsorted(x0)
101
+ np.clip(index, 1, size - 1, index)
102
+
103
+ xi1, xi0 = x[index], x[index - 1]
104
+ yi1, yi0 = y[index], y[index - 1]
105
+ zi1, zi0 = z[index], z[index - 1]
106
+ hi1 = xi1 - xi0
107
+
108
+ # calculate cubic
109
+ f0 = zi0/(6*hi1)*(xi1-x0)**3 + \
110
+ zi1/(6*hi1)*(x0-xi0)**3 + \
111
+ (yi1/hi1 - zi1*hi1/6)*(x0-xi0) + \
112
+ (yi0/hi1 - zi0*hi1/6)*(xi1-x0)
113
+
114
+ return f0
115
+
116
+
117
+ def pchip_interpolate(xi, yi, x, mode="mono", verbose=False):
118
+ '''
119
+ Functionality:
120
+ 1D PCHP interpolation
121
+ Authors:
122
+ Michael Taylor <[email protected]>
123
+ Mathieu Virbel <[email protected]>
124
+ Link:
125
+ https://gist.github.com/tito/553f1135959921ce6699652bf656150d
126
+ '''
127
+
128
+ if mode not in ("mono", "quad"):
129
+ raise ValueError("Unrecognized mode string")
130
+
131
+ # Search for [xi,xi+1] interval for each x
132
+ xi = xi.astype("double")
133
+ yi = yi.astype("double")
134
+
135
+ x_index = zeros(len(x), dtype="int")
136
+ xi_steps = diff(xi)
137
+ if not all(xi_steps > 0):
138
+ raise ValueError("x-coordinates are not in increasing order.")
139
+
140
+ x_steps = diff(x)
141
+ if xi_steps.max() / xi_steps.min() < 1.000001:
142
+ # uniform input grid
143
+ if verbose:
144
+ print("pchip: uniform input grid")
145
+ xi_start = xi[0]
146
+ xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
147
+ x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
148
+
149
+ # Calculate gradients d
150
+ h = (xi[-1] - xi[0]) / (len(xi) - 1)
151
+ d = zeros(len(xi), dtype="double")
152
+ if mode == "quad":
153
+ # quadratic polynomial fit
154
+ d[[0]] = (yi[1] - yi[0]) / h
155
+ d[[-1]] = (yi[-1] - yi[-2]) / h
156
+ d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
157
+ else:
158
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
159
+ # recipe
160
+ delta = diff(yi) / h
161
+ d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
162
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
163
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
164
+ (delta == 0, array([False]))))] = 0
165
+ # Calculate output values y
166
+ dxxi = x - xi[x_index]
167
+ dxxid = x - xi[1 + x_index]
168
+ dxxi2 = pow(dxxi, 2)
169
+ dxxid2 = pow(dxxid, 2)
170
+ y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
171
+ (dxxid - h / 2)) + 1 / pow(h, 2) *
172
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
173
+ else:
174
+ # not uniform input grid
175
+ if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
176
+ # non-uniform input grid, uniform output grid
177
+ if verbose:
178
+ print("pchip: non-uniform input grid, uniform output grid")
179
+ x_decreasing = x[-1] < x[0]
180
+ if x_decreasing:
181
+ x = x[::-1]
182
+ x_start = x[0]
183
+ x_step = (x[-1] - x[0]) / (len(x) - 1)
184
+ x_indexprev = -1
185
+ for xi_loop in range(len(xi) - 2):
186
+ x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
187
+ x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
188
+ x_indexprev = x_indexcur
189
+ x_index[1 + x_indexprev:] = len(xi) - 2
190
+ if x_decreasing:
191
+ x = x[::-1]
192
+ x_index = x_index[::-1]
193
+ elif all(x_steps > 0) or all(x_steps < 0):
194
+ # non-uniform input/output grids, output grid monotonic
195
+ if verbose:
196
+ print("pchip: non-uniform in/out grid, output grid monotonic")
197
+ x_decreasing = x[-1] < x[0]
198
+ if x_decreasing:
199
+ x = x[::-1]
200
+ x_len = len(x)
201
+ x_loop = 0
202
+ for xi_loop in range(len(xi) - 1):
203
+ while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
204
+ x_index[x_loop] = xi_loop
205
+ x_loop += 1
206
+ x_index[x_loop:] = len(xi) - 2
207
+ if x_decreasing:
208
+ x = x[::-1]
209
+ x_index = x_index[::-1]
210
+ else:
211
+ # non-uniform input/output grids, output grid not monotonic
212
+ if verbose:
213
+ print("pchip: non-uniform in/out grids, " "output grid not monotonic")
214
+ for index in range(len(x)):
215
+ loc = where(x[index] < xi)[0]
216
+ if loc.size == 0:
217
+ x_index[index] = len(xi) - 2
218
+ elif loc[0] == 0:
219
+ x_index[index] = 0
220
+ else:
221
+ x_index[index] = loc[0] - 1
222
+ # Calculate gradients d
223
+ h = diff(xi)
224
+ d = zeros(len(xi), dtype="double")
225
+ delta = diff(yi) / h
226
+ if mode == "quad":
227
+ # quadratic polynomial fit
228
+ d[[0, -1]] = delta[[0, -1]]
229
+ d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
230
+ else:
231
+ # mode=='mono', Fritsch-Carlson algorithm from fortran numerical
232
+ # recipe
233
+ d = concatenate(
234
+ (delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
235
+ (2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
236
+ d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
237
+ d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
238
+ (delta == 0, array([False]))))] = 0
239
+ dxxi = x - xi[x_index]
240
+ dxxid = x - xi[1 + x_index]
241
+ dxxi2 = pow(dxxi, 2)
242
+ dxxid2 = pow(dxxid, 2)
243
+ y = (2 / pow(h[x_index], 3) *
244
+ (yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
245
+ (dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
246
+ (d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
247
+ return y
248
+
249
+
250
+ def Interpolate1D(x, y, xx, method='nearest'):
251
+ '''
252
+ Functionality:
253
+ 1D interpolation with various methods
254
+ Author:
255
+ Kai Gao <[email protected]>
256
+ '''
257
+
258
+ n = len(x)
259
+ nn = len(xx)
260
+ yy = np.zeros(nn)
261
+
262
+ # Nearest neighbour interpolation
263
+ if method == 'nearest':
264
+ for i in range(0, nn):
265
+ xi = tf.argmin(tf.abs(xx[i] - x))
266
+ yy[i] = y[xi]
267
+
268
+ # Linear interpolation
269
+ elif method == 'linear':
270
+
271
+ # # slower version
272
+ # if n == 1:
273
+ # yy[:-1] = y[0]
274
+
275
+ # else:
276
+ # for i in range(0, nn):
277
+
278
+ # if xx[i] < x[0]:
279
+ # t = (xx[i] - x[0]) / (x[1] - x[0])
280
+ # yy[i] = (1.0 - t) * y[0] + t * y[1]
281
+
282
+ # elif x[n - 1] <= xx[i]:
283
+ # t = (xx[i] - x[n - 2]) / (x[n - 1] - x[n - 2])
284
+ # yy[i] = (1.0 - t) * y[n - 2] + t * y[n - 1]
285
+
286
+ # else:
287
+ # for k in range(1, n):
288
+ # if x[k - 1] <= xx[i] and xx[i] < x[k]:
289
+ # t = (xx[i] - x[k - 1]) / (x[k] - x[k - 1])
290
+ # yy[i] = (1.0 - t) * y[k - 1] + t * y[k]
291
+ # break
292
+
293
+ # # faster version
294
+ yy = linear_interpolate(x, y, xx)
295
+
296
+ # Cubic interpolation
297
+ elif method == 'cubic':
298
+ yy = cubic_interpolate(x, y, xx)
299
+
300
+ # Piecewise cubic Hermite interpolating polynomial (PCHIP)
301
+ elif method == 'pchip':
302
+ yy = pchip_interpolate(x, y, xx, mode='mono')
303
+
304
+ return yy
305
+
306
+
307
+ def Interpolate2D(x, y, f, xx, yy, method='nearest'):
308
+ '''
309
+ Functionality:
310
+ 2D interpolation implemented in a separable fashion
311
+ There are methods that do real 2D non-separable interpolation, which are
312
+ more difficult to implement.
313
+ Author:
314
+ Kai Gao <[email protected]>
315
+ '''
316
+
317
+ n1 = len(x)
318
+ n2 = len(y)
319
+ nn1 = len(xx)
320
+ nn2 = len(yy)
321
+
322
+ w = np.zeros((nn1, n2))
323
+ ff = np.zeros((nn1, nn2))
324
+
325
+ # Interpolate along the 1st dimension
326
+ for j in range(0, n2):
327
+ w[:, j] = Interpolate1D(x, f[:, j], xx, method)
328
+
329
+ # Interpolate along the 2nd dimension
330
+ for i in range(0, nn1):
331
+ ff[i, :] = Interpolate1D(y, w[i, :], yy, method)
332
+
333
+ return ff
334
+
335
+
336
+ def Interpolate3D(x, y, z, f, xx, yy, zz, method='nearest'):
337
+ '''
338
+ Functionality:
339
+ 3D interpolation implemented in a separable fashion
340
+ There are methods that do real 3D non-separable interpolation, which are
341
+ more difficult to implement.
342
+ Author:
343
+ Kai Gao <[email protected]>
344
+ '''
345
+
346
+ n1 = len(x)
347
+ n2 = len(y)
348
+ n3 = len(z)
349
+ nn1 = len(xx)
350
+ nn2 = len(yy)
351
+ nn3 = len(zz)
352
+
353
+ w1 = tf.zeros((nn1, n2, n3))
354
+ w2 = tf.zeros((nn1, nn2, n3))
355
+ ff = tf.zeros((nn1, nn2, nn3))
356
+
357
+ # Interpolate along the 1st dimension
358
+ for k in range(0, n3):
359
+ for j in range(0, n2):
360
+ w1[:, j, k] = Interpolate1D(x, f[:, j, k], xx, method)
361
+
362
+ # Interpolate along the 2nd dimension
363
+ for k in range(0, n3):
364
+ for i in range(0, nn1):
365
+ w2[i, :, k] = Interpolate1D(y, w1[i, :, k], yy, method)
366
+
367
+ # Interpolate along the 3rd dimension
368
+ for j in range(0, nn2):
369
+ for i in range(0, nn1):
370
+ ff[i, j, :] = Interpolate1D(z, w2[i, j, :], zz, method)
371
+
372
+ return ff
373
+
374
+
375
+ def UpInterpolate1D(x, size=2, interpolation='nearest', data_format='channels_first', align_corners=True):
376
+ '''
377
+ Functionality:
378
+ 1D upsampling interpolation for tf
379
+ Author:
380
+ Kai Gao <[email protected]>
381
+ '''
382
+
383
+ x = x.numpy()
384
+
385
+ if data_format == 'channels_last':
386
+ nb, nr, nh = x.shape
387
+ elif data_format == 'channels_first':
388
+ nb, nh, nr = x.shape
389
+
390
+ r = size
391
+ ir = np.linspace(0.0, nr - 1.0, num=nr)
392
+
393
+ if align_corners:
394
+ # align_corners=True assumes that values are sampled at discrete points
395
+ iir = np.linspace(0.0, nr - 1.0, num=nr * r)
396
+ else:
397
+ # aling_corners=False assumes that values are sampled at centers of discrete blocks
398
+ iir = np.linspace(0.0 - 0.5 + 0.5 / r, nr - 1.0 + 0.5 - 0.5 / r, num=nr * r)
399
+ iir = np.clip(iir, 0.0, nr - 1.0)
400
+
401
+ if data_format == 'channels_last':
402
+ xx = np.zeros((nb, nr * r, nh))
403
+ for i in range(0, nb):
404
+ for j in range(0, nh):
405
+ t = np.reshape(x[i, :, j], (nr))
406
+ xx[i, :, j] = Interpolate1D(ir, t, iir, interpolation)
407
+
408
+ elif data_format == 'channels_first':
409
+ xx = np.zeros((nb, nh, nr * r))
410
+ for i in range(0, nb):
411
+ for j in range(0, nh):
412
+ t = np.reshape(x[i, j, :], (nr))
413
+ xx[i, j, :] = Interpolate1D(ir, t, iir, interpolation)
414
+
415
+ return tf.convert_to_tensor(xx, dtype=x.dtype)
416
+
417
+
418
+ def UpInterpolate2D(x,
419
+ size=(2, 2),
420
+ interpolation='nearest',
421
+ data_format='channels_first',
422
+ align_corners=True):
423
+ '''
424
+ Functionality:
425
+ 2D upsampling interpolation for tf
426
+ Author:
427
+ Kai Gao <[email protected]>
428
+ '''
429
+
430
+ x = x.numpy()
431
+
432
+ if data_format == 'channels_last':
433
+ nb, nr, nc, nh = x.shape
434
+ elif data_format == 'channels_first':
435
+ nb, nh, nr, nc = x.shape
436
+
437
+ r = size[0]
438
+ c = size[1]
439
+ ir = np.linspace(0.0, nr - 1.0, num=nr)
440
+ ic = np.linspace(0.0, nc - 1.0, num=nc)
441
+
442
+ if align_corners:
443
+ # align_corners=True assumes that values are sampled at discrete points
444
+ iir = np.linspace(0.0, nr - 1.0, num=nr * r)
445
+ iic = np.linspace(0.0, nc - 1.0, num=nc * c)
446
+ else:
447
+ # aling_corners=False assumes that values are sampled at centers of discrete blocks
448
+ iir = np.linspace(0.0 - 0.5 + 0.5 / r, nr - 1.0 + 0.5 - 0.5 / r, num=nr * r)
449
+ iic = np.linspace(0.0 - 0.5 + 0.5 / c, nc - 1.0 + 0.5 - 0.5 / c, num=nc * c)
450
+ iir = np.clip(iir, 0.0, nr - 1.0)
451
+ iic = np.clip(iic, 0.0, nc - 1.0)
452
+
453
+ if data_format == 'channels_last':
454
+ xx = np.zeros((nb, nr * r, nc * c, nh))
455
+ for i in range(0, nb):
456
+ for j in range(0, nh):
457
+ t = np.reshape(x[i, :, :, j], (nr, nc))
458
+ xx[i, :, :, j] = Interpolate2D(ir, ic, t, iir, iic, interpolation)
459
+
460
+ elif data_format == 'channels_first':
461
+ xx = np.zeros((nb, nh, nr * r, nc * c))
462
+ for i in range(0, nb):
463
+ for j in range(0, nh):
464
+ t = np.reshape(x[i, j, :, :], (nr, nc))
465
+ xx[i, j, :, :] = Interpolate2D(ir, ic, t, iir, iic, interpolation)
466
+
467
+ return tf.convert_to_tensor(xx, dtype=x.dtype)
468
+
469
+
470
+ def UpInterpolate3D(x,
471
+ size=(2, 2, 2),
472
+ interpolation='nearest',
473
+ data_format='channels_first',
474
+ align_corners=True):
475
+ '''
476
+ Functionality:
477
+ 3D upsampling interpolation for tf
478
+ Author:
479
+ Kai Gao <[email protected]>
480
+ '''
481
+
482
+ # x = x.numpy()
483
+
484
+ if data_format == 'channels_last':
485
+ nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
486
+ elif data_format == 'channels_first':
487
+ nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
488
+
489
+ r = size[0]
490
+ c = size[1]
491
+ d = size[2]
492
+ ir = tf.linspace(0.0, nr - 1.0, num=nr)
493
+ ic = tf.linspace(0.0, nc - 1.0, num=nc)
494
+ id = tf.linspace(0.0, nd - 1.0, num=nd)
495
+
496
+ if align_corners:
497
+ # align_corners=True assumes that values are sampled at discrete points
498
+ iir = tf.linspace(0.0, nr - 1.0, num=nr * r)
499
+ iic = tf.linspace(0.0, nc - 1.0, num=nc * c)
500
+ iid = tf.linspace(0.0, nd - 1.0, num=nd * d)
501
+ else:
502
+ # aling_corners=False assumes that values are sampled at centers of discrete blocks
503
+ iir = tf.linspace(0.0 - 0.5 + 0.5 / r, nr - 1.0 + 0.5 - 0.5 / r, num=nr * r)
504
+ iic = tf.linspace(0.0 - 0.5 + 0.5 / c, nc - 1.0 + 0.5 - 0.5 / c, num=nc * c)
505
+ iid = tf.linspace(0.0 - 0.5 + 0.5 / d, nd - 1.0 + 0.5 - 0.5 / d, num=nd * d)
506
+ iir = tf.clip_by_value(iir, 0.0, nr - 1.0)
507
+ iic = tf.clip_by_value(iic, 0.0, nc - 1.0)
508
+ iid = tf.clip_by_value(iid, 0.0, nd - 1.0)
509
+
510
+ if data_format == 'channels_last':
511
+ xx = tf.zeros((nb, nr * r, nc * c, nd * d, nh))
512
+ for i in range(0, nb):
513
+ for j in range(0, nh):
514
+ t = tf.reshape(x[i, :, :, :, j], (nr, nc, nd))
515
+ xx[i, :, :, :, j] = Interpolate3D(ir, ic, id, t, iir, iic, iid, interpolation)
516
+
517
+ elif data_format == 'channels_first':
518
+ xx = tf.zeros((nb, nh, nr * r, nc * c, nd * d))
519
+ for i in range(0, nb):
520
+ for j in range(0, nh):
521
+ t = tf.reshape(x[i, j, :, :, :], (nr, nc, nd))
522
+ xx[i, j, :, :, :] = Interpolate3D(ir, ic, id, t, iir, iic, iid, interpolation)
523
+
524
+ return tf.convert_to_tensor(xx, dtype=x.dtype)
525
+
526
+
527
+ # ################################################################################
528
+ @keras_export('keras.layers.UpSampling1D')
529
+ class UpSampling1D(Layer):
530
+ """Upsampling layer for 1D inputs.
531
+ Repeats each temporal step `size` times along the time axis.
532
+ Examples:
533
+ >>> input_shape = (2, 2, 3)
534
+ >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
535
+ >>> print(x)
536
+ [[[ 0 1 2]
537
+ [ 3 4 5]]
538
+ [[ 6 7 8]
539
+ [ 9 10 11]]]
540
+ >>> y = tf.keras.layers.UpSampling1D(size=2)(x)
541
+ >>> print(y)
542
+ tf.Tensor(
543
+ [[[ 0 1 2]
544
+ [ 0 1 2]
545
+ [ 3 4 5]
546
+ [ 3 4 5]]
547
+ [[ 6 7 8]
548
+ [ 6 7 8]
549
+ [ 9 10 11]
550
+ [ 9 10 11]]], shape=(2, 4, 3), dtype=int64)
551
+ Args:
552
+ size: Integer. Upsampling factor.
553
+ Input shape:
554
+ 3D tensor with shape: `(batch_size, steps, features)`.
555
+ Output shape:
556
+ 3D tensor with shape: `(batch_size, upsampled_steps, features)`.
557
+ """
558
+ def __init__(self, size=2, data_format='None', interpolation='nearest', align_corners=True, **kwargs):
559
+ super(UpSampling1D, self).__init__(**kwargs)
560
+ self.data_format = conv_utils.normalize_data_format(data_format)
561
+ self.size = int(size)
562
+ self.input_spec = InputSpec(ndim=3)
563
+ self.interpolation = interpolation
564
+ if self.interpolation not in {'nearest', 'linear', 'cubic', 'pchip'}:
565
+ raise ValueError('`interpolation` argument should be one of `"nearest"` '
566
+ 'or `"linear"` '
567
+ 'or `"cubic"` '
568
+ 'or `"pchip"`.')
569
+ self.align_corners = align_corners
570
+
571
+ def compute_output_shape(self, input_shape):
572
+ input_shape = tf.TensorShape(input_shape).as_list()
573
+ size = self.size * input_shape[1] if input_shape[1] is not None else None
574
+ return tf.TensorShape([input_shape[0], size, input_shape[2]])
575
+
576
+ def call(self, inputs):
577
+ return UpInterpolate1D(inputs,
578
+ self.size,
579
+ data_format=self.data_format,
580
+ interpolation=self.interpolation,
581
+ align_corners=self.align_corners)
582
+
583
+ def get_config(self):
584
+ config = {'size': self.size}
585
+ base_config = super(UpSampling1D, self).get_config()
586
+ return dict(list(base_config.items()) + list(config.items()))
587
+
588
+
589
+ @keras_export('keras.layers.UpSampling2D')
590
+ class UpSampling2D(Layer):
591
+ """Upsampling layer for 2D inputs.
592
+ Repeats the rows and columns of the data
593
+ by `size[0]` and `size[1]` respectively.
594
+ Examples:
595
+ >>> input_shape = (2, 2, 1, 3)
596
+ >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
597
+ >>> print(x)
598
+ [[[[ 0 1 2]]
599
+ [[ 3 4 5]]]
600
+ [[[ 6 7 8]]
601
+ [[ 9 10 11]]]]
602
+ >>> y = tf.keras.layers.UpSampling2D(size=(1, 2))(x)
603
+ >>> print(y)
604
+ tf.Tensor(
605
+ [[[[ 0 1 2]
606
+ [ 0 1 2]]
607
+ [[ 3 4 5]
608
+ [ 3 4 5]]]
609
+ [[[ 6 7 8]
610
+ [ 6 7 8]]
611
+ [[ 9 10 11]
612
+ [ 9 10 11]]]], shape=(2, 2, 2, 3), dtype=int64)
613
+ Args:
614
+ size: Int, or tuple of 2 integers.
615
+ The upsampling factors for rows and columns.
616
+ data_format: A string,
617
+ one of `channels_last` (default) or `channels_first`.
618
+ The ordering of the dimensions in the inputs.
619
+ `channels_last` corresponds to inputs with shape
620
+ `(batch_size, height, width, channels)` while `channels_first`
621
+ corresponds to inputs with shape
622
+ `(batch_size, channels, height, width)`.
623
+ It defaults to the `image_data_format` value found in your
624
+ Keras config file at `~/.keras/keras.json`.
625
+ If you never set it, then it will be "channels_last".
626
+ interpolation: A string, one of `nearest` or `bilinear`.
627
+ Input shape:
628
+ 4D tensor with shape:
629
+ - If `data_format` is `"channels_last"`:
630
+ `(batch_size, rows, cols, channels)`
631
+ - If `data_format` is `"channels_first"`:
632
+ `(batch_size, channels, rows, cols)`
633
+ Output shape:
634
+ 4D tensor with shape:
635
+ - If `data_format` is `"channels_last"`:
636
+ `(batch_size, upsampled_rows, upsampled_cols, channels)`
637
+ - If `data_format` is `"channels_first"`:
638
+ `(batch_size, channels, upsampled_rows, upsampled_cols)`
639
+ """
640
+ def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', align_corners=True, **kwargs):
641
+ super(UpSampling2D, self).__init__(**kwargs)
642
+ self.data_format = conv_utils.normalize_data_format(data_format)
643
+ self.size = conv_utils.normalize_tuple(size, 2, 'size')
644
+ self.input_spec = InputSpec(ndim=4)
645
+ self.interpolation = interpolation
646
+ if self.interpolation not in {'nearest', 'bilinear', 'linear', 'cubic', 'pchip'}:
647
+ raise ValueError('`interpolation` argument should be one of `"nearest"` '
648
+ 'or `"bilinear"` '
649
+ 'or `"linear"` '
650
+ 'or `"cubic"` '
651
+ 'or `"pchip"`.')
652
+ if self.interpolation == 'bilinear':
653
+ self.interpolation = 'linear'
654
+ self.align_corners = align_corners
655
+
656
+ def compute_output_shape(self, input_shape):
657
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
658
+ if self.data_format == 'channels_first':
659
+ height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
660
+ width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
661
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1], height, width])
662
+ else:
663
+ height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
664
+ width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
665
+ return tensor_shape.TensorShape([input_shape[0], height, width, input_shape[3]])
666
+
667
+ def call(self, inputs):
668
+ return UpInterpolate2D(inputs,
669
+ self.size,
670
+ data_format=self.data_format,
671
+ interpolation=self.interpolation,
672
+ align_corners=self.align_corners)
673
+
674
+ def get_config(self):
675
+ config = {'size': self.size, 'data_format': self.data_format, 'interpolation': self.interpolation}
676
+ base_config = super(UpSampling2D, self).get_config()
677
+ return dict(list(base_config.items()) + list(config.items()))
678
+
679
+
680
+ @keras_export('keras.layers.UpSampling3D')
681
+ class UpSampling3D(Layer):
682
+ """Upsampling layer for 3D inputs.
683
+ Repeats the 1st, 2nd and 3rd dimensions
684
+ of the data by `size[0]`, `size[1]` and `size[2]` respectively.
685
+ Examples:
686
+ >>> input_shape = (2, 1, 2, 1, 3)
687
+ >>> x = tf.constant(1, shape=input_shape)
688
+ >>> y = tf.keras.layers.UpSampling3D(size=2)(x)
689
+ >>> print(y.shape)
690
+ (2, 2, 4, 2, 3)
691
+ Args:
692
+ size: Int, or tuple of 3 integers.
693
+ The upsampling factors for dim1, dim2 and dim3.
694
+ data_format: A string,
695
+ one of `channels_last` (default) or `channels_first`.
696
+ The ordering of the dimensions in the inputs.
697
+ `channels_last` corresponds to inputs with shape
698
+ `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
699
+ while `channels_first` corresponds to inputs with shape
700
+ `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
701
+ It defaults to the `image_data_format` value found in your
702
+ Keras config file at `~/.keras/keras.json`.
703
+ If you never set it, then it will be "channels_last".
704
+ Input shape:
705
+ 5D tensor with shape:
706
+ - If `data_format` is `"channels_last"`:
707
+ `(batch_size, dim1, dim2, dim3, channels)`
708
+ - If `data_format` is `"channels_first"`:
709
+ `(batch_size, channels, dim1, dim2, dim3)`
710
+ Output shape:
711
+ 5D tensor with shape:
712
+ - If `data_format` is `"channels_last"`:
713
+ `(batch_size, upsampled_dim1, upsampled_dim2, upsampled_dim3, channels)`
714
+ - If `data_format` is `"channels_first"`:
715
+ `(batch_size, channels, upsampled_dim1, upsampled_dim2, upsampled_dim3)`
716
+ """
717
+ def __init__(self,
718
+ size=(2, 2, 2),
719
+ data_format=None,
720
+ interpolation='nearest',
721
+ align_corners=True,
722
+ **kwargs):
723
+ super(UpSampling3D, self).__init__(**kwargs)
724
+ self.data_format = conv_utils.normalize_data_format(data_format)
725
+ self.size = conv_utils.normalize_tuple(size, 3, 'size')
726
+ self.input_spec = InputSpec(ndim=5)
727
+ self.interpolation = interpolation
728
+ if interpolation not in {'nearest', 'trilinear', 'linear', 'cubic', 'pchip'}:
729
+ raise ValueError('`interpolation` argument should be one of `"nearest"` '
730
+ 'or `"trilinear"` '
731
+ 'or `"linear"` '
732
+ 'or `"cubic"` '
733
+ 'or `"pchip"`.')
734
+ if self.interpolation == 'trilinear':
735
+ self.interpolation = 'linear'
736
+ self.align_corners = align_corners
737
+
738
+ def compute_output_shape(self, input_shape):
739
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
740
+ if self.data_format == 'channels_first':
741
+ dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None
742
+ dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None
743
+ dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None
744
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1], dim1, dim2, dim3])
745
+ else:
746
+ dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
747
+ dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
748
+ dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
749
+ return tensor_shape.TensorShape([input_shape[0], dim1, dim2, dim3, input_shape[4]])
750
+
751
+ def call(self, inputs):
752
+ return UpInterpolate3D(inputs,
753
+ self.size,
754
+ data_format=self.data_format,
755
+ interpolation=self.interpolation,
756
+ align_corners=self.align_corners)
757
+
758
+ def get_config(self):
759
+ config = {'size': self.size, 'data_format': self.data_format}
760
+ base_config = super(UpSampling3D, self).get_config()
761
+ return dict(list(base_config.items()) + list(config.items()))
DeepDeformationMapRegistration/losses.py CHANGED
@@ -1,12 +1,17 @@
1
  import tensorflow as tf
 
2
  from scipy.ndimage import generate_binary_structure
 
3
 
4
- from DeepDeformationMapRegistration.utils.operators import soft_threshold
5
  from DeepDeformationMapRegistration.utils.constants import EPS_tf
 
6
 
 
 
7
 
8
  class HausdorffDistanceErosion:
9
- def __init__(self, ndim=3, nerosion=10, value_per_channel=False):
10
  """
11
  Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
12
  Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
@@ -14,50 +19,222 @@ class HausdorffDistanceErosion:
14
 
15
  :param ndim: Dimensionality of the images
16
  :param nerosion: Number of erosion steps. Defaults to 10.
17
- :param value_per_channel: Return an array with the HD distance computed on each channel independently or the sum
18
  """
 
19
  self.ndims = ndim
20
- self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
 
 
21
  self.nerosions = nerosion
22
- self.sum_range = tf.range(0, self.ndims) if value_per_channel else None
23
-
24
- def _erode(self, in_tensor, kernel):
25
- out = 1. - tf.squeeze(self.conv(tf.expand_dims(1. - in_tensor, 0), kernel, [1] * (self.ndims + 2), 'SAME'), axis=0)
26
- return soft_threshold(out, 0.5, name='soft_thresholding')
27
 
28
- def _erode_per_channel(self, in_tensor, kernel):
29
- # In the lambda function we add a fictitious channel and then remove it, so the final shape is [1, H, W, D]
30
- er_tensor = tf.map_fn(lambda tens: tf.squeeze(self._erode(tf.expand_dims(tens, -1), kernel)),
31
- tf.transpose(in_tensor, [3, 0, 1, 2]), tf.float32) # Iterate along the channel dimension (3)
 
 
 
 
32
 
33
- return tf.transpose(er_tensor, [1, 2, 3, 0]) # move the channels back to the end
 
 
 
 
 
 
 
 
 
 
34
 
35
  def _erosion_distance_single(self, y_true, y_pred):
36
- diff = tf.math.pow(y_pred - y_true, 2)
37
  alpha = 2
38
 
39
- norm = 1 / (self.ndims * 2 + 1)
40
- kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
41
- kernel = tf.constant(kernel, tf.float32)
42
- kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) # [H, W, D, C_in, C_out]
43
-
44
  ret = 0.
45
  for k in range(1, self.nerosions+1):
46
  er = diff
47
  # k successive erosions
48
  for j in range(k):
49
- er = self._erode_per_channel(er, kernel)
50
- ret += tf.reduce_sum(tf.multiply(er, tf.cast(tf.pow(k, alpha), tf.float32)), self.sum_range)
51
 
52
- img_vol = tf.cast(tf.reduce_prod(tf.shape(y_true)[:-1]), tf.float32) # Volume of each channel
53
- return tf.divide(ret, img_vol) # Divide by the image size
54
 
55
- def loss(self, y_true, y_pred):
 
56
  batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
57
- dtype=tf.float32)
58
 
59
  return tf.reduce_mean(batched_dist)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  class NCC:
63
  def __init__(self, in_shape, eps=EPS_tf):
@@ -69,48 +246,105 @@ class NCC:
69
  f_yp = tf.reshape(y_pred, [-1])
70
  mean_yt = tf.reduce_mean(f_yt)
71
  mean_yp = tf.reduce_mean(f_yp)
72
- std_yt = tf.math.reduce_std(f_yt)
73
- std_yp = tf.math.reduce_std(f_yp)
74
 
75
  n_f_yt = f_yt - mean_yt
76
  n_f_yp = f_yp - mean_yp
77
- numerator = tf.reduce_sum(n_f_yt * n_f_yp)
78
- denominator = std_yt * std_yp * self.__shape_size + self.__eps
 
 
79
  return tf.math.divide_no_nan(numerator, denominator)
80
 
 
81
  def loss(self, y_true, y_pred):
82
  # According to the documentation, the loss returns a scalar
83
  # Ref: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
84
  return tf.reduce_mean(tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  class StructuralSimilarity:
88
  # Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
89
- def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0):
 
 
 
90
  """
91
  Structural (Di)Similarity Index Measure:
92
 
93
  :param k1: Internal parameter. Defaults to 0.01
94
  :param k2: Internal parameter. Defaults to 0.02
95
- :param patch_size: Size of the extracted patches
96
  :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
 
 
 
 
97
  """
98
- self.__c1 = (k1 * dynamic_range) ** 2
99
- self.__c2 = (k2 * dynamic_range) ** 2
100
- self.__kernel_shape = [1] + [patch_size] * 3 + [1]
 
 
 
 
 
 
 
101
  stride = int(patch_size * (1 - overlap))
102
- self.__stride = [1] + [stride if stride else 1] * 3 + [1]
103
- self.__max_val = dynamic_range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  def __int_shape(self, x):
106
  return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
107
 
108
  def ssim(self, y_true, y_pred):
109
-
110
- patches_true = tf.extract_volume_patches(y_true, self.__kernel_shape, self.__stride, 'VALID',
111
- 'patches_true')
112
- patches_pred = tf.extract_volume_patches(y_pred, self.__kernel_shape, self.__stride, 'VALID',
113
- 'patches_pred')
 
 
 
114
 
115
  #bs, w, h, d, *c = self.__int_shape(patches_pred)
116
  #patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
@@ -124,15 +358,496 @@ class StructuralSimilarity:
124
  v_true = tf.math.reduce_variance(patches_true, axis=-1)
125
  v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
126
 
 
 
 
 
127
  # Covariance
128
  covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
129
 
130
  # SSIM
131
- numerator = (2 * u_true * u_pred + self.__c1) * (2 * covar + self.__c2)
132
- denominator = ((tf.square(u_true) + tf.square(u_pred) + self.__c1) * (v_pred + v_true + self.__c2))
133
- ssim = numerator / denominator
 
 
 
134
 
135
- return tf.reduce_mean(ssim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- def dssim(self, y_true, y_pred):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import tensorflow as tf
2
+ import tensorflow.keras.backend as K
3
  from scipy.ndimage import generate_binary_structure
4
+ from sklearn.utils.extmath import cartesian
5
 
6
+ from DeepDeformationMapRegistration.utils.operators import soft_threshold, min_max_norm, hard_threshold
7
  from DeepDeformationMapRegistration.utils.constants import EPS_tf
8
+ from DeepDeformationMapRegistration.utils.misc import function_decorator
9
 
10
+ import numpy as np
11
+ import warnings
12
 
13
  class HausdorffDistanceErosion:
14
+ def __init__(self, ndim=3, nerosion=10, im_shape: [list, tuple] = (64, 64, 64, 1), alpha=2):
15
  """
16
  Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
17
  Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
 
19
 
20
  :param ndim: Dimensionality of the images
21
  :param nerosion: Number of erosion steps. Defaults to 10.
22
+ :param alpha: Parameter to penalize large segmentations. Defaults to 2
23
  """
24
+ assert len(im_shape) == ndim + 1, "im_shape does not match with ndim. Missing channel dimension?"
25
  self.ndims = ndim
26
+ axes = np.arange(0, self.ndims).tolist()
27
+ self.before_erosion_transp = [axes[-1], *axes[:-1]] # [H, W, ..., C] -> [C, H, W, ...]
28
+ self.after_erosion_transp = [*axes[1:], axes[0]] # [C, H, W, ...] -> [H, W, ..., C]
29
  self.nerosions = nerosion
30
+ self.sum_range = tf.range(0, self.ndims)
 
 
 
 
31
 
32
+ self.im_shape = im_shape
33
+ self.im_vol = np.prod(im_shape[:-1])
34
+ kernel = generate_binary_structure(self.ndims, 1).astype(int)
35
+ kernel = kernel / np.sum(kernel)
36
+ kernel = kernel[..., np.newaxis, np.newaxis]
37
+ self.kernel = tf.convert_to_tensor(kernel, tf.float32)
38
+ self.k_alpha = [np.power(k, alpha).astype(float) for k in range(1, nerosion + 1)]
39
+ self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
40
 
41
+ def _erode(self, in_tensor):
42
+ indiv_channels = tf.split(in_tensor, self.im_shape[-1], -1)
43
+ res = list()
44
+ with tf.variable_scope('erode', reuse=tf.AUTO_REUSE):
45
+ for ch in indiv_channels:
46
+ res.append(self.conv(tf.expand_dims(ch, 0), self.kernel, [1] * (self.ndims + 2), 'SAME'))
47
+ # out = -tf.nn.max_pool3d(-tf.expand_dims(in_tensor, 0), [3]*self.ndims, [1]*self.ndims, 'SAME', name='HDE_erosion')
48
+ out = tf.concat(res, -1)
49
+ out = tf.squeeze(out, axis=0)
50
+ out = hard_threshold(out, 0.5, name='thresholding') # soft_threshold(out, 0.5, name='thresholding')
51
+ return out
52
 
53
  def _erosion_distance_single(self, y_true, y_pred):
54
+ diff = tf.math.pow(y_pred - y_true, 2, name='HDE_diff')
55
  alpha = 2
56
 
 
 
 
 
 
57
  ret = 0.
58
  for k in range(1, self.nerosions+1):
59
  er = diff
60
  # k successive erosions
61
  for j in range(k):
62
+ er = self._erode(er) # er contains the eroded version along the channels
63
+ ret += tf.reduce_sum(tf.multiply(er, self.k_alpha[k - 1]), self.sum_range, name='HDE_ret')
64
 
65
+ return tf.divide(ret, self.im_vol) # Divide by the image size
 
66
 
67
+ @function_decorator('Hausdorff_erosion__loss')
68
+ def loss(self, y_true, y_pred, name='HDE_loss'):
69
  batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
70
+ dtype=tf.float32, name=name+'_map_fn')
71
 
72
  return tf.reduce_mean(batched_dist)
73
 
74
+ @function_decorator('Hausdorff_erosion__metric')
75
+ def metric(self, y_true, y_pred):
76
+ return self.loss(y_true, y_pred, name='HDE_metric')
77
+
78
+ def debug(self, y_true, y_pred):
79
+ return tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
80
+ dtype=tf.float32, name='HDE_loss_map_fn')
81
+
82
+
83
+ # class HausdorffDiatanceConvolution:
84
+ # def __init__(self, ndim=3, im_shape: tuple = (64, 64, 64, 1), max_kernel_size=9, step_kernel_size=3, alpha=2):
85
+ # """
86
+ # Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
87
+ # Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
88
+ # Networks". IEEE Transactions on Medical Imaging, 39, 2020. DOI 10.1109/TMI.2019.2930068
89
+ #
90
+ # :param ndim: Dimensionality of the images
91
+ # :param nerosion: Number of erosion steps. Defaults to 10.
92
+ # :param alpha: Parameter to penalize large segmentations. Defaults to 2
93
+ # """
94
+ # assert len(im_shape) == ndim + 1, "im_shape does not match with ndim. Missing channel dimension?"
95
+ # self.ndims = ndim
96
+ # axes = np.arange(0, self.ndims).tolist()
97
+ # self.before_erosion_transp = [axes[-1], *axes[:-1]] # [H, W, ..., C] -> [C, H, W, ...]
98
+ # self.after_erosion_transp = [*axes[1:], axes[0]] # [C, H, W, ...] -> [H, W, ..., C]
99
+ # self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
100
+ # self.sum_range = tf.range(0, self.ndims)
101
+ #
102
+ # self.im_shape = im_shape
103
+ # self.im_vol = np.prod(im_shape[:-1])
104
+ # kernel = generate_binary_structure(self.ndims, 1).astype(int)
105
+ # self.kernel = tf.constant(kernel / np.sum(kernel), tf.float32)
106
+ # self.kernel = tf.expand_dims(tf.expand_dims(self.kernel, -1), -1) # [H, W, D, C_in, C_out]
107
+ # self.kernel = tf.tile(self.kernel, [*[1]*self.ndims, self.im_shape[-1], self.im_shape[-1]])
108
+ # self.alpha = int(alpha)
109
+ # self.radii = np.arange(1, max_kernel_size, step=step_kernel_size)
110
+ # self.radii_alpha = [np.pow(r, alpha).astype(float) for r in self.radii]
111
+ #
112
+ # def soft_diff(self, p, q):
113
+ # return tf.multiply(tf.pow(p - q, 2.), q)
114
+ #
115
+ # def body(self, y_true, y_pred):
116
+ #
117
+
118
+ """
119
+ class HausdorffDistanceErosion_2:
120
+ def __init__(self, im_shape, num_erosions, num_dimensions=3, alpha=2., loop_max_iterations=20):
121
+ self.alpha = alpha
122
+ self.ndims = num_dimensions
123
+ self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
124
+
125
+ self.iterator = tf.constant(num_erosions, name='num_erosions')
126
+ self.norm = 1 / np.prod(im_shape)
127
+ self.erosion_kernel = generate_binary_structure(self.ndims, 1).astype(float)
128
+ self.erosion_kernel /= np.sum(self.erosion_kernel)
129
+ self.erosion_kernel = tf.constant( self.erosion_kernel, tf.float32)
130
+
131
+ self.loop_max_iterations = loop_max_iterations
132
+
133
+ def erosion_sum(self, p, q, k):
134
+ er_tensor = p - q
135
+ er_tensor = tf.pow(er_tensor, 2.)
136
+
137
+ def erode(in_tensor):
138
+ # Erosion of in_tensor = Dilation of (1 - in_tensor)
139
+ return self.conv(tf.expand_dims(1. - in_tensor, 0), self.erosion_kernel, [1] * (self.ndims + 2), 'SAME')
140
+
141
+ def while_loop_body(i, in_tensor):
142
+ in_tensor = erode(in_tensor)
143
+ i -= 1
144
+ return i, in_tensor
145
+
146
+ def while_loop_condition(i, in_tensor):
147
+ return tf.less_equal(i, 1), in_tensor
148
+
149
+ er_iterator = tf.constant(k)
150
+ _, er_tensor = tf.while_loop(while_loop_condition, while_loop_body, loop_vars=[er_iterator, er_tensor],
151
+ maximum_iterations=self.loop_max_iterations)
152
+
153
+ er_tensor *= tf.pow(k, self.alpha)
154
+ return tf.reduce_sum(er_tensor)
155
+
156
+ def loss(self, y_true, y_pred):
157
+ hd_distance = tf.constant(0, name='hausdroff_distance')
158
+
159
+ def while_loop_body(i, p, q, ret):
160
+ i -= 1
161
+ return i, p, q, ret + self.erosion_sum(p, q, i)
162
+
163
+ _, _, _, hd_distance = tf.while_loop(lambda i, p, q, ret: tf.less_equal(i, 1),
164
+ while_loop_body,
165
+ loop_vars=[self.iterator, y_pred, y_true, hd_distance])
166
+ hd_distance /= self.norm
167
+ return hd_distance
168
+
169
+ """
170
+
171
+
172
+ class WeightedHausdorffDistance:
173
+ def __init__(self, input_shape, alpha=-1, threshold=0.5):
174
+ """
175
+ WARNING: Requires a insane amount of memory
176
+ :param input_shape: [H, W, D, C] or [H, W, C]
177
+ :param alpha: Parameter of the generalized mean. Ideally -inf, but then the function becomes less smooth.
178
+ :param threshold: Threshold of segmentations, used in tf.where function
179
+ """
180
+ warnings.warn("This function requires an insane amount of memory")
181
+ self.input_shape = input_shape
182
+ self.dim = len(input_shape[:-1])
183
+ self.ohe_segm = bool(input_shape[-1] > 1) # One-Hot Encoded segmentations on the channel axis
184
+ aux = np.arange(len(self.input_shape)).tolist()
185
+ self.ohe_transpose = [aux[-1], *aux[:-1]]
186
+ self.alpha = alpha
187
+ self.threshold = threshold
188
+ list_coords = [np.arange(c) for c in self.input_shape[:-1]]
189
+ self.img_loc = tf.convert_to_tensor(cartesian(list_coords), dtype=tf.float32)
190
+ self.max_dist = np.sqrt(np.sum(np.square(self.input_shape[:-1]))) # Largest diagonal
191
+
192
+ def pairwise_distance(self, A, B):
193
+ sq_norm_a = tf.reduce_sum(tf.square(A), 1)
194
+ sq_norm_b = tf.reduce_sum(tf.square(B), 1)
195
+
196
+ sq_norm_a = tf.reshape(sq_norm_a, [-1, 1])
197
+ sq_norm_b = tf.reshape(sq_norm_b, [1, -1])
198
+
199
+ return tf.sqrt(tf.maximum(sq_norm_a - 2 * tf.matmul(A, B, transpose_a=False, transpose_b=True) + sq_norm_b, 0.))
200
+
201
+ def hausdorff(self, y_true, y_pred):
202
+ if self.ohe_segm:
203
+ y_true = tf.transpose(y_true, self.ohe_transpose)
204
+ y_pred = tf.transpose(y_pred, self.ohe_transpose)
205
+ hausdorff_per_ch = tf.map_fn(lambda x: self.hausdorff_per_channel(x[0], x[1]), (y_true, y_pred), tf.float32)
206
+ return tf.reduce_mean(hausdorff_per_ch)
207
+ else:
208
+ return self.hausdorff_per_channel(y_true, y_pred)
209
+
210
+ def hausdorff_per_channel(self, y_true, y_pred):
211
+ Y = tf.cast(tf.where(y_true > self.threshold), dtype=tf.float32)
212
+ p = K.flatten(y_pred) # Flatten the predicted segmentation (activation map 'p' in d_WH)
213
+
214
+ size_Y = tf.shape(Y)[0]
215
+ S = tf.reduce_sum(p)
216
+
217
+ p = tf.squeeze(K.repeat(tf.expand_dims(p, -1), size_Y))
218
+ dist_mat = self.pairwise_distance(self.img_loc, Y)
219
+
220
+ term_1 = tf.reduce_sum(p * tf.minimum(dist_mat, 1)) / (S + EPS_tf)
221
+
222
+ term_2 = tf.minimum((dist_mat + EPS_tf) / (tf.pow(p, self.alpha) + (EPS_tf / self.max_dist)), 0.)
223
+ term_2 = tf.clip_by_value(term_2, 0., self.max_dist)
224
+ term_2 = tf.reduce_mean(term_2, axis=0)
225
+
226
+ return term_1 + term_2
227
+
228
+ @function_decorator('Weighted_Hausdorff__loss')
229
+ def loss(self, y_true, y_pred):
230
+ batch_hdist = tf.map_fn(lambda x: self.hausdorff(x[0], x[1]), (y_true, y_pred), dtype=tf.float32)
231
+
232
+ return tf.reduce_mean(batch_hdist)
233
+
234
+ @function_decorator('Weighted_Hausdorff__metric')
235
+ def metric(self, y_true, y_pred):
236
+ return self.loss(y_true, y_pred)
237
+
238
 
239
  class NCC:
240
  def __init__(self, in_shape, eps=EPS_tf):
 
246
  f_yp = tf.reshape(y_pred, [-1])
247
  mean_yt = tf.reduce_mean(f_yt)
248
  mean_yp = tf.reduce_mean(f_yp)
 
 
249
 
250
  n_f_yt = f_yt - mean_yt
251
  n_f_yp = f_yp - mean_yp
252
+ norm_yt = tf.norm(f_yt, ord='euclidean')
253
+ norm_yp = tf.norm(f_yp, ord='euclidean')
254
+ numerator = tf.reduce_sum(tf.multiply(n_f_yt, n_f_yp))
255
+ denominator = norm_yt * norm_yp + self.__eps
256
  return tf.math.divide_no_nan(numerator, denominator)
257
 
258
+ @function_decorator('NCC__loss')
259
  def loss(self, y_true, y_pred):
260
  # According to the documentation, the loss returns a scalar
261
  # Ref: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
262
  return tf.reduce_mean(tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
263
 
264
+ @function_decorator('NCC__metric')
265
+ def metric(self, y_true, y_pred):
266
+ return tf.reduce_mean(tf.map_fn(lambda x: self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
267
+
268
+
269
+ def ncc(y_true, y_pred):
270
+ y_true = K.flatten(K.cast(y_true, 'float32'))
271
+ y_pred = K.flatten(K.cast(y_pred, 'float32'))
272
+
273
+ mean_true = K.mean(y_true)
274
+ mean_pred = K.mean(y_pred)
275
+
276
+ std_true = K.std(y_true)
277
+ std_pred = K.std(y_pred)
278
+
279
+ num = K.mean((y_true - mean_true) * (y_pred - mean_pred))
280
+ den = std_true * std_pred + EPS_tf
281
+ batch_ncc = num / den
282
+
283
+ return K.mean(batch_ncc)
284
+
285
 
286
  class StructuralSimilarity:
287
  # Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
288
+ def __init__(self, k1=0.01, k2=0.03,
289
+ patch_size=32, dynamic_range=1., overlap=0.0, dim=3,
290
+ alpha=1., beta=1., gamma=1.,
291
+ **kwargs):
292
  """
293
  Structural (Di)Similarity Index Measure:
294
 
295
  :param k1: Internal parameter. Defaults to 0.01
296
  :param k2: Internal parameter. Defaults to 0.02
297
+ :param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
298
  :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
299
+ :param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
300
+ :param dim: Data dimensionality. Must be {1, 2, 3}. Defaults to 3.
301
+ :param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
302
+ structure measures. Default to 1.
303
  """
304
+ assert (dim > 0) and (dim < 4), 'Invalid dimension. It must be 1, 2, or 3'
305
+ assert overlap < 1., 'Invalid overlap. It must be in the range [0., 1.)'
306
+ self.c1 = (k1 * dynamic_range) ** 2
307
+ self.c2 = (k2 * dynamic_range) ** 2
308
+ self.c3 = self.c2 / 2
309
+ self.alpha = tf.cast(alpha, tf.float32)
310
+ self.beta = tf.cast(beta, tf.float32)
311
+ self.gamma = tf.cast(gamma, tf.float32)
312
+
313
+ self.kernel_shape = [1] + [patch_size] * dim + [1]
314
  stride = int(patch_size * (1 - overlap))
315
+ self.stride = [1] + [stride if stride else 1] * dim + [1]
316
+ self.dim = dim
317
+ self.patch_extractor = None
318
+ self.reduce_axis = list()
319
+ if dim == 2:
320
+ self.patch_extractor = tf.extract_image_patches
321
+ self.reduce_axis = [1, 2]
322
+ elif dim == 3:
323
+ self.patch_extractor = tf.extract_volume_patches
324
+ self.reduce_axis = [1, 2, 3]
325
+ else:
326
+ raise ValueError('Invalid dimension value. Expected 2 or 3')
327
+
328
+ if patch_size == -1:
329
+ # Don't extract patches
330
+ self.dim = 1
331
+
332
+ self.L = None # Luminance
333
+ self.C = None # Contrast
334
+ self.S = None # Structure
335
 
336
  def __int_shape(self, x):
337
  return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
338
 
339
  def ssim(self, y_true, y_pred):
340
+ if self.dim > 1:
341
+ # Don't use for training. The gradient doesn't backpropagate through the patch extractors
342
+ # patches: [B, out_rows, out_cols, ..., krows*kcols*...*channels] -> out_rows * out_cols * ... = nb patches
343
+ patches_true = self.patch_extractor(y_true, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_true')
344
+ patches_pred = self.patch_extractor(y_pred, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_pred')
345
+ else:
346
+ patches_true = y_true
347
+ patches_pred = y_pred
348
 
349
  #bs, w, h, d, *c = self.__int_shape(patches_pred)
350
  #patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
 
358
  v_true = tf.math.reduce_variance(patches_true, axis=-1)
359
  v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
360
 
361
+ # Standard dev.
362
+ s_true = tf.sqrt(v_true)
363
+ s_pred = tf.sqrt(v_pred)
364
+
365
  # Covariance
366
  covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
367
 
368
  # SSIM
369
+ self.L = (2 * u_true * u_pred + self.c1) / (tf.square(u_true) + tf.square(u_pred) + self.c1)
370
+ self.C = (2 * s_true * s_pred + self.c2) / (v_true + v_pred + self.c2)
371
+ self.S = (covar + self.c3) / (s_true * s_pred + self.c3)
372
+ self.L = tf.reduce_mean(self.L, axis=self.reduce_axis)
373
+ self.C = tf.reduce_mean(self.C, axis=self.reduce_axis)
374
+ self.S = tf.reduce_mean(self.S, axis=self.reduce_axis)
375
 
376
+ return tf.pow(self.L, self.alpha) * tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma)
377
+
378
+ @function_decorator('SSIM__loss')
379
+ def loss(self, y_true, y_pred):
380
+ return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)
381
+
382
+ @function_decorator('SSIM__metric')
383
+ def metric(self, y_true, y_pred):
384
+ return tf.reduce_mean(self.ssim(y_true, y_pred))
385
+
386
+
387
+ class StructuralSimilarity_simplified:
388
+ # Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
389
+ def __init__(self, k1=0.01, k2=0.03,
390
+ patch_size=32, dynamic_range=1., overlap=0.0, dim=3,
391
+ alpha=1., beta=1., gamma=1.,
392
+ **kwargs):
393
+ """
394
+ Structural (Di)Similarity Index Measure:
395
+
396
+ :param k1: Internal parameter. Defaults to 0.01
397
+ :param k2: Internal parameter. Defaults to 0.02
398
+ :param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
399
+ :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
400
+ :param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
401
+ :param dim: Data dimensionality. Must be {1, 2, 3}. Defaults to 3.
402
+ :param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
403
+ structure measures. Default to 1.
404
+ """
405
+ assert (dim > 0) and (dim < 4), 'Invalid dimension. It must be 1, 2, or 3'
406
+ assert overlap < 1., 'Invalid overlap. It must be in the range [0., 1.)'
407
+ self.c1 = (k1 * dynamic_range) ** 2
408
+ self.c2 = (k2 * dynamic_range) ** 2
409
+ self.c3 = self.c2 / 2
410
+ self.alpha = tf.cast(alpha, tf.float32)
411
+ self.beta = tf.cast(beta, tf.float32)
412
+ self.gamma = tf.cast(gamma, tf.float32)
413
+
414
+ self.kernel_shape = [1] + [patch_size] * dim + [1]
415
+ stride = int(patch_size * (1 - overlap))
416
+ self.stride = [1] + [stride if stride else 1] * dim + [1]
417
+ self.dim = dim
418
+ self.patch_extractor = None
419
+ if dim == 2:
420
+ self.patch_extractor = tf.extract_image_patches
421
+ elif dim == 3:
422
+ self.patch_extractor = tf.extract_volume_patches
423
+
424
+ if patch_size == -1:
425
+ # Don't extract patches
426
+ self.dim = 1
427
+
428
+ self.L = None # Luminance
429
+ self.C = None # Contrast
430
+ self.S = None # Structure
431
 
432
+ def __int_shape(self, x):
433
+ return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
434
+
435
+ def ssim(self, y_true, y_pred):
436
+ if self.dim > 1:
437
+ # Don't use for training. The gradient doesn't backpropagate through the patch extractors
438
+ # patches: [B, out_rows, out_cols, ..., krows*kcols*...*channels] -> out_rows * out_cols * ... = nb patches
439
+ patches_true = self.patch_extractor(y_true, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_true')
440
+ patches_pred = self.patch_extractor(y_pred, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_pred')
441
+ else:
442
+ patches_true = y_true
443
+ patches_pred = y_pred
444
+
445
+ #bs, w, h, d, *c = self.__int_shape(patches_pred)
446
+ #patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
447
+ #patches_pred = tf.reshape(patches_pred, [-1, w, h, d, tf.reduce_prod(c)])
448
+
449
+ # Mean
450
+ u_true = tf.reduce_mean(patches_true, axis=-1)
451
+ u_pred = tf.reduce_mean(patches_pred, axis=-1)
452
+
453
+ # Variance
454
+ v_true = tf.math.reduce_variance(patches_true, axis=-1)
455
+ v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
456
+
457
+ # Covariance
458
+ covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
459
+
460
+ # return tf.pow(self.L, self.alpha) * tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma)
461
+ num = (2 * u_true * u_pred + self.c1) * (2 * covar + self.c2)
462
+ den = ((tf.square(u_true) + tf.square(u_pred) + self.c1) * (v_pred + v_true + self.c2))
463
+ return num / den
464
+
465
+ @function_decorator('SSIM_simple__loss')
466
+ def loss(self, y_true, y_pred):
467
  return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)
468
+
469
+ @function_decorator('SSIM_simple__metric')
470
+ def metric(self, y_true, y_pred):
471
+ return tf.reduce_mean(self.ssim(y_true, y_pred))
472
+
473
+
474
+ class MultiScaleStructuralSimilarity(StructuralSimilarity):
475
+ def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0, dim=3, nscales=3, alpha=1., beta=1., gamma=1.):
476
+ """
477
+ Multi Scale Structural (Di)Similarity Index Measure:
478
+ Ref: [1] https://www.cns.nyu.edu/pub/eero/wang03b.pdf
479
+
480
+ :param k1: Internal parameter. Defaults to 0.01
481
+ :param k2: Internal parameter. Defaults to 0.02
482
+ :param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
483
+ :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
484
+ :param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
485
+ :param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
486
+ :param nscales: Number of scales to analyze. Defaults to 3.
487
+ :param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
488
+ structure measures. Default to 1.
489
+ """
490
+ assert dim > 1, 'Cannot be used with 1-D data'
491
+ super(MultiScaleStructuralSimilarity, self).__init__(k1=k1, k2=k2, patch_size=patch_size,
492
+ dynamic_range=dynamic_range, overlap=overlap, dim=dim,
493
+ alpha=alpha, beta=beta, gamma=gamma)
494
+ self.num_scales = nscales
495
+ self.avg_pool = getattr(tf.nn, 'avg_pool%dd' % dim)
496
+ self.ds_stride = self.ds_kernel = [1] + [2]*dim + [1]
497
+
498
+ # In [1] these are set to the same value at the same scales and normalized across scales
499
+ self.alpha = self.beta = self.gamma = 1 / nscales
500
+
501
+ def _cond(self, cs_prod, scale_level, y_true, y_pred):
502
+ return tf.less_equal(scale_level, self.num_scales)
503
+
504
+ def _iteration(self, cs_prod, scale_level, y_true, y_pred):
505
+ super(MultiScaleStructuralSimilarity, self).ssim(y_true, y_pred)
506
+ cs_prod *= tf.reduce_mean(tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma))
507
+ y_true = self.avg_pool(y_true, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
508
+ y_pred = self.avg_pool(y_pred, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
509
+ scale_level += 1
510
+ return cs_prod, scale_level, y_true, y_pred,
511
+
512
+ def ssim(self, y_true, y_pred):
513
+ return self.ms_ssim(y_true, y_pred)
514
+
515
+ def ms_ssim(self, y_true, y_pred):
516
+ cs_prod = tf.constant(1.)
517
+ scale_level = tf.constant(1.)
518
+ cs_prod, *_ = tf.while_loop(self._cond,
519
+ self._iteration,
520
+ (cs_prod, scale_level, y_true, y_pred),
521
+ (cs_prod.get_shape(), scale_level.get_shape(),
522
+ tf.TensorShape(([1] + [None] * self.dim + [1])),
523
+ tf.TensorShape(([1] + [None] * self.dim + [1]))))
524
+
525
+ ms_ssim = tf.reduce_mean(tf.pow(self.L, self.alpha)) * cs_prod
526
+
527
+ return tf.reduce_mean(ms_ssim)
528
+
529
+ @function_decorator('MS_SSIM__loss')
530
+ def loss(self, y_true, y_pred):
531
+ return tf.reduce_mean((1. - self.ms_ssim(y_true, y_pred)) / 2.0)
532
+
533
+
534
+ class MultiScaleStructuralSimilarity_v2(StructuralSimilarity):
535
+ def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0, dim=3, nscales=3, alpha=1., beta=1., gamma=1.):
536
+ """
537
+ Multi Scale Structural (Di)Similarity Index Measure:
538
+ Ref: [1] https://www.cns.nyu.edu/pub/eero/wang03b.pdf
539
+
540
+ :param k1: Internal parameter. Defaults to 0.01
541
+ :param k2: Internal parameter. Defaults to 0.02
542
+ :param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
543
+ :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
544
+ :param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
545
+ :param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
546
+ :param nscales: Number of scales to analyze. Defaults to 3.
547
+ :param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
548
+ structure measures. Default to 1.
549
+ """
550
+ assert dim > 1, 'Cannot be used with 1-D data'
551
+ super(MultiScaleStructuralSimilarity_v2, self).__init__(k1=k1, k2=k2, patch_size=patch_size,
552
+ dynamic_range=dynamic_range, overlap=overlap, dim=dim,
553
+ alpha=alpha, beta=beta, gamma=gamma)
554
+ self.num_scales = nscales
555
+ self.avg_pool = getattr(tf.nn, 'avg_pool%dd' % dim)
556
+ self.ds_stride = self.ds_kernel = [1] + [2]*dim + [1]
557
+
558
+ # In [1] these are set to the same value at the same scales and normalized across scales
559
+ self.alpha = self.beta = self.gamma = 1 / nscales
560
+
561
+ def _cond(self, cs_prod, scale_level, y_true, y_pred):
562
+ return tf.less_equal(scale_level, self.num_scales)
563
+
564
+ def _iteration(self, cs_prod, scale_level, y_true, y_pred):
565
+ super(MultiScaleStructuralSimilarity_v2, self).ssim(y_true, y_pred)
566
+ cs_prod *= tf.reduce_mean(tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma))
567
+ y_true = self.avg_pool(y_true, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
568
+ y_pred = self.avg_pool(y_pred, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
569
+ scale_level += 1
570
+ return cs_prod, scale_level, y_true, y_pred,
571
+
572
+ def ssim(self, y_true, y_pred):
573
+ return self.ms_ssim(y_true, y_pred)
574
+
575
+ def ms_ssim(self, y_true, y_pred):
576
+ cs_prod = tf.constant(1.)
577
+ scale_level = tf.constant(1.)
578
+ cs_prod, *_ = tf.while_loop(self._cond,
579
+ self._iteration,
580
+ (cs_prod, scale_level, y_true, y_pred),
581
+ (cs_prod.get_shape(), scale_level.get_shape(),
582
+ tf.TensorShape(([1] + [None] * self.dim + [1])),
583
+ tf.TensorShape(([1] + [None] * self.dim + [1]))))
584
+
585
+ ms_ssim = tf.reduce_mean(tf.pow(self.L, self.alpha)) * cs_prod
586
+
587
+ return tf.reduce_mean(ms_ssim)
588
+
589
+ @function_decorator('MS_SSIM_v2__loss')
590
+ def loss(self, y_true, y_pred):
591
+ return tf.reduce_mean((1. - self.ms_ssim(y_true, y_pred)) / 2.0)
592
+
593
+
594
+ class StructuralSimilarityGaussian:
595
+ # This is equivalent to StructuralSimilarity(patch_size=img_size)
596
+ def __init__(self, k1=0.01, k2=0.03, dynamic_range=1., gauss_sigma=5., dim=3, alpha=1., beta=1., gamma=1.):
597
+ """
598
+ SSIM using Gaussian filter to approximate the statistics of the images
599
+ Ref: https://www.cns.nyu.edu/pub/eero/wang03b.pdf
600
+ https://arxiv.org/pdf/1511.08861.pdf
601
+ https://github.com/NVlabs/PL4NN/blob/master/src/loss.py
602
+
603
+ :param k1: Internal parameter. Defaults to 0.01
604
+ :param k2: Internal parameter. Defaults to 0.02
605
+ :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
606
+ :param gauss_sigma: Sigma of the Gaussian filter. Defaults to 1.5.
607
+ :param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
608
+ :param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
609
+ structure measures. Default to 1.
610
+ """
611
+ self.c1 = (k1 * dynamic_range) ** 2
612
+ self.c2 = (k2 * dynamic_range) ** 2
613
+ self.c3 = self.c2 / 2
614
+ self.alpha = tf.cast(alpha, tf.float32)
615
+ self.beta = tf.cast(beta, tf.float32)
616
+ self.gamma = tf.cast(gamma, tf.float32)
617
+ self.dim = dim
618
+ self.convDN = getattr(tf.nn, 'conv%dd' % dim)
619
+ self.sigma = gauss_sigma
620
+
621
+ def build_gaussian_filter(self, size, sigma, num_channels=1):
622
+ range_1d = tf.range(-(size/2) + 1, size//2 + 1)
623
+ g_1d = tf.math.exp(-1.0 * tf.pow(range_1d, 2) / (2. * tf.pow(sigma, 2)))
624
+ g_1d_expanded = tf.expand_dims(g_1d, -1)
625
+ iterator = tf.constant(1)
626
+ self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
627
+ lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
628
+ [iterator, g_1d],
629
+ [iterator.get_shape(), tf.TensorShape([None]*self.dim)] # Shape invariants
630
+ )[-1]
631
+
632
+ self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
633
+ self.__GF = tf.reshape(self.__GF, (*[size]*self.dim, 1, 1)) # Add Ch_in and Ch_out for convolution
634
+ self.__GF = tf.tile(self.__GF, (*[1] * self.dim, num_channels, num_channels,))
635
+
636
+ def format_data(self, in_data):
637
+ ret_val = in_data
638
+ if self.dim == 3:
639
+ ret_val = tf.transpose(ret_val, [0, 3, 1, 2, 4])
640
+ return ret_val
641
+
642
+ def ssim(self, y_true, y_pred):
643
+ self.build_gaussian_filter(y_pred.shape[1], self.sigma)
644
+ y_true_tr = self.format_data(y_true)
645
+ y_pred_tr = self.format_data(y_pred)
646
+
647
+ u_true = self.convDN(y_true_tr, self.__GF, [1] * (self.dim + 2), 'SAME')
648
+ u_pred = self.convDN(y_pred_tr, self.__GF, [1] * (self.dim + 2), 'SAME')
649
+
650
+ v_true = self.convDN(tf.pow(y_true_tr, 2), self.__GF, [1] * (self.dim + 2), 'SAME') - tf.pow(u_true, 2)
651
+ v_pred = self.convDN(tf.pow(y_pred_tr, 2), self.__GF, [1] * (self.dim + 2), 'SAME') - tf.pow(u_pred, 2)
652
+ covar = self.convDN(tf.multiply(y_true_tr, y_pred_tr), self.__GF, [1] * (self.dim + 2), 'SAME') - u_true * u_pred
653
+
654
+ self.L = (2 * u_true * u_pred + self.c1) / (tf.square(u_true) + tf.square(u_pred) + self.c1)
655
+ self.C = (2 * tf.sqrt(v_true) * tf.sqrt(v_pred) + self.c2) / (v_true + v_pred + self.c2)
656
+ self.S = (covar + self.c3) / (tf.sqrt(v_true) * tf.sqrt(v_pred) + self.c3)
657
+ ssim = tf.pow(self.L, self.alpha) * tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma)
658
+
659
+ return tf.reduce_mean(ssim)
660
+
661
+ @function_decorator('SSIM_Gaus__loss')
662
+ def loss(self, y_true, y_pred):
663
+ return tf.reduce_mean((1. - self.ssim(y_true, y_pred))/2.)
664
+
665
+ @function_decorator('SSIM_Gaus__metric')
666
+ def metric(self, y_true, y_pred):
667
+ return tf.reduce_mean(self.ssim(y_true, y_pred))
668
+
669
+
670
+ class MultiScaleStructuralSimilarityGaussian(StructuralSimilarityGaussian):
671
+ def __init__(self, k1=0.01, k2=0.03, dynamic_range=1., gauss_sigma=5., dim=3, nscales=3, alpha=1., beta=1., gamma=1.):
672
+ """
673
+ Multi Scale SSIM inheriting from StructuralSimilarityGaussian classed
674
+ Ref: https://www.cns.nyu.edu/pub/eero/wang03b.pdf
675
+ https://arxiv.org/pdf/1511.08861.pdf
676
+ https://github.com/NVlabs/PL4NN/blob/master/src/loss.py
677
+
678
+ :param k1: Internal parameter. Defaults to 0.01
679
+ :param k2: Internal parameter. Defaults to 0.02
680
+ :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
681
+ :param gauss_sigma: Sigma of the Gaussian filter. Defaults to 1.5.
682
+ :param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
683
+ :param nscales: Number of scales to analyze. Defaults to 3.
684
+ :param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
685
+ structure measures. Default to 1.
686
+ """
687
+ super(MultiScaleStructuralSimilarityGaussian, self).__init__(k1=k1, k2=k2, dynamic_range=dynamic_range,
688
+ gauss_sigma=gauss_sigma, dim=dim,
689
+ alpha=alpha, beta=beta, gamma=gamma)
690
+ self.__num_scales = nscales
691
+
692
+ # # If using the Gaussian approximation of the pyramid MS approach described in https://arxiv.org/pdf/1511.08861.pdf
693
+ # def build_sigma_scales(self):
694
+ # iterator = tf.constant(0)
695
+ # scales = tf.expand_dims(self.sigma, -1)
696
+ # last_sigma = scales
697
+ # self.sigma_scales = tf.while_loop(lambda iterator, last_sigma, scales: tf.less_equal(iterator, self.__num_scales),
698
+ # lambda iterator, last_sigma, scales: (iterator + 1, tf.concat([scales, last_sigma/2], 0), last_sigma/2),
699
+ # [iterator, last_sigma, scales])[-1]
700
+ #
701
+ # def build_gaussian_filters_scales(self, size):
702
+ # self.__GFS = tf.map_fn(lambda sigma: self.build_gaussian_filter(size, sigma), self.sigma, tf.float32)
703
+
704
+ def _iteration(self, cs_prod, scale_level, y_true, y_pred):
705
+ # Compute the SSIM, so CS and L have the correct value
706
+ self.ssim(y_true, y_pred)
707
+
708
+ cs_prod *= tf.reduce_mean(tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma))
709
+ scale_level += 1
710
+
711
+ # Downsample the images to half the resolution for the next iteration
712
+ y_true = tf.nn.avg_pool(y_true, [1] + [2]*self.dim + [1], [1] + [2]*self.dim + [1], 'SAME')
713
+ y_pred = tf.nn.avg_pool(y_true, [1] + [2]*self.dim + [1], [1] + [2]*self.dim + [1], 'SAME')
714
+ return cs_prod, scale_level, y_true, y_pred
715
+
716
+ def ms_ssim(self, y_true, y_pred):
717
+ scale_level = tf.constant(0.)
718
+ cs_prod = tf.constant(1.)
719
+ cs_prod, *_ = tf.while_loop(tf.less(scale_level, self.__num_scales),
720
+ self._iteration,
721
+ (cs_prod, scale_level, y_true, y_pred),
722
+ (cs_prod.get_shape(), scale_level.get_shape(),
723
+ tf.TensorShape(([1] + [None]*self.dim + [1])),
724
+ tf.TensorShape(([1] + [None]*self.dim + [1]))))
725
+ # L is taken from the last scale
726
+ return tf.reduce_mean(tf.pow(self.L, self.alfa)) * cs_prod
727
+
728
+ @function_decorator('MS_SSIM_Gaus__metric')
729
+ def loss(self, y_true, y_pred):
730
+ return tf.reduce_mean((1. - self.ms_ssim(y_true, y_pred))/2.)
731
+
732
+
733
+ class DICEScore:
734
+ def __init__(self, input_shape: list):
735
+ """
736
+ DICE Score.
737
+ :param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
738
+ """
739
+ self.axes = list(range(1, len(input_shape))) # The list will not include the channel axis [1, ..., num_dims)
740
+
741
+ def dice(self, y_true, y_pred):
742
+ numerator = 2 * tf.reduce_sum(y_true * y_pred, self.axes)
743
+ denominator = tf.reduce_sum(y_true + y_pred, self.axes)
744
+ return tf.reduce_mean(tf.div_no_nan(numerator, denominator))
745
+
746
+ @function_decorator('DICE__loss')
747
+ def loss(self, y_true, y_pred):
748
+ return 1 - 2 * tf.reduce_mean(self.dice(y_true, y_pred))
749
+
750
+ @function_decorator('DICE__metric')
751
+ def metric(self, y_true, y_pred):
752
+ return tf.reduce_mean(self.dice(y_true, y_pred))
753
+
754
+
755
+ class GeneralizedDICEScore:
756
+ def __init__(self, input_shape: list, num_labels: int=None):
757
+ """
758
+ Generalized DICE Score. Implementation based on Carole H. Sudre, et al., "Generalised DIce Overlap as a Deep
759
+ Learning Los Function for Highly Unbalanced Segmentations" https://arxiv.org/abs/1707.03237
760
+ :param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
761
+ """
762
+ if input_shape[-1] > 1:
763
+ self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
764
+ self.hot_encode = False
765
+ elif num_labels is not None:
766
+ self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), num_labels]
767
+ self.one_hot_enc_shape = [-1, *input_shape[:-1]]
768
+ self.hot_encode = True
769
+ warnings.warn('Differentiable one-hot encoding not yet implemented')
770
+ else:
771
+ raise ValueError('If input_shape is not one hot encoded, then num_labels must be provided')
772
+
773
+ def one_hot_encoding(self, in_img, name=''):
774
+ # TODO: Test if differentiable!
775
+ labels, indices = tf.unique(tf.reshape(in_img, [-1]), tf.int32, name=name+'_unique')
776
+ one_hot = tf.one_hot(indices, tf.size(labels), name=name + '_one_hot')
777
+ one_hot = tf.reshape(one_hot, self.one_hot_enc_shape + [tf.size(labels)], name=name + '_reshape')
778
+ one_hot = tf.slice(one_hot, [0]*len(self.one_hot_enc_shape) + [1], [-1]*(len(self.one_hot_enc_shape) + 1),
779
+ name=name + '_remove_bg')
780
+ return one_hot
781
+
782
+ def weigthed_dice(self, y_true, y_pred):
783
+ # y_true = [B, -1, L]
784
+ # y_pred = [B, -1, L]
785
+ if self.hot_encode:
786
+ y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
787
+ y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
788
+ y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
789
+ y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
790
+
791
+ size_y_true = tf.reduce_sum(y_true, axis=1, name='GDICE_size_y_true')
792
+ size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
793
+ w = tf.div_no_nan(1., tf.pow(size_y_true, 2), name='GDICE_weight')
794
+ numerator = w * tf.reduce_sum(y_true * y_pred, axis=1)
795
+ denominator = w * (size_y_true + size_y_pred)
796
+ return tf.div_no_nan(2 * tf.reduce_sum(numerator, axis=-1), tf.reduce_sum(denominator, axis=-1))
797
+
798
+ def macro_dice(self, y_true, y_pred):
799
+ # y_true = [B, -1, L]
800
+ # y_pred = [B, -1, L]
801
+ if self.hot_encode:
802
+ y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
803
+ y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
804
+ y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
805
+ y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
806
+
807
+ size_y_true = tf.reduce_sum(y_true, axis=1, name='GDICE_size_y_true')
808
+ size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
809
+ numerator = tf.reduce_sum(y_true * y_pred, axis=1)
810
+ denominator = (size_y_true + size_y_pred)
811
+ return tf.div_no_nan(2 * numerator, denominator)
812
+
813
+ @function_decorator('GeneralizeDICE__loss')
814
+ def loss(self, y_true, y_pred):
815
+ return 1 - tf.reduce_mean(self.weigthed_dice(y_true, y_pred))
816
+
817
+ @function_decorator('GeneralizeDICE__metric')
818
+ def metric(self, y_true, y_pred):
819
+ return tf.reduce_mean(self.weigthed_dice(y_true, y_pred))
820
+
821
+ @function_decorator('GeneralizeDICE__loss_macro')
822
+ def loss_macro(self, y_true, y_pred):
823
+ return 1 - tf.reduce_mean(self.macro_dice(y_true, y_pred))
824
+
825
+ @function_decorator('GeneralizeDICE__metric_macro')
826
+ def metric_macro(self, y_true, y_pred):
827
+ return tf.reduce_mean(self.macro_dice(y_true, y_pred))
828
+
829
+
830
+ def target_registration_error(y_true, y_pred, average=True):
831
+ '''
832
+ Target Registration Error measured as the average distance between y_true and y_pred
833
+ :param y_true: [N, D] target points
834
+ :param y_pred: [N, D] predicted points
835
+ :param average: return the average TRE or an [N,] array
836
+ :return: averate TRE or [N,] array of TRE for each point
837
+ '''
838
+ assert y_true.shape == y_pred.shape, "y_true and y_pred must have the same shape"
839
+ if average:
840
+ return tf.reduce_mean(tf.linalg.norm(y_pred - y_true, axis=1))
841
+ else:
842
+ return tf.linalg.norm(y_pred - y_true, axis=1)
843
+
844
+ # TODO: tensorflow-graphic has an implementation of Hausdorff ditance.
845
+ # However, this is not where it should and I can't find it
846
+ # def HausdorffDistance_exact(y_true, y_pred, ohe=False, name='hd_exact'):
847
+ # if ohe:
848
+ # y_true = tf.transpose(y_true, [0, 4, 1, 2, 3])
849
+ # y_pred = tf.transpose(y_pred, [0, 4, 1, 2, 3])
850
+ # y_true_coords = tf.where(y_true)
851
+ # y_pred_coords = tf.where(y_pred)
852
+ #
853
+ # return tfg_nn.loss.hausdorff_distance.evaluate(y_true_coords, y_pred_coords, name=name)
DeepDeformationMapRegistration/ms_ssim_tf.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SRC: https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/python/ops/image_ops_impl.py
2
+ from tensorflow.python import nn_ops
3
+ from tensorflow.python import math_ops
4
+ from tensorflow.python import array_ops
5
+ from tensorflow.python.framework import dtypes
6
+ from tensorflow.python.framework import ops
7
+ from tensorflow.python.framework import constant_op
8
+ from tensorflow.python.ops import control_flow_ops
9
+ from tensorflow.python.ops import nn
10
+ from tensorflow.python.util.tf_export import tf_export
11
+ from tensorflow.python.util import dispatch
12
+ from DeepDeformationMapRegistration.utils.misc import function_decorator
13
+
14
+
15
+ @tf_export('image.convert_image_dtype')
16
+ @dispatch.add_dispatch_support
17
+ def convert_image_dtype(image, dtype, saturate=False, name=None):
18
+ """Convert `image` to `dtype`, scaling its values if needed.
19
+ The operation supports data types (for `image` and `dtype`) of
20
+ `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
21
+ `float16`, `float32`, `float64`, `bfloat16`.
22
+ Images that are represented using floating point values are expected to have
23
+ values in the range [0,1). Image data stored in integer data types are
24
+ expected to have values in the range `[0,MAX]`, where `MAX` is the largest
25
+ positive representable number for the data type.
26
+ This op converts between data types, scaling the values appropriately before
27
+ casting.
28
+ Usage Example:
29
+ >>> x = [[[1, 2, 3], [4, 5, 6]],
30
+ ... [[7, 8, 9], [10, 11, 12]]]
31
+ >>> x_int8 = tf.convert_to_tensor(x, dtype=tf.int8)
32
+ >>> tf.image.convert_image_dtype(x_int8, dtype=tf.float16, saturate=False)
33
+ <tf.Tensor: shape=(2, 2, 3), dtype=float16, numpy=
34
+ array([[[0.00787, 0.01575, 0.02362],
35
+ [0.0315 , 0.03937, 0.04724]],
36
+ [[0.0551 , 0.063 , 0.07086],
37
+ [0.07874, 0.0866 , 0.0945 ]]], dtype=float16)>
38
+ Converting integer types to floating point types returns normalized floating
39
+ point values in the range [0, 1); the values are normalized by the `MAX` value
40
+ of the input dtype. Consider the following two examples:
41
+ >>> a = [[[1], [2]], [[3], [4]]]
42
+ >>> a_int8 = tf.convert_to_tensor(a, dtype=tf.int8)
43
+ >>> tf.image.convert_image_dtype(a_int8, dtype=tf.float32)
44
+ <tf.Tensor: shape=(2, 2, 1), dtype=float32, numpy=
45
+ array([[[0.00787402],
46
+ [0.01574803]],
47
+ [[0.02362205],
48
+ [0.03149606]]], dtype=float32)>
49
+ >>> a_int32 = tf.convert_to_tensor(a, dtype=tf.int32)
50
+ >>> tf.image.convert_image_dtype(a_int32, dtype=tf.float32)
51
+ <tf.Tensor: shape=(2, 2, 1), dtype=float32, numpy=
52
+ array([[[4.6566129e-10],
53
+ [9.3132257e-10]],
54
+ [[1.3969839e-09],
55
+ [1.8626451e-09]]], dtype=float32)>
56
+ Despite having identical values of `a` and output dtype of `float32`, the
57
+ outputs differ due to the different input dtypes (`int8` vs. `int32`). This
58
+ is, again, because the values are normalized by the `MAX` value of the input
59
+ dtype.
60
+ Note that converting floating point values to integer type may lose precision.
61
+ In the example below, an image tensor `b` of dtype `float32` is converted to
62
+ `int8` and back to `float32`. The final output, however, is different from
63
+ the original input `b` due to precision loss.
64
+ >>> b = [[[0.12], [0.34]], [[0.56], [0.78]]]
65
+ >>> b_float32 = tf.convert_to_tensor(b, dtype=tf.float32)
66
+ >>> b_int8 = tf.image.convert_image_dtype(b_float32, dtype=tf.int8)
67
+ >>> tf.image.convert_image_dtype(b_int8, dtype=tf.float32)
68
+ <tf.Tensor: shape=(2, 2, 1), dtype=float32, numpy=
69
+ array([[[0.11811024],
70
+ [0.33858266]],
71
+ [[0.5590551 ],
72
+ [0.77952754]]], dtype=float32)>
73
+ Scaling up from an integer type (input dtype) to another integer type (output
74
+ dtype) will not map input dtype's `MAX` to output dtype's `MAX` but converting
75
+ back and forth should result in no change. For example, as shown below, the
76
+ `MAX` value of int8 (=127) is not mapped to the `MAX` value of int16 (=32,767)
77
+ but, when scaled back, we get the same, original values of `c`.
78
+ >>> c = [[[1], [2]], [[127], [127]]]
79
+ >>> c_int8 = tf.convert_to_tensor(c, dtype=tf.int8)
80
+ >>> c_int16 = tf.image.convert_image_dtype(c_int8, dtype=tf.int16)
81
+ >>> print(c_int16)
82
+ tf.Tensor(
83
+ [[[ 256]
84
+ [ 512]]
85
+ [[32512]
86
+ [32512]]], shape=(2, 2, 1), dtype=int16)
87
+ >>> c_int8_back = tf.image.convert_image_dtype(c_int16, dtype=tf.int8)
88
+ >>> print(c_int8_back)
89
+ tf.Tensor(
90
+ [[[ 1]
91
+ [ 2]]
92
+ [[127]
93
+ [127]]], shape=(2, 2, 1), dtype=int8)
94
+ Scaling down from an integer type to another integer type can be a lossy
95
+ conversion. Notice in the example below that converting `int16` to `uint8` and
96
+ back to `int16` has lost precision.
97
+ >>> d = [[[1000], [2000]], [[3000], [4000]]]
98
+ >>> d_int16 = tf.convert_to_tensor(d, dtype=tf.int16)
99
+ >>> d_uint8 = tf.image.convert_image_dtype(d_int16, dtype=tf.uint8)
100
+ >>> d_int16_back = tf.image.convert_image_dtype(d_uint8, dtype=tf.int16)
101
+ >>> print(d_int16_back)
102
+ tf.Tensor(
103
+ [[[ 896]
104
+ [1920]]
105
+ [[2944]
106
+ [3968]]], shape=(2, 2, 1), dtype=int16)
107
+ Note that converting from floating point inputs to integer types may lead to
108
+ over/underflow problems. Set saturate to `True` to avoid such problem in
109
+ problematic conversions. If enabled, saturation will clip the output into the
110
+ allowed range before performing a potentially dangerous cast (and only before
111
+ performing such a cast, i.e., when casting from a floating point to an integer
112
+ type, and when casting from a signed to an unsigned type; `saturate` has no
113
+ effect on casts between floats, or on casts that increase the type's range).
114
+ Args:
115
+ image: An image.
116
+ dtype: A `DType` to convert `image` to.
117
+ saturate: If `True`, clip the input before casting (if necessary).
118
+ name: A name for this operation (optional).
119
+ Returns:
120
+ `image`, converted to `dtype`.
121
+ Raises:
122
+ AttributeError: Raises an attribute error when dtype is neither
123
+ float nor integer
124
+ """
125
+ image = ops.convert_to_tensor(image, name='image')
126
+ dtype = dtypes.as_dtype(dtype)
127
+ if not dtype.is_floating and not dtype.is_integer:
128
+ raise AttributeError('dtype must be either floating point or integer')
129
+ if dtype == image.dtype:
130
+ return array_ops.identity(image, name=name)
131
+
132
+ with ops.name_scope(name, 'convert_image', [image]) as name:
133
+ # Both integer: use integer multiplication in the larger range
134
+ if image.dtype.is_integer and dtype.is_integer:
135
+ scale_in = image.dtype.max
136
+ scale_out = dtype.max
137
+ if scale_in > scale_out:
138
+ # Scaling down, scale first, then cast. The scaling factor will
139
+ # cause in.max to be mapped to above out.max but below out.max+1,
140
+ # so that the output is safely in the supported range.
141
+ scale = (scale_in + 1) // (scale_out + 1)
142
+ scaled = math_ops.floordiv(image, scale)
143
+
144
+ if saturate:
145
+ return math_ops.saturate_cast(scaled, dtype, name=name)
146
+ else:
147
+ return math_ops.cast(scaled, dtype, name=name)
148
+ else:
149
+ # Scaling up, cast first, then scale. The scale will not map in.max to
150
+ # out.max, but converting back and forth should result in no change.
151
+ if saturate:
152
+ cast = math_ops.saturate_cast(image, dtype)
153
+ else:
154
+ cast = math_ops.cast(image, dtype)
155
+ scale = (scale_out + 1) // (scale_in + 1)
156
+ return math_ops.multiply(cast, scale, name=name)
157
+ elif image.dtype.is_floating and dtype.is_floating:
158
+ # Both float: Just cast, no possible overflows in the allowed ranges.
159
+ # Note: We're ignoring float overflows. If your image dynamic range
160
+ # exceeds float range, you're on your own.
161
+ return math_ops.cast(image, dtype, name=name)
162
+ else:
163
+ if image.dtype.is_integer:
164
+ # Converting to float: first cast, then scale. No saturation possible.
165
+ cast = math_ops.cast(image, dtype)
166
+ scale = 1. / image.dtype.max
167
+ return math_ops.multiply(cast, scale, name=name)
168
+ else:
169
+ # Converting from float: first scale, then cast
170
+ scale = dtype.max + 0.5 # avoid rounding problems in the cast
171
+ scaled = math_ops.multiply(image, scale)
172
+ if saturate:
173
+ return math_ops.saturate_cast(scaled, dtype, name=name)
174
+ else:
175
+ return math_ops.cast(scaled, dtype, name=name)
176
+
177
+
178
+ def _verify_compatible_image_shapes(img1, img2):
179
+ """Checks if two image tensors are compatible for applying SSIM or PSNR.
180
+ This function checks if two sets of images have ranks at least 3, and if the
181
+ last three dimensions match.
182
+ Args:
183
+ img1: Tensor containing the first image batch.
184
+ img2: Tensor containing the second image batch.
185
+ Returns:
186
+ A tuple containing: the first tensor shape, the second tensor shape, and a
187
+ list of control_flow_ops.Assert() ops implementing the checks.
188
+ Raises:
189
+ ValueError: When static shape check fails.
190
+ """
191
+ shape1 = img1.get_shape().with_rank_at_least(4) # at least [H, W, D, C]
192
+ shape2 = img2.get_shape().with_rank_at_least(4) # at least [H, W, D, C]
193
+ shape1[-4:].assert_is_compatible_with(shape2[-4:])
194
+
195
+ if shape1.ndims is not None and shape2.ndims is not None:
196
+ for dim1, dim2 in zip(
197
+ reversed(shape1.dims[:-4]), reversed(shape2.dims[:-4])):
198
+ if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)):
199
+ raise ValueError('Two images are not compatible: %s and %s' %
200
+ (shape1, shape2))
201
+
202
+ # Now assign shape tensors.
203
+ shape1, shape2 = array_ops.shape_n([img1, img2])
204
+
205
+ # TODO(sjhwang): Check if shape1[:-4] and shape2[:-4] are broadcastable.
206
+ checks = []
207
+ checks.append(
208
+ control_flow_ops.Assert(
209
+ math_ops.greater_equal(array_ops.size(shape1), 4), [shape1, shape2],
210
+ summarize=10))
211
+ checks.append(
212
+ control_flow_ops.Assert(
213
+ math_ops.reduce_all(math_ops.equal(shape1[-4:], shape2[-4:])),
214
+ [shape1, shape2],
215
+ summarize=10))
216
+ return shape1, shape2, checks
217
+
218
+
219
+ def _ssim_helper(x, y, reducer, max_val, compensation=1.0, k1=0.01, k2=0.03):
220
+ r"""Helper function for computing SSIM.
221
+ SSIM estimates covariances with weighted sums. The default parameters
222
+ use a biased estimate of the covariance:
223
+ Suppose `reducer` is a weighted sum, then the mean estimators are
224
+ \mu_x = \sum_i w_i x_i,
225
+ \mu_y = \sum_i w_i y_i,
226
+ where w_i's are the weighted-sum weights, and covariance estimator is
227
+ cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
228
+ with assumption \sum_i w_i = 1. This covariance estimator is biased, since
229
+ E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y).
230
+ For SSIM measure with unbiased covariance estimators, pass as `compensation`
231
+ argument (1 - \sum_i w_i ^ 2).
232
+ Args:
233
+ x: First set of images.
234
+ y: Second set of images.
235
+ reducer: Function that computes 'local' averages from the set of images. For
236
+ non-convolutional version, this is usually tf.reduce_mean(x, [1, 2]), and
237
+ for convolutional version, this is usually tf.nn.avg_pool2d or
238
+ tf.nn.conv3d with weighted-sum kernel.
239
+ max_val: The dynamic range (i.e., the difference between the maximum
240
+ possible allowed value and the minimum allowed value).
241
+ compensation: Compensation factor. See above.
242
+ k1: Default value 0.01
243
+ k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
244
+ it would be better if we took the values in the range of 0 < K2 < 0.4).
245
+ Returns:
246
+ A pair containing the luminance measure, and the contrast-structure measure.
247
+ """
248
+
249
+ c1 = (k1 * max_val)**2
250
+ c2 = (k2 * max_val)**2
251
+
252
+ # SSIM luminance measure is
253
+ # (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
254
+ mean0 = reducer(x)
255
+ mean1 = reducer(y)
256
+ num0 = mean0 * mean1 * 2.0
257
+ den0 = math_ops.square(mean0) + math_ops.square(mean1)
258
+ luminance = (num0 + c1) / (den0 + c1)
259
+
260
+ # SSIM contrast-structure measure is
261
+ # (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2).
262
+ # Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then
263
+ # cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
264
+ # = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j).
265
+ num1 = reducer(x * y) * 2.0
266
+ den1 = reducer(math_ops.square(x) + math_ops.square(y))
267
+ c2 *= compensation
268
+ cs = (num1 - num0 + c2) / (den1 - den0 + c2)
269
+
270
+ # SSIM score is the product of the luminance and contrast-structure measures.
271
+ return luminance, cs
272
+
273
+
274
+ def _fspecial_gauss(size, sigma):
275
+ """Function to mimic the 'fspecial' gaussian MATLAB function."""
276
+ size = ops.convert_to_tensor(size, dtypes.int32)
277
+ sigma = ops.convert_to_tensor(sigma, dtypes.float32)
278
+
279
+ coords = math_ops.cast(math_ops.range(size), sigma.dtype)
280
+ coords -= math_ops.cast(size - 1, sigma.dtype) / 2.0
281
+
282
+ g = math_ops.square(coords)
283
+ g *= -0.5 / math_ops.square(sigma)
284
+
285
+ g = array_ops.reshape(g, shape=[1, -1]) + array_ops.reshape(g, shape=[-1, 1])
286
+ g = array_ops.reshape(g, shape=[size, size, 1]) + array_ops.reshape(g, shape=[1, size, size])
287
+ g = array_ops.reshape(g, shape=[1, -1]) # For tf.nn.softmax().
288
+ g = nn_ops.softmax(g)
289
+ return array_ops.reshape(g, shape=[size, size, size, 1, 1])
290
+
291
+
292
+ def _ssim_per_channel(img1,
293
+ img2,
294
+ max_val=1.0,
295
+ filter_size=11,
296
+ filter_sigma=1.5,
297
+ k1=0.01,
298
+ k2=0.03):
299
+ """Computes SSIM index between img1 and img2 per color channel.
300
+ This function matches the standard SSIM implementation from:
301
+ Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
302
+ quality assessment: from error visibility to structural similarity. IEEE
303
+ transactions on image processing.
304
+ Details:
305
+ - 11x11 Gaussian filter of width 1.5 is used.
306
+ - k1 = 0.01, k2 = 0.03 as in the original paper.
307
+ Args:
308
+ img1: First image batch.
309
+ img2: Second image batch.
310
+ max_val: The dynamic range of the images (i.e., the difference between the
311
+ maximum the and minimum allowed values).
312
+ filter_size: Default value 11 (size of gaussian filter).
313
+ filter_sigma: Default value 1.5 (width of gaussian filter).
314
+ k1: Default value 0.01
315
+ k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
316
+ it would be better if we took the values in the range of 0 < K2 < 0.4).
317
+ Returns:
318
+ A pair of tensors containing and channel-wise SSIM and contrast-structure
319
+ values. The shape is [..., channels].
320
+ """
321
+ filter_size = constant_op.constant(filter_size, dtype=dtypes.int32)
322
+ filter_sigma = constant_op.constant(filter_sigma, dtype=img1.dtype)
323
+
324
+ shape1, shape2 = array_ops.shape_n([img1, img2])
325
+ checks = [
326
+ control_flow_ops.Assert(
327
+ math_ops.reduce_all(
328
+ math_ops.greater_equal(shape1[-4:-1], filter_size)),
329
+ [shape1, filter_size],
330
+ summarize=8),
331
+ control_flow_ops.Assert(
332
+ math_ops.reduce_all(
333
+ math_ops.greater_equal(shape2[-4:-1], filter_size)),
334
+ [shape2, filter_size],
335
+ summarize=8)
336
+ ]
337
+
338
+ # Enforce the check to run before computation.
339
+ with ops.control_dependencies(checks):
340
+ img1 = array_ops.identity(img1)
341
+
342
+ # TODO(sjhwang): Try to cache kernels and compensation factor.
343
+ kernel = _fspecial_gauss(filter_size, filter_sigma)
344
+ kernel = array_ops.tile(kernel, multiples=[1, 1, 1, shape1[-1], 1])
345
+
346
+ # The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`,
347
+ # but to match MATLAB implementation of MS-SSIM, we use 1.0 instead.
348
+ compensation = 1.0
349
+
350
+ # TODO(sjhwang): Try FFT.
351
+ # TODO(sjhwang): Gaussian kernel is separable in space. Consider applying
352
+ # 1-by-n and n-by-1 Gaussian filters instead of an n-by-n filter.
353
+ def reducer(x):
354
+ shape = array_ops.shape(x)
355
+ x = array_ops.reshape(x, shape=array_ops.concat([[-1], shape[-4:]], 0))
356
+ y = nn.conv3d(x, kernel, strides=[1, 1, 1, 1, 1], padding='VALID')
357
+ return array_ops.reshape(y, array_ops.concat([shape[:-4], array_ops.shape(y)[1:]], 0))
358
+
359
+ luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation, k1,
360
+ k2)
361
+
362
+ # Average over the second, third and the fourth from the last: height, width, depth.
363
+ axes = constant_op.constant([-4, -3, -2], dtype=dtypes.int32)
364
+ ssim_val = math_ops.reduce_mean(luminance * cs, axes)
365
+ cs = math_ops.reduce_mean(cs, axes)
366
+ return ssim_val, cs
367
+
368
+
369
+ @tf_export('image.ssim')
370
+ @dispatch.add_dispatch_support
371
+ def ssim(img1,
372
+ img2,
373
+ max_val,
374
+ filter_size=11,
375
+ filter_sigma=1.5,
376
+ k1=0.01,
377
+ k2=0.03):
378
+ """Computes SSIM index between img1 and img2.
379
+ This function is based on the standard SSIM implementation from:
380
+ Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
381
+ quality assessment: from error visibility to structural similarity. IEEE
382
+ transactions on image processing.
383
+ Note: The true SSIM is only defined on grayscale. This function does not
384
+ perform any colorspace transform. (If the input is already YUV, then it will
385
+ compute YUV SSIM average.)
386
+ Details:
387
+ - 11x11 Gaussian filter of width 1.5 is used.
388
+ - k1 = 0.01, k2 = 0.03 as in the original paper.
389
+ The image sizes must be at least 11x11 because of the filter size.
390
+ Example:
391
+ ```python
392
+ # Read images (of size 255 x 255) from file.
393
+ im1 = tf.image.decode_image(tf.io.read_file('path/to/im1.png'))
394
+ im2 = tf.image.decode_image(tf.io.read_file('path/to/im2.png'))
395
+ tf.shape(im1) # `img1.png` has 3 channels; shape is `(255, 255, 3)`
396
+ tf.shape(im2) # `img2.png` has 3 channels; shape is `(255, 255, 3)`
397
+ # Add an outer batch for each image.
398
+ im1 = tf.expand_dims(im1, axis=0)
399
+ im2 = tf.expand_dims(im2, axis=0)
400
+ # Compute SSIM over tf.uint8 Tensors.
401
+ ssim1 = tf.image.ssim(im1, im2, max_val=255, filter_size=11,
402
+ filter_sigma=1.5, k1=0.01, k2=0.03)
403
+ # Compute SSIM over tf.float32 Tensors.
404
+ im1 = tf.image.convert_image_dtype(im1, tf.float32)
405
+ im2 = tf.image.convert_image_dtype(im2, tf.float32)
406
+ ssim2 = tf.image.ssim(im1, im2, max_val=1.0, filter_size=11,
407
+ filter_sigma=1.5, k1=0.01, k2=0.03)
408
+ # ssim1 and ssim2 both have type tf.float32 and are almost equal.
409
+ ```
410
+ Args:
411
+ img1: First image batch. 4-D Tensor of shape `[batch, height, width,
412
+ channels]` with only Positive Pixel Values.
413
+ img2: Second image batch. 4-D Tensor of shape `[batch, height, width,
414
+ channels]` with only Positive Pixel Values.
415
+ max_val: The dynamic range of the images (i.e., the difference between the
416
+ maximum the and minimum allowed values).
417
+ filter_size: Default value 11 (size of gaussian filter).
418
+ filter_sigma: Default value 1.5 (width of gaussian filter).
419
+ k1: Default value 0.01
420
+ k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
421
+ it would be better if we took the values in the range of 0 < K2 < 0.4).
422
+ Returns:
423
+ A tensor containing an SSIM value for each image in batch. Returned SSIM
424
+ values are in range (-1, 1], when pixel values are non-negative. Returns
425
+ a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
426
+ """
427
+ with ops.name_scope(None, 'SSIM', [img1, img2]):
428
+ # Convert to tensor if needed.
429
+ img1 = ops.convert_to_tensor(img1, name='img1')
430
+ img2 = ops.convert_to_tensor(img2, name='img2')
431
+ # Shape checking.
432
+ _, _, checks = _verify_compatible_image_shapes(img1, img2)
433
+ with ops.control_dependencies(checks):
434
+ img1 = array_ops.identity(img1)
435
+
436
+ # Need to convert the images to float32. Scale max_val accordingly so that
437
+ # SSIM is computed correctly.
438
+ max_val = math_ops.cast(max_val, img1.dtype)
439
+ max_val = convert_image_dtype(max_val, dtypes.float32)
440
+ img1 = convert_image_dtype(img1, dtypes.float32)
441
+ img2 = convert_image_dtype(img2, dtypes.float32)
442
+ ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
443
+ filter_sigma, k1, k2)
444
+ # Compute average over color channels.
445
+ return math_ops.reduce_mean(ssim_per_channel, [-1])
446
+
447
+
448
+ # Default values obtained by Wang et al.
449
+ _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
450
+
451
+
452
+ @tf_export('image.ssim_multiscale')
453
+ @dispatch.add_dispatch_support
454
+ def ssim_multiscale(img1,
455
+ img2,
456
+ max_val,
457
+ power_factors=_MSSSIM_WEIGHTS,
458
+ filter_size=11,
459
+ filter_sigma=1.5,
460
+ k1=0.01,
461
+ k2=0.03):
462
+ """Computes the MS-SSIM between img1 and img2.
463
+ This function assumes that `img1` and `img2` are image batches, i.e. the last
464
+ three dimensions are [height, width, channels].
465
+ Note: The true SSIM is only defined on grayscale. This function does not
466
+ perform any colorspace transform. (If the input is already YUV, then it will
467
+ compute YUV SSIM average.)
468
+ Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
469
+ structural similarity for image quality assessment." Signals, Systems and
470
+ Computers, 2004.
471
+ Args:
472
+ img1: First image batch with only Positive Pixel Values.
473
+ img2: Second image batch with only Positive Pixel Values. Must have the
474
+ same rank as img1.
475
+ max_val: The dynamic range of the images (i.e., the difference between the
476
+ maximum the and minimum allowed values).
477
+ power_factors: Iterable of weights for each of the scales. The number of
478
+ scales used is the length of the list. Index 0 is the unscaled
479
+ resolution's weight and each increasing scale corresponds to the image
480
+ being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
481
+ 0.1333), which are the values obtained in the original paper.
482
+ filter_size: Default value 11 (size of gaussian filter).
483
+ filter_sigma: Default value 1.5 (width of gaussian filter).
484
+ k1: Default value 0.01
485
+ k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
486
+ it would be better if we took the values in the range of 0 < K2 < 0.4).
487
+ Returns:
488
+ A tensor containing an MS-SSIM value for each image in batch. The values
489
+ are in range [0, 1]. Returns a tensor with shape:
490
+ broadcast(img1.shape[:-3], img2.shape[:-3]).
491
+ """
492
+ with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
493
+ # Convert to tensor if needed.
494
+ img1 = ops.convert_to_tensor(img1, name='img1')
495
+ img2 = ops.convert_to_tensor(img2, name='img2')
496
+ # Shape checking.
497
+ shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
498
+ with ops.control_dependencies(checks):
499
+ img1 = array_ops.identity(img1)
500
+
501
+ # Need to convert the images to float32. Scale max_val accordingly so that
502
+ # SSIM is computed correctly.
503
+ max_val = math_ops.cast(max_val, img1.dtype)
504
+ max_val = convert_image_dtype(max_val, dtypes.float32)
505
+ img1 = convert_image_dtype(img1, dtypes.float32)
506
+ img2 = convert_image_dtype(img2, dtypes.float32)
507
+
508
+ imgs = [img1, img2]
509
+ shapes = [shape1, shape2]
510
+
511
+ # img1 and img2 are assumed to be a (multi-dimensional) batch of
512
+ # 4-dimensional images (height, width, depth, channels). `heads` contain the batch
513
+ # dimensions, and `tails` contain the image dimensions.
514
+ heads = [s[:-4] for s in shapes]
515
+ tails = [s[-4:] for s in shapes]
516
+
517
+ divisor = [1, 2, 2, 2, 1]
518
+ divisor_tensor = constant_op.constant(divisor[1:], dtype=dtypes.int32)
519
+
520
+ def do_pad(images, remainder):
521
+ padding = array_ops.expand_dims(remainder, -1)
522
+ padding = array_ops.pad(padding, [[1, 0], [1, 0]])
523
+ return [array_ops.pad(x, padding, mode='SYMMETRIC') for x in images]
524
+
525
+ mcs = []
526
+ for k in range(len(power_factors)):
527
+ with ops.name_scope(None, 'Scale%d' % k, imgs):
528
+ if k > 0:
529
+ # Avg pool takes rank 4 tensors. Flatten leading dimensions.
530
+ flat_imgs = [
531
+ array_ops.reshape(x, array_ops.concat([[-1], t], 0))
532
+ for x, t in zip(imgs, tails)
533
+ ]
534
+
535
+ remainder = tails[0] % divisor_tensor
536
+ need_padding = math_ops.reduce_any(math_ops.not_equal(remainder, 0))
537
+ # pylint: disable=cell-var-from-loop
538
+ padded = control_flow_ops.cond(need_padding,
539
+ lambda: do_pad(flat_imgs, remainder),
540
+ lambda: flat_imgs)
541
+ # pylint: enable=cell-var-from-loop
542
+
543
+ downscaled = [
544
+ nn_ops.avg_pool3d(
545
+ x, ksize=divisor, strides=divisor, padding='VALID', data_format='NDHWC',)
546
+ for x in padded
547
+ ]
548
+ tails = [x[1:] for x in array_ops.shape_n(downscaled)]
549
+ imgs = [
550
+ array_ops.reshape(x, array_ops.concat([h, t], 0))
551
+ for x, h, t in zip(downscaled, heads, tails)
552
+ ]
553
+
554
+ # Overwrite previous ssim value since we only need the last one.
555
+ ssim_per_channel, cs = _ssim_per_channel(
556
+ *imgs,
557
+ max_val=max_val,
558
+ filter_size=filter_size,
559
+ filter_sigma=filter_sigma,
560
+ k1=k1,
561
+ k2=k2)
562
+ mcs.append(nn_ops.relu(cs))
563
+
564
+ # Remove the cs score for the last scale. In the MS-SSIM calculation,
565
+ # we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p).
566
+ mcs.pop() # Remove the cs score for the last scale.
567
+ mcs_and_ssim = array_ops.stack(
568
+ mcs + [nn_ops.relu(ssim_per_channel)], axis=-1)
569
+ # Take weighted geometric mean across the scale axis.
570
+ ms_ssim = math_ops.reduce_prod(
571
+ math_ops.pow(mcs_and_ssim, power_factors), [-1])
572
+
573
+ return math_ops.reduce_mean(ms_ssim, [-1]) # Avg over color channels.
574
+
575
+
576
+ class MultiScaleStructuralSimilarity:
577
+ def __init__(self,
578
+ max_val,
579
+ power_factors=_MSSSIM_WEIGHTS,
580
+ filter_size=11,
581
+ filter_sigma=1.5,
582
+ k1=0.01,
583
+ k2=0.03):
584
+ self.max_val = max_val
585
+ self.power_factors = power_factors
586
+ self.filter_size = int(filter_size)
587
+ self.filter_sigma = filter_sigma
588
+ self.k1 = k1
589
+ self.k2 = k2
590
+
591
+ @function_decorator('MS_SSIM__loss')
592
+ def loss(self, y_true, y_pred):
593
+ return math_ops.reduce_mean((1 - ssim_multiscale(y_true, y_pred, self.max_val, self.power_factors,
594
+ self.filter_size, self.filter_sigma, self.k1, self.k2))/2)
595
+
596
+ @function_decorator('MS_SSIM__metric')
597
+ def metric(self, y_true, y_pred):
598
+ return ssim_multiscale(y_true, y_pred, self.max_val, self.power_factors, self.filter_size, self.filter_sigma,
599
+ self.k1, self.k2)
600
+
601
+
602
+ if __name__ == '__main__':
603
+ import os
604
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
605
+ import tensorflow as tf
606
+ tf.enable_eager_execution()
607
+ import nibabel as nib
608
+ import numpy as np
609
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm
610
+ from skimage.metrics import structural_similarity
611
+
612
+ img1 = nib.load('test_images/ixi_image.nii.gz')
613
+ img1 = np.asarray(img1.dataobj)
614
+ img1 = img1[np.newaxis, ..., np.newaxis] # Add Batch and Channel dimensions
615
+
616
+ img2 = nib.load('test_images/ixi_image2.nii.gz')
617
+ img2 = np.asarray(img2.dataobj)
618
+ img2 = img2[np.newaxis, ..., np.newaxis]
619
+
620
+ img1 = min_max_norm(img1)
621
+ img2 = min_max_norm(img2)
622
+
623
+ ssim_tf_1_2 = ssim(img1, img2, 1., filter_size=5)
624
+ assert ssim(img1, img1, 1., filter_size=5).numpy()[0] == 1., 'TF SSIM returned an unexpected value'
625
+ ssim_sklearn = structural_similarity(img1[0, ..., 0], img2[0, ..., 0], win_size=5)
626
+
627
+ ms_ssim_tf_1_2 = ssim_multiscale(img1, img2, 1., filter_size=5)
628
+ assert ssim_multiscale(img1, img1, 1., filter_size=5).numpy()[0] == 1., 'TF MS-SSIM returned an unexpected value'
629
+
630
+ print('SSIM TF: {}\nSSIM SKLEARN: {}\nMS SSIM TF: {}\n'.format(ssim_tf_1_2, ssim_sklearn, ms_ssim_tf_1_2))
631
+
632
+ batch_img1 = np.stack([img1, img2], axis=0)
633
+ batch_img2 = np.stack([img2, img2], axis=0)
634
+ batch_ssim_tf = ssim(batch_img1, batch_img2, 1., filter_size=5)
635
+ batch_ms_ssim_tf = ssim_multiscale(batch_img1, batch_img2, 1., filter_size=5)
636
+
637
+ print('Batch SSIM TF: {}\nBatch MS SSIM TF: {}\n'.format(batch_ssim_tf, batch_ms_ssim_tf))
638
+
639
+ img1 = img1[:, :127, :127, :127, :]
640
+ img2 = img2[:, :127, :127, :127, :]
641
+ MS_SSIM = MultiScaleStructuralSimilarity(1., filter_size=5)
642
+ print('MS SSIM Loss{}'.format(MS_SSIM.loss(img1, img2)))
DeepDeformationMapRegistration/networks.py CHANGED
@@ -8,12 +8,14 @@ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
8
  import tensorflow as tf
9
  import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
 
11
 
12
 
13
  class WeaklySupervised(LoadableModel):
14
 
15
  @store_config_args
16
- def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
 
17
  """
18
  Parameters:
19
  inshape: Input shape. e.g. (192, 192, 192)
@@ -21,34 +23,53 @@ class WeaklySupervised(LoadableModel):
21
  hot_labels: List of labels to output as one-hot maps.
22
  nb_unet_features: Unet convolutional features. See VxmDense documentation for more information.
23
  int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
 
24
  kwargs: Forwarded to the internal VxmDense model.
25
  """
26
 
27
- fix_segm = tf.keras.Input((*inshape, len(all_labels)), name='fix_segmentations_input')
28
  mov_segm = tf.keras.Input((*inshape, len(all_labels)), name='mov_segmentations_input')
29
 
 
30
  mov_img = tf.keras.Input((*inshape, 1), name='mov_image_input')
31
 
32
- unet_input_model = tf.keras.Model(inputs=[mov_segm, fix_segm], outputs=[mov_segm, fix_segm])
33
 
34
  vxm_model = vxm.networks.VxmDense(inshape=inshape,
35
  nb_unet_features=nb_unet_features,
36
- input_model=unet_input_model,
37
  int_steps=int_steps,
38
- bidir=bidir, **kwargs)
39
-
40
- pred_img = vxm.layers.SpatialTransformer(interp_method='linear', indexing='ij', name='pred_fix_img')(
41
- [mov_img, vxm_model.references.pos_flow])
42
-
43
- inputs = [mov_segm, fix_segm, mov_img] # mov_img, mov_segm, fix_segm
44
- outputs = [pred_img] + vxm_model.outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  self.references = LoadableModel.ReferenceContainer()
47
- self.references.pred_segm = vxm_model.outputs[0]
48
- self.references.pred_img = pred_img
49
  self.references.pos_flow = vxm_model.references.pos_flow
50
 
51
- super().__init__(inputs=inputs, outputs=outputs)
52
 
53
  def get_registration_model(self):
54
  return tf.keras.Model(self.inputs, self.references.pos_flow)
 
8
  import tensorflow as tf
9
  import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
+ from tensorflow.keras.layers import UpSampling3D
12
 
13
 
14
  class WeaklySupervised(LoadableModel):
15
 
16
  @store_config_args
17
+ def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False,
18
+ int_downsize=1, outshape=None, **kwargs):
19
  """
20
  Parameters:
21
  inshape: Input shape. e.g. (192, 192, 192)
 
23
  hot_labels: List of labels to output as one-hot maps.
24
  nb_unet_features: Unet convolutional features. See VxmDense documentation for more information.
25
  int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
26
+ int_downsize: Dowsampling of the displacement map. Integer
27
  kwargs: Forwarded to the internal VxmDense model.
28
  """
29
 
 
30
  mov_segm = tf.keras.Input((*inshape, len(all_labels)), name='mov_segmentations_input')
31
 
32
+ fix_img = tf.keras.Input((*inshape, 1), name='fix_image_input')
33
  mov_img = tf.keras.Input((*inshape, 1), name='mov_image_input')
34
 
35
+ input_model = tf.keras.Model(inputs=[mov_img, fix_img], outputs=[mov_img, fix_img])
36
 
37
  vxm_model = vxm.networks.VxmDense(inshape=inshape,
38
  nb_unet_features=nb_unet_features,
39
+ input_model=input_model,
40
  int_steps=int_steps,
41
+ bidir=bidir,
42
+ int_downsize=int_downsize,
43
+ **kwargs)
44
+
45
+ pred_segm = vxm.layers.SpatialTransformer(interp_method='linear', indexing='ij', name='interp_segm')(
46
+ [mov_segm, vxm_model.references.pos_flow])
47
+
48
+ inputs = [mov_img, fix_img, mov_segm] # mov_img, mov_segm, fix_segm
49
+ model_outputs = vxm_model.outputs
50
+ if outshape is not None:
51
+ scale_factors = [o//i for i, o in zip(inshape, outshape)]
52
+ upsampling_layer = UpSampling3D(scale_factors) # Doesn't perform trilinear, only nearest
53
+ # Image
54
+ model_outputs[0] = upsampling_layer(model_outputs[0])
55
+ # Segmentation
56
+ pred_segm = upsampling_layer(pred_segm)
57
+ # Displacement map
58
+ model_outputs[1] = upsampling_layer(scale_factors)(model_outputs[1])
59
+ model_outputs[1] = tf.multiply(model_outputs[1], tf.cast(scale_factors, model_outputs[1].dtype))
60
+
61
+ # Just renaming
62
+ pred_fix_image = tf.identity(model_outputs[0], name='pred_fix_image')
63
+ pred_dm = tf.identity(model_outputs[1], name='pred_dm')
64
+ pred_segm = tf.identity(pred_segm, name='pred_fix_segm')
65
+ outputs = [pred_fix_image, pred_segm, pred_dm]
66
 
67
  self.references = LoadableModel.ReferenceContainer()
68
+ self.references.pred_segm = pred_segm
69
+ self.references.pred_img = vxm_model.outputs[0]
70
  self.references.pos_flow = vxm_model.references.pos_flow
71
 
72
+ super(WeaklySupervised, self).__init__(inputs=inputs, outputs=outputs)
73
 
74
  def get_registration_model(self):
75
  return tf.keras.Model(self.inputs, self.references.pos_flow)
DeepDeformationMapRegistration/utils/acummulated_optimizer.py CHANGED
@@ -1,5 +1,9 @@
1
- from tensorflow.keras.optimizers import Optimizer
2
- from tensorflow.keras import backend as K
 
 
 
 
3
 
4
 
5
  class AccumOptimizer(Optimizer):
@@ -14,7 +18,7 @@ class AccumOptimizer(Optimizer):
14
  a new keras optimizer.
15
  """
16
  def __init__(self, optimizer, steps_per_update=1, **kwargs):
17
- super(AccumOptimizer, self).__init__(**kwargs)
18
  self.optimizer = optimizer
19
  with K.name_scope(self.__class__.__name__):
20
  self.steps_per_update = steps_per_update
@@ -55,3 +59,134 @@ class AccumOptimizer(Optimizer):
55
  config = self.optimizer.get_config()
56
  K.set_value(self.iterations, iterations)
57
  return config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.python.keras.optimizers import Optimizer
2
+ from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
3
+ from tensorflow.python import ops, math_ops, state_ops, control_flow_ops
4
+ from tensorflow.python.keras import backend as K
5
+ from tensorflow.python.keras import backend_config
6
+ import tensorflow as tf
7
 
8
 
9
  class AccumOptimizer(Optimizer):
 
18
  a new keras optimizer.
19
  """
20
  def __init__(self, optimizer, steps_per_update=1, **kwargs):
21
+ super(AccumOptimizer, self).__init__(name='AccumOptimizer', **kwargs)
22
  self.optimizer = optimizer
23
  with K.name_scope(self.__class__.__name__):
24
  self.steps_per_update = steps_per_update
 
59
  config = self.optimizer.get_config()
60
  K.set_value(self.iterations, iterations)
61
  return config
62
+
63
+
64
+ __all__ = ['AdamAccumulated']
65
+
66
+
67
+ # SRC: https://github.com/CyberZHG/keras-gradient-accumulation/blob/master/keras_gradient_accumulation/optimizer_v2.py
68
+ class AdamAccumulated(OptimizerV2):
69
+ """Optimizer that implements the Adam algorithm with gradient accumulation."""
70
+
71
+ def __init__(self,
72
+ accumulation_steps,
73
+ learning_rate=0.001,
74
+ beta_1=0.9,
75
+ beta_2=0.999,
76
+ epsilon=1e-7,
77
+ amsgrad=False,
78
+ name='Adam',
79
+ **kwargs):
80
+ r"""Construct a new Adam optimizer.
81
+ Args:
82
+ accumulation_steps: An integer. Update gradient in every accumulation steps.
83
+ learning_rate: A Tensor or a floating point value. The learning rate.
84
+ beta_1: A float value or a constant float tensor. The exponential decay
85
+ rate for the 1st moment estimates.
86
+ beta_2: A float value or a constant float tensor. The exponential decay
87
+ rate for the 2nd moment estimates.
88
+ epsilon: A small constant for numerical stability. This epsilon is
89
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
90
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
91
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
92
+ the paper "On the Convergence of Adam and beyond".
93
+ name: Optional name for the operations created when applying gradients.
94
+ Defaults to "Adam". @compatibility(eager) When eager execution is
95
+ enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be
96
+ a callable that takes no arguments and returns the actual value to use.
97
+ This can be useful for changing these values across different
98
+ invocations of optimizer functions. @end_compatibility
99
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
100
+ `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
101
+ gradients by value, `decay` is included for backward compatibility to
102
+ allow time inverse decay of learning rate. `lr` is included for backward
103
+ compatibility, recommended to use `learning_rate` instead.
104
+ """
105
+
106
+ super(AdamAccumulated, self).__init__(name, **kwargs)
107
+ self._set_hyper('accumulation_steps', accumulation_steps)
108
+ self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
109
+ self._set_hyper('decay', self._initial_decay)
110
+ self._set_hyper('beta_1', beta_1)
111
+ self._set_hyper('beta_2', beta_2)
112
+ self.epsilon = epsilon or backend_config.epsilon()
113
+ self.amsgrad = amsgrad
114
+
115
+ def _create_slots(self, var_list):
116
+ for var in var_list:
117
+ self.add_slot(var, 'g')
118
+ for var in var_list:
119
+ self.add_slot(var, 'm')
120
+ for var in var_list:
121
+ self.add_slot(var, 'v')
122
+ if self.amsgrad:
123
+ for var in var_list:
124
+ self.add_slot(var, 'vhat')
125
+
126
+ def set_weights(self, weights):
127
+ params = self.weights
128
+ num_vars = int((len(params) - 1) / 2)
129
+ if len(weights) == 3 * num_vars + 1:
130
+ weights = weights[:len(params)]
131
+ super(AdamAccumulated, self).set_weights(weights)
132
+
133
+ def _resource_apply_dense(self, grad, var):
134
+ var_dtype = var.dtype.base_dtype
135
+ lr_t = self._decayed_lr(var_dtype)
136
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
137
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
138
+ accumulation_steps = self._get_hyper('accumulation_steps', 'int64')
139
+ update_cond = tf.equal((self.iterations + 1) % accumulation_steps, 0)
140
+ sub_step = self.iterations % accumulation_steps + 1
141
+ local_step = math_ops.cast(self.iterations // accumulation_steps + 1, var_dtype)
142
+ beta_1_power = math_ops.pow(beta_1_t, local_step)
143
+ beta_2_power = math_ops.pow(beta_2_t, local_step)
144
+ epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
145
+ lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
146
+ lr = tf.where(update_cond, lr, 0.0)
147
+
148
+ g = self.get_slot(var, 'g')
149
+ g_a = grad / math_ops.cast(accumulation_steps, var_dtype)
150
+ g_t = tf.where(tf.equal(sub_step, 1),
151
+ g_a,
152
+ g + (g_a - g) / math_ops.cast(sub_step, var_dtype))
153
+ g_t = state_ops.assign(g, g_t, use_locking=self._use_locking)
154
+
155
+ m = self.get_slot(var, 'm')
156
+ m_t = tf.where(update_cond, m * beta_1_t + g_t * (1 - beta_1_t), m)
157
+ m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)
158
+
159
+ v = self.get_slot(var, 'v')
160
+ v_t = tf.where(update_cond, v * beta_2_t + (g_t * g_t) * (1 - beta_2_t), v)
161
+ v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)
162
+
163
+ if not self.amsgrad:
164
+ v_sqrt = math_ops.sqrt(v_t)
165
+ var_update = state_ops.assign_sub(
166
+ var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
167
+ return control_flow_ops.group(*[var_update, m_t, v_t])
168
+ else:
169
+ v_hat = self.get_slot(var, 'vhat')
170
+ v_hat_t = tf.where(update_cond, math_ops.maximum(v_hat, v_t), v_hat)
171
+ with ops.control_dependencies([v_hat_t]):
172
+ v_hat_t = state_ops.assign(
173
+ v_hat, v_hat_t, use_locking=self._use_locking)
174
+ v_hat_sqrt = math_ops.sqrt(v_hat_t)
175
+ var_update = state_ops.assign_sub(
176
+ var,
177
+ lr * m_t / (v_hat_sqrt + epsilon_t),
178
+ use_locking=self._use_locking)
179
+ return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
180
+
181
+ def get_config(self):
182
+ config = super(AdamAccumulated, self).get_config()
183
+ config.update({
184
+ 'accumulation_steps': self._serialize_hyperparameter('accumulation_steps'),
185
+ 'learning_rate': self._serialize_hyperparameter('learning_rate'),
186
+ 'decay': self._serialize_hyperparameter('decay'),
187
+ 'beta_1': self._serialize_hyperparameter('beta_1'),
188
+ 'beta_2': self._serialize_hyperparameter('beta_2'),
189
+ 'epsilon': self.epsilon,
190
+ 'amsgrad': self.amsgrad,
191
+ })
192
+ return config
DeepDeformationMapRegistration/utils/constants.py CHANGED
@@ -52,6 +52,8 @@ H5_MOV_TUMORS_MASK = 'input/{}'.format(MOVING_TUMORS_MASK)
52
  H5_FIX_TUMORS_MASK = 'input/{}'.format(FIXED_TUMORS_MASK)
53
  H5_FIX_SEGMENTATIONS = 'input/{}'.format(FIXED_SEGMENTATIONS)
54
  H5_MOV_SEGMENTATIONS = 'input/{}'.format(MOVING_SEGMENTATIONS)
 
 
55
 
56
  H5_GT_DISP = 'output/{}'.format(DISP_MAP_GT)
57
  H5_GT_IMG = 'output/{}'.format(PRED_IMG_GT)
@@ -64,6 +66,7 @@ MAX_ANGLE = 45.0 # degrees
64
  MAX_FLIPS = 2 # Axes to flip over
65
  NUM_ROTATIONS = 5
66
  MAX_WORKERS = 10
 
67
 
68
  # Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
69
  DG_LBL_FIX_IMG = H5_FIX_IMG
@@ -75,6 +78,8 @@ DG_LBL_MOV_VESSELS = H5_MOV_VESSELS_MASK
75
  DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
76
  DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
77
  DG_LBL_ZERO_GRADS = 'zero_gradients'
 
 
78
 
79
  # Training constants
80
  MODEL = 'unet'
@@ -91,6 +96,8 @@ DATA_FORMAT = 'channels_last' # or 'channels_fist'
91
  DATA_DIR = './data'
92
  MODEL_CHECKPOINT = './model_checkpoint'
93
  BATCH_SIZE = 8
 
 
94
  EPOCHS = 100
95
  SAVE_EPOCH = EPOCHS // 10 # Epoch when to save the model
96
  VERBOSE_EPOCH = EPOCHS // 10
@@ -99,7 +106,6 @@ VALIDATION_ERR_LIMIT_COUNTER = 10 # Number of successive times the validation e
99
  VALIDATION_ERR_LIMIT_COUNTER_BACKUP = 10
100
  THRESHOLD = 0.5 # Threshold to select the centerline in the interpolated images
101
  RESTORE_TRAINING = True # look for previously saved models to resume training
102
- EARLY_STOP_PATIENCE = 10
103
  LOG_FIELD_NAMES = ['time', 'epoch', 'step',
104
  'training_loss_mean', 'training_loss_std',
105
  'training_loss1_mean', 'training_loss1_std',
@@ -362,10 +368,13 @@ COORDS_GRID = CoordinatesGrid()
362
  class VisualizationParameters:
363
  def __init__(self):
364
  self.__scale = None # See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.quiver.html
365
- self.__spacing = 5
366
 
367
  def set_spacing(self, img_shape: tf.TensorShape):
368
- self.__spacing = int(5 * np.log(img_shape[W]))
 
 
 
369
 
370
  @property
371
  def spacing(self):
@@ -495,3 +504,19 @@ MANUAL_W = [1.] * len(PRIOR_W)
495
 
496
  REG_PRIOR_W = [1e-3]
497
  REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  H5_FIX_TUMORS_MASK = 'input/{}'.format(FIXED_TUMORS_MASK)
53
  H5_FIX_SEGMENTATIONS = 'input/{}'.format(FIXED_SEGMENTATIONS)
54
  H5_MOV_SEGMENTATIONS = 'input/{}'.format(MOVING_SEGMENTATIONS)
55
+ H5_FIX_CENTROID = 'input/fix_centroid'
56
+ H5_MOV_CENTROID = 'input/mov_centroid'
57
 
58
  H5_GT_DISP = 'output/{}'.format(DISP_MAP_GT)
59
  H5_GT_IMG = 'output/{}'.format(PRED_IMG_GT)
 
66
  MAX_FLIPS = 2 # Axes to flip over
67
  NUM_ROTATIONS = 5
68
  MAX_WORKERS = 10
69
+ DEG_TO_RAD = np.pi/180.
70
 
71
  # Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
72
  DG_LBL_FIX_IMG = H5_FIX_IMG
 
78
  DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
79
  DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
80
  DG_LBL_ZERO_GRADS = 'zero_gradients'
81
+ DG_LBL_FIX_CENTROID = H5_FIX_CENTROID
82
+ DG_LBL_MOV_CENTROID = H5_MOV_CENTROID
83
 
84
  # Training constants
85
  MODEL = 'unet'
 
96
  DATA_DIR = './data'
97
  MODEL_CHECKPOINT = './model_checkpoint'
98
  BATCH_SIZE = 8
99
+ ACCUM_GRADIENT_STEP = 1
100
+ EARLY_STOP_PATIENCE = 30 # Weights are updated every ACCUM_GRADIENT_STEPth step
101
  EPOCHS = 100
102
  SAVE_EPOCH = EPOCHS // 10 # Epoch when to save the model
103
  VERBOSE_EPOCH = EPOCHS // 10
 
106
  VALIDATION_ERR_LIMIT_COUNTER_BACKUP = 10
107
  THRESHOLD = 0.5 # Threshold to select the centerline in the interpolated images
108
  RESTORE_TRAINING = True # look for previously saved models to resume training
 
109
  LOG_FIELD_NAMES = ['time', 'epoch', 'step',
110
  'training_loss_mean', 'training_loss_std',
111
  'training_loss1_mean', 'training_loss1_std',
 
368
  class VisualizationParameters:
369
  def __init__(self):
370
  self.__scale = None # See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.quiver.html
371
+ self.__spacing = 15
372
 
373
  def set_spacing(self, img_shape: tf.TensorShape):
374
+ if isinstance(img_shape, tf.TensorShape):
375
+ self.__spacing = int(5 * np.log(img_shape[W]))
376
+ else:
377
+ self.__spacing = img_shape
378
 
379
  @property
380
  def spacing(self):
 
504
 
505
  REG_PRIOR_W = [1e-3]
506
  REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
507
+
508
+ # Constants for augmentation layer
509
+ # .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
510
+ # The augmentation values will be scaled using the average+std
511
+ IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
512
+ MAX_AUG_DISP_ISOT = 30
513
+ MAX_AUG_DEF_ISOT = 6
514
+ MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
515
+ MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
516
+ MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
517
+ np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[1]) * 180 / np.pi,
518
+ np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi]) # Scaled angles
519
+ GAMMA_AUGMENTATION = False
520
+ BRIGHTNESS_AUGMENTATION = False
521
+ NUM_CONTROL_PTS_AUG = 10
522
+ NUM_AUGMENTATIONS = 1
DeepDeformationMapRegistration/utils/misc.py CHANGED
@@ -1,11 +1,19 @@
1
  import os
2
  import errno
 
 
 
 
 
 
 
3
 
4
- def try_mkdir(dir):
 
5
  try:
6
  os.makedirs(dir)
7
  except OSError as err:
8
- if err.errno == errno.EEXIST:
9
  print("Directory " + dir + " already exists")
10
  else:
11
  raise ValueError("Can't create dir " + dir)
@@ -23,3 +31,120 @@ def function_decorator(new_name):
23
  func.__name__ = new_name
24
  return func
25
  return decorator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import errno
3
+ import shutil
4
+ import numpy as np
5
+ from scipy.interpolate import griddata, Rbf, LinearNDInterpolator, NearestNDInterpolator
6
+ from skimage.measure import regionprops
7
+ from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
8
+ from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
9
+ from tensorflow import squeeze
10
 
11
+
12
+ def try_mkdir(dir, verbose=True):
13
  try:
14
  os.makedirs(dir)
15
  except OSError as err:
16
+ if err.errno == errno.EEXIST and verbose:
17
  print("Directory " + dir + " already exists")
18
  else:
19
  raise ValueError("Can't create dir " + dir)
 
31
  func.__name__ = new_name
32
  return func
33
  return decorator
34
+
35
+
36
+ class DatasetCopy:
37
+ def __init__(self, dataset_location, copy_location=None, verbose=True):
38
+ self.__copy_loc = os.path.join(os.getcwd(), 'temp_dataset') if copy_location is None else copy_location
39
+ self.__dst_loc = dataset_location
40
+ self.__verbose = verbose
41
+
42
+ def copy_dataset(self):
43
+ shutil.copytree(self.__dst_loc, self.__copy_loc)
44
+ if self.__verbose:
45
+ print('{} copied to {}'.format(self.__dst_loc, self.__copy_loc))
46
+ return self.__copy_loc
47
+
48
+ def delete_temp(self):
49
+ shutil.rmtree(self.__copy_loc)
50
+ if self.__verbose:
51
+ print('Deleted: ', self.__copy_loc)
52
+
53
+
54
+ class DisplacementMapInterpolator:
55
+ def __init__(self,
56
+ image_shape=[64, 64, 64],
57
+ method='rbf'):
58
+ assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
59
+ self.method = method
60
+ self.image_shape = image_shape
61
+
62
+ self.grid = self.__regular_grid()
63
+
64
+ def __regular_grid(self):
65
+ xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
66
+ yy = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
67
+ zz = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
68
+
69
+ xx, yy, zz = np.meshgrid(xx, yy, zz)
70
+
71
+ return np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
72
+
73
+ def __call__(self, disp_map, interp_points, backwards=False):
74
+ disp_map = disp_map.reshape([-1, 3])
75
+ grid_pts = self.grid.copy()
76
+ if backwards:
77
+ grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
78
+ disp_map *= -1
79
+
80
+ if self.method == 'rbf':
81
+ interpolator = Rbf(grid_pts[:, 0], grid_pts[:, 1], grid_pts[:, 2], disp_map[:, :],
82
+ method='thin_plate', mode='N-D')
83
+ disp = interpolator(interp_points)
84
+ elif self.method == 'griddata':
85
+ linear_interp = LinearNDInterpolator(grid_pts, disp_map)
86
+ disp = linear_interp(interp_points).copy()
87
+ del linear_interp
88
+
89
+ if np.any(np.isnan(disp)):
90
+ # It might happen (though it shouldn't) that the interpolation point is outside the convex hull of grid points.
91
+ # in this situation, linear interpolation fails and will put NaN. Nearest can give a value, so we are going to
92
+ # substitute those unexpected NaNs with the nearest value. Unexpected == not in interp_points
93
+ nan_disp_idx = set(np.unique(np.argwhere(np.isnan(disp))[:, 0]))
94
+ nan_interp_pts_idx = set(np.unique(np.argwhere(np.isnan(interp_points))[:, 0]))
95
+ idx = nan_disp_idx - nan_interp_pts_idx if len(nan_disp_idx) > len(nan_interp_pts_idx) else nan_interp_pts_idx - nan_disp_idx
96
+ idx = list(idx)
97
+ if len(idx):
98
+ # We have unexpected NaNs
99
+ near_interp = NearestNDInterpolator(grid_pts, disp_map)
100
+ near_disp = near_interp(interp_points[idx, ...]).copy()
101
+ del near_interp
102
+ for n, i in enumerate(idx):
103
+ disp[i, ...] = near_disp[n, ...]
104
+ elif self.method == 'tf':
105
+ # Order: 1 -> linear, 2 -> thin plate, 3 -> cubic
106
+ disp = squeeze(interpolate_spline(grid_pts[np.newaxis, ...][::4, :], # Batch axis
107
+ disp_map[np.newaxis, ...][::4, :],
108
+ interp_points[np.newaxis, ...], order=2), axis=0)
109
+ else:
110
+ tps_interp = ThinPlateSplines(grid_pts[::8, :], self.grid.copy().astype(np.float32)[::8, :])
111
+ disp = tps_interp.interpolate(interp_points).eval()
112
+ del tps_interp
113
+
114
+ return disp
115
+
116
+
117
+ def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(0, 28), missing_centroid=[np.nan]*3, brain_study=True):
118
+ segmentations = np.squeeze(segmentations)
119
+ if ohe:
120
+ segmentations = np.sum(segmentations, axis=-1).astype(np.uint8)
121
+ missing_lbls = set(expected_lbls) - set(np.unique(segmentations))
122
+ if brain_study:
123
+ segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1
124
+ else:
125
+ missing_lbls = set(expected_lbls) - set(np.unique(segmentations))
126
+
127
+ seg_props = regionprops(segmentations)
128
+ centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
129
+
130
+ for lbl in missing_lbls:
131
+ idx = expected_lbls.index(lbl)
132
+ centroids = np.insert(centroids, idx, missing_centroid, axis=0)
133
+ return centroids.copy(), missing_lbls
134
+
135
+
136
+ def segmentation_ohe_to_cardinal(segmentation):
137
+ cpy = segmentation.copy()
138
+ for lbl in range(segmentation.shape[-1]):
139
+ cpy[..., lbl] *= (lbl + 1)
140
+ # Add the Background
141
+ cpy = np.concatenate([np.zeros(segmentation.shape[:-1])[..., np.newaxis], cpy], axis=-1)
142
+ return np.argmax(cpy, axis=-1)[..., np.newaxis]
143
+
144
+
145
+ def segmentation_cardinal_to_ohe(segmentation):
146
+ # Keep in mind that we don't handle the overlap between the segmentations!
147
+ cpy = np.tile(np.zeros_like(segmentation), (1, 1, 1, len(np.unique(segmentation)[1:])))
148
+ for ch, lbl in enumerate(np.unique(segmentation)[1:]):
149
+ cpy[segmentation == lbl, ch] = 1
150
+ return cpy
DeepDeformationMapRegistration/utils/nifti_utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import nibabel as nb
3
+ import numpy as np
4
+ import zipfile
5
+
6
+
7
+ TEMP_UNZIP_PATH = '/mnt/EncryptedData1/Users/javier/ext_datasets/LITS17/temp'
8
+ NII_EXTENSION = '.nii'
9
+
10
+
11
+ def save_nifti(data, save_path, header=None, verbose=True):
12
+ if header is None:
13
+ data_nifti = nb.Nifti1Image(data, affine=np.eye(4))
14
+ else:
15
+ data_nifti = nb.Nifti1Image(data, affine=None, header=header)
16
+
17
+ data_nifti.header.get_xyzt_units()
18
+ try:
19
+ nb.save(data_nifti, save_path) # Save as NiBabel file
20
+ if verbose:
21
+ print('Saved {}'.format(save_path))
22
+ except ValueError:
23
+ print('Could not save {}'.format(save_path))
24
+
25
+
26
+ def unzip_nii_file(file_path):
27
+ file_dir, file_name = os.path.split(file_path)
28
+ file_name = file_name.split('.zip')[0]
29
+
30
+ dest_path = os.path.join(TEMP_UNZIP_PATH, file_name)
31
+ zipfile.ZipFile(file_path).extractall(TEMP_UNZIP_PATH)
32
+
33
+ if not os.path.exists(dest_path):
34
+ print('ERR: File {} not unzip-ed!'.format(file_path))
35
+ dest_path = None
36
+ return dest_path
37
+
38
+
39
+ def delete_temp(file_path, verbose=False):
40
+ assert NII_EXTENSION in file_path
41
+ os.remove(file_path)
42
+ if verbose:
43
+ print('Deleted file: ', file_path)
DeepDeformationMapRegistration/utils/operators.py CHANGED
@@ -11,14 +11,23 @@ def min_max_norm(img: np.ndarray, out_max_val=1.):
11
  return out_img * out_max_val
12
 
13
 
14
- def soft_threshold(x, threshold, name=None):
15
  # https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
16
- with tf.name_scope(name or 'soft_threshold'):
 
 
 
17
  x = tf.convert_to_tensor(x, name='x')
18
  threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
19
  return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
20
 
21
 
 
 
 
 
 
 
22
  def binary_activation(x):
23
  # https://stackoverflow.com/questions/37743574/hard-limiting-threshold-activation-function-in-tensorflow
24
  cond = tf.less(x, tf.zeros(tf.shape(x)))
@@ -26,3 +35,31 @@ def binary_activation(x):
26
 
27
  return out
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  return out_img * out_max_val
12
 
13
 
14
+ def soft_threshold(x, threshold, name='soft_threshold'):
15
  # https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
16
+ # Foucart S., Rauhut H. (2013) Basic Algorithms. In: A Mathematical Introduction to Compressive Sensing.
17
+ # Applied and Numerical Harmonic Analysis. Birkhäuser, New York, NY. https://doi.org/10.1007/978-0-8176-4948-7_3
18
+ # Chapter 3, page 72
19
+ with tf.name_scope(name):
20
  x = tf.convert_to_tensor(x, name='x')
21
  threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
22
  return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
23
 
24
 
25
+ def hard_threshold(x, threshold, name='hard_threshold'):
26
+ with tf.name_scope(name):
27
+ threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
28
+ return tf.sign(tf.maximum(tf.abs(x) - threshold, 0.))
29
+
30
+
31
  def binary_activation(x):
32
  # https://stackoverflow.com/questions/37743574/hard-limiting-threshold-activation-function-in-tensorflow
33
  cond = tf.less(x, tf.zeros(tf.shape(x)))
 
35
 
36
  return out
37
 
38
+
39
+ def gaussian_kernel(kernel_size, sigma, in_ch, out_ch, dim, dtype=tf.float32):
40
+ # SRC: https://stackoverflow.com/questions/59286171/gaussian-blur-image-in-dataset-pipeline-in-tensorflow
41
+ x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=dtype)
42
+ g = tf.math.exp(-(tf.pow(x, 2) / (2 * tf.pow(tf.cast(sigma, dtype), 2))))
43
+
44
+ g_kernel = tf.identity(g)
45
+ g_kernel = tf.tensordot(g_kernel, g, 0)
46
+ g_kernel = tf.tensordot(g_kernel, g, 0)
47
+
48
+ # i = tf.constant(0)
49
+ # cond = lambda i, g_kern: tf.less(i, dim - 1)
50
+ # mult_kern = lambda i, g_kern: [tf.add(i, 1), tf.tensordot(g_kern, g, 0)]
51
+ # _, g_kernel = tf.while_loop(cond, mult_kern,
52
+ # loop_vars=[i, g_kernel],
53
+ # shape_invariants=[i.get_shape(), tf.TensorShape([kernel_size, None, None])])
54
+
55
+ g_kernel = g_kernel / tf.reduce_sum(g_kernel)
56
+ g_kernel = tf.expand_dims(tf.expand_dims(g_kernel, axis=-1), axis=-1)
57
+ return tf.tile(g_kernel, (*(1,)*dim, in_ch, out_ch))
58
+
59
+
60
+ def sample_unique(population, samples, tout=tf.int32):
61
+ # src: https://github.com/tensorflow/tensorflow/issues/9260#issuecomment-437875125
62
+ z = -tf.log(-tf.log(tf.random_uniform((tf.shape(population)[0],), 0, 1)))
63
+ _, indices = tf.nn.top_k(z, samples)
64
+ ret_val = tf.gather(population, indices)
65
+ return tf.cast(ret_val, tout)
DeepDeformationMapRegistration/utils/thin_plate_splines.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+
4
+
5
+ class ThinPlateSplines:
6
+ def __init__(self, ctrl_pts: tf.Tensor, target_pts: tf.Tensor, reg=0.0):
7
+ """
8
+
9
+ :param ctrl_pts: [N, d] tensor of control d-dimensional points
10
+ :param target_pts: [N, d] tensor of target d-dimensional points
11
+ :param reg: regularization coefficient
12
+ """
13
+ self.__ctrl_pts = ctrl_pts
14
+ self.__target_pts = target_pts
15
+ self.__reg = reg
16
+ self.__num_ctrl_pts = ctrl_pts.shape[0]
17
+ self.__dim = ctrl_pts.shape[1]
18
+
19
+ self.__compute_coeffs()
20
+ # self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...] # Affine parameters of the TPS
21
+ self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] # Non-affine parameters of he TPS
22
+
23
+ def __compute_coeffs(self):
24
+ target_pts_aug = tf.concat([self.__target_pts,
25
+ tf.zeros([self.__dim + 1, self.__dim], dtype=self.__target_pts.dtype)],
26
+ axis=0)
27
+
28
+ # T = self.__make_T()
29
+ T_i = tf.cast(tf.linalg.inv(self.__make_T()), target_pts_aug.dtype)
30
+ self.__coeffs = tf.cast(tf.matmul(T_i, target_pts_aug), tf.float32)
31
+
32
+ def __make_T(self):
33
+ # cp: [K x 2] control points
34
+ # T: [(num_pts+dim+1) x (num_pts+dim+1)]
35
+ num_pts = self.__ctrl_pts.shape[0]
36
+
37
+ P = tf.concat([tf.ones([self.__num_ctrl_pts, 1], dtype=tf.float32), self.__ctrl_pts], axis=1)
38
+ zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
39
+ self.__K = self.__U_dist(self.__ctrl_pts)
40
+ alfa = tf.reduce_mean(self.__K)
41
+
42
+ self.__K = self.__K + tf.ones_like(self.__K) * tf.pow(alfa, 2) * self.__reg
43
+
44
+ # top = tf.concat([self.__K, P], axis=1)
45
+ # bottom = tf.concat([tf.transpose(P), zeros], axis=1)
46
+
47
+ return tf.concat([tf.concat([self.__K, P], axis=1), tf.concat([tf.transpose(P), zeros], axis=1)], axis=0)
48
+
49
+ def __U_dist(self, ctrl_pts, int_pts=None):
50
+ if int_pts is None:
51
+ dist = self.__pairwise_distance_equal(ctrl_pts) # Already squared!
52
+ else:
53
+ dist = self.__pairwise_distance_different(ctrl_pts, int_pts) # Already squared!
54
+
55
+
56
+ # U(x, y) = p_w_dist(x, y)^2 * log(p_w_dist(x, y)) (dist() > =0); 0 otw
57
+ if ctrl_pts.shape[-1] == 2:
58
+ u_dist = dist * tf.math.log(dist + 1e-6)
59
+ else:
60
+ # Src: https://github.com/vaipatel/morphops/blob/master/morphops/tps.py
61
+ # In particular, if k = 2, then U(r) = r^2 * log(r^2), else U(r) = r
62
+ u_dist = tf.sqrt(dist)
63
+ # tf.matrix_set_diag(u_dist, tf.constant(0, dtype=dist_sq.dtype))
64
+ # reg_term = self.__reg * tf.pow(alfa, 2) * tf.eye(self.__num_ctrl_pts)
65
+
66
+ return u_dist # + reg_term
67
+
68
+ def __pairwise_distance_sq(self, pts_a, pts_b):
69
+ with tf.variable_scope('pairwise_distance'):
70
+ if np.all(pts_a == pts_b):
71
+ # This implementation works better when doing the pairwise distance os a single set of points
72
+ pts_a_ = tf.reshape(pts_a, [-1, 1, 3])
73
+ pts_b_ = tf.reshape(pts_b, [1, -1, 3])
74
+ dist = tf.reduce_sum(tf.square(pts_a_ - pts_b_), 2) # squared pairwise distance
75
+ else:
76
+ # PwD^2= A_norm^2 - 2*A*B' + B_norm^2
77
+ pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
78
+ pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)
79
+
80
+ pts_a_ = tf.expand_dims(pts_a_, 1)
81
+ pts_b_ = tf.expand_dims(pts_b_, 0)
82
+
83
+ pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)
84
+
85
+ dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
86
+
87
+ return tf.cast(dist, tf.float32)
88
+
89
+ @staticmethod
90
+ def __pairwise_distance_equal(pts):
91
+ # This implementation works better when doing the pairwise distance os a single set of points
92
+ dist = tf.reduce_sum(tf.square(tf.reshape(pts, [-1, 1, 3]) - tf.reshape(pts, [1, -1, 3])), 2) # squared pairwise distance
93
+ return tf.cast(dist, tf.float32)
94
+
95
+ @staticmethod
96
+ def __pairwise_distance_different(pts_a, pts_b):
97
+ pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
98
+ pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)
99
+
100
+ pts_a_ = tf.expand_dims(pts_a_, 1)
101
+ pts_b_ = tf.expand_dims(pts_b_, 0)
102
+
103
+ pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)
104
+
105
+ dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
106
+ return tf.cast(dist, tf.float32)
107
+
108
+ def __lift_pts(self, int_pts: tf.Tensor, num_pts):
109
+ # int_pts: [N x 2], input points
110
+ # cp: [K x 2], control points
111
+ # pLift: [N x (3+K)], lifted input points
112
+
113
+ # u_dist = self.__U_dist(int_pts, self.__ctrl_pts)
114
+
115
+ int_pts_lift = tf.concat([self.__U_dist(int_pts, self.__ctrl_pts),
116
+ tf.ones([num_pts, 1], dtype=tf.float32),
117
+ int_pts], axis=1)
118
+ return int_pts_lift
119
+
120
+ @property
121
+ def bending_energy(self):
122
+ aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
123
+ return tf.matmul(aux, self.__non_aff_paramms)
124
+
125
+ def interpolate(self, int_points): #, num_pts):
126
+ """
127
+
128
+ :param int_points: [K, d] flattened d-points of a mesh
129
+ :return:
130
+ """
131
+ num_pts = tf.shape(int_points)[0]
132
+ int_points_lift = self.__lift_pts(int_points, num_pts)
133
+
134
+ return tf.matmul(int_points_lift, self.__coeffs)
135
+
136
+ def __call__(self, int_points, num_pts, **kwargs):
137
+ return self.interpolate(int_points) # , num_pts)
138
+
139
+
140
+ def thin_plate_splines_batch(ctrl_pts: tf.Tensor, target_pts: tf.Tensor, int_pts: tf.Tensor, reg=0.0):
141
+ _batches = ctrl_pts.shape[0]
142
+
143
+ if tf.get_default_session() is not None:
144
+ print('DEBUG TIME')
145
+
146
+ def tps_sample(in_data):
147
+ cp, tp, ip = in_data
148
+ # _num_pts = ip.shape[0]
149
+ tps = ThinPlateSplines(cp, tp, reg)
150
+ interp = tps.interpolate(ip) # , _num_pts)
151
+ return interp
152
+
153
+ return tf.map_fn(tps_sample, elems=(ctrl_pts, target_pts, int_pts), dtype=tf.float32)
154
+
DeepDeformationMapRegistration/utils/visualization.py CHANGED
@@ -1,5 +1,5 @@
1
- # import matplotlib
2
- # matplotlib.use('TkAgg')
3
  import matplotlib.pyplot as plt
4
  from mpl_toolkits.mplot3d import Axes3D
5
  import matplotlib.colors as mcolors
@@ -8,7 +8,7 @@ from matplotlib import cm
8
  from mpl_toolkits.axes_grid1 import make_axes_locatable
9
  import tensorflow as tf
10
  import numpy as np
11
- import pyVesselRegistration_constants as const
12
  from skimage.exposure import rescale_intensity
13
  import scipy.misc as scpmisc
14
  import os
@@ -175,11 +175,11 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
175
  fig.clear()
176
  plt.figure(fig.number)
177
  else:
178
- fig = plt.figure(dpi=const.DPI)
179
 
180
  ax_fix = fig.add_subplot(231)
181
  im_fix = ax_fix.imshow(list_imgs[0][:, :, 0])
182
- ax_fix.set_title('Fix image', fontsize=const.FONT_SIZE)
183
  ax_fix.tick_params(axis='both',
184
  which='both',
185
  bottom=False,
@@ -188,7 +188,7 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
188
  labelbottom=False)
189
  ax_mov = fig.add_subplot(232)
190
  im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
191
- ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
192
  ax_mov.tick_params(axis='both',
193
  which='both',
194
  bottom=False,
@@ -198,7 +198,7 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
198
 
199
  ax_pred_im = fig.add_subplot(233)
200
  im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
201
- ax_pred_im.set_title('Prediction', fontsize=const.FONT_SIZE)
202
  ax_pred_im.tick_params(axis='both',
203
  which='both',
204
  bottom=False,
@@ -228,8 +228,8 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
228
  else:
229
  cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3])
230
  im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
231
- ax_pred_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
232
- ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
233
  ax_pred_disp.tick_params(axis='both',
234
  which='both',
235
  bottom=False,
@@ -259,8 +259,8 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
259
  else:
260
  cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4])
261
  im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal')
262
- ax_gt_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
263
- ax_gt_disp.set_title('GT disp map', fontsize=const.FONT_SIZE)
264
  ax_gt_disp.tick_params(axis='both',
265
  which='both',
266
  bottom=False,
@@ -276,7 +276,7 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
276
 
277
  if filename is not None:
278
  plt.savefig(filename, format='png') # Call before show
279
- if not const.REMOTE:
280
  plt.show()
281
  else:
282
  plt.close()
@@ -288,7 +288,7 @@ def save_centreline_img(img, title, filename, fig=None):
288
  fig.clear()
289
  plt.figure(fig.number)
290
  else:
291
- fig = plt.figure(dpi=const.DPI)
292
 
293
  dim = len(img.shape[:-1])
294
 
@@ -321,19 +321,19 @@ def save_centreline_img(img, title, filename, fig=None):
321
  plt.close()
322
 
323
 
324
- def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
325
  if fig is not None:
326
  fig.clear()
327
  plt.figure(fig.number)
328
  else:
329
- fig = plt.figure(dpi=const.DPI)
330
-
331
  dim = disp_map.shape[-1]
332
 
333
  if dim == 2:
334
  ax_x = fig.add_subplot(131)
335
  ax_x.set_title('H displacement')
336
- im_x = ax_x.imshow(disp_map[..., const.H_DISP])
337
  ax_x.tick_params(axis='both',
338
  which='both',
339
  bottom=False,
@@ -344,7 +344,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
344
 
345
  ax_y = fig.add_subplot(132)
346
  ax_y.set_title('W displacement')
347
- im_y = ax_y.imshow(disp_map[..., const.W_DISP])
348
  ax_y.tick_params(axis='both',
349
  which='both',
350
  bottom=False,
@@ -373,8 +373,8 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
373
  else:
374
  c, d, s = _prepare_quiver_map(disp_map, dim=dim)
375
  im = ax.imshow(s, interpolation='none', aspect='equal')
376
- ax.quiver(c[const.H_DISP], c[const.W_DISP], d[const.H_DISP], d[const.W_DISP],
377
- scale=const.QUIVER_PARAMS.arrow_scale)
378
  cb = _set_colorbar(fig, ax, im, False)
379
  ax.set_title('Displacement map')
380
  ax.tick_params(axis='both',
@@ -387,9 +387,8 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
387
  else:
388
  ax = fig.add_subplot(111, projection='3d')
389
  c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
390
- ax.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP], d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
391
- norm=True)
392
- _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
393
  fig.suptitle('Displacement map')
394
  ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html
395
  which='both',
@@ -397,10 +396,14 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
397
  left=False,
398
  labelleft=False,
399
  labelbottom=False)
 
400
  fig.suptitle(title)
401
 
402
  plt.savefig(filename, format='png')
 
 
403
  plt.close()
 
404
 
405
 
406
  def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None,
@@ -409,16 +412,16 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
409
  fig.clear()
410
  plt.figure(fig.number)
411
  else:
412
- fig = plt.figure(dpi=const.DPI)
413
 
414
  dim = len(list_imgs[0].shape[:-1])
415
 
416
  if dim == 2:
417
  # TRAINING
418
  ax_input = fig.add_subplot(241)
419
- ax_input.set_ylabel(title_first_row, fontsize=const.FONT_SIZE)
420
  im_fix = ax_input.imshow(list_imgs[0][:, :, 0])
421
- ax_input.set_title('Fix image', fontsize=const.FONT_SIZE)
422
  ax_input.tick_params(axis='both',
423
  which='both',
424
  bottom=False,
@@ -427,7 +430,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
427
  labelbottom=False)
428
  ax_mov = fig.add_subplot(242)
429
  im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
430
- ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
431
  ax_mov.tick_params(axis='both',
432
  which='both',
433
  bottom=False,
@@ -437,7 +440,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
437
 
438
  ax_pred_im = fig.add_subplot(244)
439
  im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
440
- ax_pred_im.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
441
  ax_pred_im.tick_params(axis='both',
442
  which='both',
443
  bottom=False,
@@ -467,8 +470,8 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
467
  else:
468
  cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim)
469
  im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
470
- ax_pred_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
471
- ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
472
  ax_pred_disp.tick_params(axis='both',
473
  which='both',
474
  bottom=False,
@@ -478,9 +481,9 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
478
 
479
  # VALIDATION
480
  axinput_val = fig.add_subplot(245)
481
- axinput_val.set_ylabel(title_second_row, fontsize=const.FONT_SIZE)
482
  im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0])
483
- axinput_val.set_title('Fix image', fontsize=const.FONT_SIZE)
484
  axinput_val.tick_params(axis='both',
485
  which='both',
486
  bottom=False,
@@ -489,7 +492,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
489
  labelbottom=False)
490
  ax_mov_val = fig.add_subplot(246)
491
  im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0])
492
- ax_mov_val.set_title('Moving image', fontsize=const.FONT_SIZE)
493
  ax_mov_val.tick_params(axis='both',
494
  which='both',
495
  bottom=False,
@@ -499,7 +502,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
499
 
500
  ax_pred_im_val = fig.add_subplot(248)
501
  im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0])
502
- ax_pred_im_val.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
503
  ax_pred_im_val.tick_params(axis='both',
504
  which='both',
505
  bottom=False,
@@ -529,8 +532,8 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
529
  else:
530
  c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
531
  im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
532
- ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=const.QUIVER_PARAMS.arrow_scale)
533
- ax_pred_disp_val.set_title('Pred disp map', fontsize=const.FONT_SIZE)
534
  ax_pred_disp_val.tick_params(axis='both',
535
  which='both',
536
  bottom=False,
@@ -552,10 +555,10 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
552
  # 3D
553
  # TRAINING
554
  ax_input = fig.add_subplot(231, projection='3d')
555
- ax_input.set_ylabel(title_first_row, fontsize=const.FONT_SIZE)
556
  im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
557
  im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
558
- ax_input.set_title('Fix image', fontsize=const.FONT_SIZE)
559
  ax_input.tick_params(axis='both',
560
  which='both',
561
  bottom=False,
@@ -566,7 +569,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
566
  ax_pred_im = fig.add_subplot(232, projection='3d')
567
  im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction')
568
  im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
569
- ax_pred_im.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
570
  ax_pred_im.tick_params(axis='both',
571
  which='both',
572
  bottom=False,
@@ -578,9 +581,9 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
578
 
579
  c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim)
580
  im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
581
- ax_pred_disp.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
582
- d[const.H_DISP], d[const.W_DISP], d[const.D_DISP], scale=const.QUIVER_PARAMS.arrow_scale)
583
- ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
584
  ax_pred_disp.tick_params(axis='both',
585
  which='both',
586
  bottom=False,
@@ -590,10 +593,10 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
590
 
591
  # VALIDATION
592
  axinput_val = fig.add_subplot(234, projection='3d')
593
- axinput_val.set_ylabel(title_second_row, fontsize=const.FONT_SIZE)
594
  im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
595
  im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)')
596
- axinput_val.set_title('Fix image', fontsize=const.FONT_SIZE)
597
  axinput_val.tick_params(axis='both',
598
  which='both',
599
  bottom=False,
@@ -604,7 +607,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
604
  ax_pred_im_val = fig.add_subplot(235, projection='3d')
605
  im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)')
606
  im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
607
- ax_pred_im_val.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
608
  ax_pred_im_val.tick_params(axis='both',
609
  which='both',
610
  bottom=False,
@@ -615,10 +618,10 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
615
  ax_pred_disp_val = fig.add_subplot(236, projection='3d')
616
  c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
617
  im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
618
- ax_pred_disp_val.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
619
- d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
620
- scale=const.QUIVER_PARAMS.arrow_scale)
621
- ax_pred_disp_val.set_title('Pred disp map', fontsize=const.FONT_SIZE)
622
  ax_pred_disp_val.tick_params(axis='both',
623
  which='both',
624
  bottom=False,
@@ -628,7 +631,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
628
 
629
  if filename is not None:
630
  plt.savefig(filename, format='png') # Call before show
631
- if not const.REMOTE:
632
  plt.show()
633
  else:
634
  plt.close()
@@ -644,21 +647,21 @@ def _set_colorbar(fig, ax, im, drawedges=True):
644
  return im_cb
645
 
646
 
647
- def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=const.QUIVER_PARAMS.spacing):
648
  if isinstance(disp_map, tf.Tensor):
649
  if tf.executing_eagerly():
650
  disp_map = disp_map.numpy()
651
  else:
652
  disp_map = disp_map.eval()
653
- dx = disp_map[..., const.H_DISP]
654
- dy = disp_map[..., const.W_DISP]
655
  if dim > 2:
656
- dz = disp_map[..., const.D_DISP]
657
 
658
- img_size_x = disp_map.shape[const.H_DISP]
659
- img_size_y = disp_map.shape[const.W_DISP]
660
  if dim > 2:
661
- img_size_z = disp_map.shape[const.D_DISP]
662
 
663
  if dim > 2:
664
  s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz))
@@ -728,7 +731,7 @@ def plot_input_data(fix_img, mov_img, img_size=(64, 64), title=None, filename=No
728
 
729
  if filename is not None:
730
  plt.savefig(filename, format='png') # Call before show
731
- if not const.REMOTE:
732
  plt.show()
733
  else:
734
  plt.close()
@@ -807,29 +810,42 @@ def plot_dataset_3d(img_sets):
807
  return fig
808
 
809
 
810
- def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None):
811
  num_rows = fix_img_batch.shape[0]
812
- img_size = fix_img_batch.shape[1:3]
 
813
  if fig is not None:
814
  fig.clear()
815
  plt.figure(fig.number)
816
- ax = fig.add_subplot(nrows=num_rows, ncols=4, dpi=const.DPI)
817
  else:
818
- fig, ax = plt.subplots(nrows=num_rows, ncols=4, dpi=const.DPI)
 
 
 
 
 
 
 
 
 
 
 
 
819
 
820
  for row in range(num_rows):
821
- fix_img = fix_img_batch[row, :, :, 0]
822
- mov_img = mov_img_batch[row, :, :, 0]
823
- disp_map = disp_map_batch[row, :, :, :]
824
- pred_img = pred_img_batch[row, :, :, 0]
825
- ax[row, 0].imshow(fix_img)
826
  ax[row, 0].tick_params(axis='both',
827
  which='both',
828
  bottom=False,
829
  left=False,
830
  labelleft=False,
831
  labelbottom=False)
832
- ax[row, 1].imshow(mov_img)
833
  ax[row, 1].tick_params(axis='both',
834
  which='both',
835
  bottom=False,
@@ -837,11 +853,12 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
837
  labelleft=False,
838
  labelbottom=False)
839
 
840
- cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
 
 
841
  disp_map_color = _prepare_colormap(disp_map)
842
- ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal')
843
- ax[row, 2].quiver(cx.eval(), cy.eval(), dx.eval(), dy.eval(), units='xy', scale=const.QUIVER_PARAMS.arrow_scale)
844
- ax[row, 2].figure.set_size_inches(img_size)
845
  ax[row, 2].tick_params(axis='both',
846
  which='both',
847
  bottom=False,
@@ -849,6 +866,8 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
849
  labelleft=False,
850
  labelbottom=False)
851
 
 
 
852
  ax[row, 3].tick_params(axis='both',
853
  which='both',
854
  bottom=False,
@@ -856,18 +875,26 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
856
  labelleft=False,
857
  labelbottom=False)
858
 
859
- plt.axis('off')
860
- ax[0, 0].set_title('Fixed img ($I_f$)', fontsize=const.FONT_SIZE)
861
- ax[0, 1].set_title('Moving img ($I_m$)', fontsize=const.FONT_SIZE)
862
- ax[0, 2].set_title('Displacement map ($\delta$)', fontsize=const.FONT_SIZE)
863
- ax[0, 3].set_title('Updated $I_m$', fontsize=const.FONT_SIZE)
 
 
864
 
 
 
 
 
 
 
 
865
  if filename is not None:
866
  plt.savefig(filename, format='png') # Call before show
867
- if not const.REMOTE:
868
  plt.show()
869
- else:
870
- plt.close()
871
  return fig
872
 
873
 
@@ -876,7 +903,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
876
  fig.clear()
877
  plt.figure(fig.number)
878
  else:
879
- fig = plt.figure(dpi=const.DPI)
880
 
881
  ax0 = fig.add_subplot(221)
882
  im0 = ax0.imshow(fix_img[..., 0])
@@ -900,7 +927,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
900
  ax2 = fig.add_subplot(223)
901
  im2 = ax2.imshow(s, interpolation='none', aspect='equal')
902
 
903
- ax2.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
904
  # ax2.figure.set_size_inches(img_size)
905
  ax2.tick_params(axis='both',
906
  which='both',
@@ -912,7 +939,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
912
  ax3 = fig.add_subplot(224)
913
  dif = fix_img[..., 0] - mov_img[..., 0]
914
  im3 = ax3.imshow(dif)
915
- ax3.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
916
  ax3.tick_params(axis='both',
917
  which='both',
918
  bottom=False,
@@ -921,10 +948,10 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
921
  labelbottom=False)
922
 
923
  plt.axis('off')
924
- ax0.set_title('Fixed img ($I_f$)', fontsize=const.FONT_SIZE)
925
- ax1.set_title('Moving img ($I_m$)', fontsize=const.FONT_SIZE)
926
- ax2.set_title('Displacement map', fontsize=const.FONT_SIZE)
927
- ax3.set_title('Fix - Mov', fontsize=const.FONT_SIZE)
928
 
929
  im0_cb = _set_colorbar(fig, ax0, im0, False)
930
  im1_cb = _set_colorbar(fig, ax1, im1, False)
@@ -933,7 +960,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
933
 
934
  if filename is not None:
935
  plt.savefig(filename, format='png') # Call before show
936
- if not const.REMOTE:
937
  plt.show()
938
  else:
939
  plt.close()
@@ -950,7 +977,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
950
  fig = plt.figure()
951
 
952
  ax_grid = fig.add_subplot(231)
953
- ax_grid.set_title('Grids', fontsize=const.FONT_SIZE)
954
  ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20)
955
  ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1)
956
  ax_grid.tick_params(axis='both',
@@ -966,7 +993,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
966
  ax_grid.set_aspect('equal')
967
 
968
  ax_disp = fig.add_subplot(232)
969
- ax_disp.set_title('Displacement map', fontsize=const.FONT_SIZE)
970
  cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
971
  ax_disp.imshow(s, interpolation='none', aspect='equal')
972
  ax_disp.tick_params(axis='both',
@@ -977,7 +1004,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
977
  labelbottom=False)
978
 
979
  ax_mask = fig.add_subplot(233)
980
- ax_mask.set_title('Mask', fontsize=const.FONT_SIZE)
981
  ax_mask.imshow(mask)
982
  ax_mask.tick_params(axis='both',
983
  which='both',
@@ -987,7 +1014,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
987
  labelbottom=False)
988
 
989
  ax_fix = fig.add_subplot(234)
990
- ax_fix.set_title('Fix image', fontsize=const.FONT_SIZE)
991
  ax_fix.imshow(fix_img[..., 0])
992
  ax_fix.tick_params(axis='both',
993
  which='both',
@@ -997,7 +1024,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
997
  labelbottom=False)
998
 
999
  ax_mov = fig.add_subplot(235)
1000
- ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
1001
  ax_mov.imshow(mov_img[..., 0])
1002
  ax_mov.tick_params(axis='both',
1003
  which='both',
@@ -1007,7 +1034,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
1007
  labelbottom=False)
1008
 
1009
  ax_dif = fig.add_subplot(236)
1010
- ax_dif.set_title('Fix - Moving image', fontsize=const.FONT_SIZE)
1011
  ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
1012
  ax_dif.tick_params(axis='both',
1013
  which='both',
@@ -1023,7 +1050,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
1023
 
1024
  if filename is not None:
1025
  plt.savefig(filename, format='png') # Call before show
1026
- if not const.REMOTE:
1027
  plt.show()
1028
 
1029
  return fig
@@ -1037,10 +1064,10 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
1037
  fig = plt.figure()
1038
 
1039
  ax_d_m_f = fig.add_subplot(131)
1040
- ax_d_m_f.set_title('Disp M->F', fontsize=const.FONT_SIZE)
1041
  cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f)
1042
  ax_d_m_f.imshow(s, interpolation='none', aspect='equal')
1043
- ax_d_m_f.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
1044
  ax_d_m_f.tick_params(axis='both',
1045
  which='both',
1046
  bottom=False,
@@ -1049,9 +1076,9 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
1049
  labelbottom=False)
1050
 
1051
  ax_d_f_m = fig.add_subplot(132)
1052
- ax_d_f_m.set_title('Disp F->M', fontsize=const.FONT_SIZE)
1053
  cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m)
1054
- ax_d_f_m.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
1055
  ax_d_f_m.imshow(s, interpolation='none', aspect='equal')
1056
  ax_d_f_m.tick_params(axis='both',
1057
  which='both',
@@ -1061,7 +1088,7 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
1061
  labelbottom=False)
1062
 
1063
  ax_dif = fig.add_subplot(133)
1064
- ax_dif.set_title('Fix - Moving image', fontsize=const.FONT_SIZE)
1065
  ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
1066
  ax_dif.tick_params(axis='both',
1067
  which='both',
@@ -1078,7 +1105,7 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
1078
 
1079
  if filename is not None:
1080
  plt.savefig(filename, format='png') # Call before show
1081
- if not const.REMOTE:
1082
  plt.show()
1083
  else:
1084
  plt.close()
@@ -1104,7 +1131,7 @@ def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='
1104
  fig.tight_layout(pad=5.0)
1105
  ax = fig.add_subplot(2, num_cols, 1, projection='3d')
1106
  ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
1107
- ax.set_title('Fix image', fontsize=const.FONT_SIZE)
1108
  _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1109
 
1110
  for i in range(2, num_preds+2):
@@ -1116,23 +1143,23 @@ def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='
1116
 
1117
  ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d')
1118
  ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
1119
- ax.set_title('Fix image', fontsize=const.FONT_SIZE)
1120
  _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1121
 
1122
  for i in range(num_preds+2, 2 * num_preds + 2):
1123
  ax = fig.add_subplot(2, num_cols, i + 1, projection='3d')
1124
  c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3)
1125
- ax.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
1126
- d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
1127
  norm=True)
1128
  ax.set_title('Disp. #{}'.format(i - 5))
1129
  _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1130
 
1131
- fig.suptitle(fig_title, fontsize=const.FONT_SIZE)
1132
 
1133
  if save_file:
1134
  plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show
1135
- if not const.REMOTE:
1136
  plt.show()
1137
  else:
1138
  plt.close()
@@ -1149,3 +1176,66 @@ def _square_3d_plot(X, Y, Z, ax):
1149
  ax.set_ylim(mid_y - max_range, mid_y + max_range)
1150
  ax.set_zlim(mid_z - max_range, mid_z + max_range)
1151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ #matplotlib.use('TkAgg')
3
  import matplotlib.pyplot as plt
4
  from mpl_toolkits.mplot3d import Axes3D
5
  import matplotlib.colors as mcolors
 
8
  from mpl_toolkits.axes_grid1 import make_axes_locatable
9
  import tensorflow as tf
10
  import numpy as np
11
+ import DeepDeformationMapRegistration.utils.constants as C
12
  from skimage.exposure import rescale_intensity
13
  import scipy.misc as scpmisc
14
  import os
 
175
  fig.clear()
176
  plt.figure(fig.number)
177
  else:
178
+ fig = plt.figure(dpi=C.DPI)
179
 
180
  ax_fix = fig.add_subplot(231)
181
  im_fix = ax_fix.imshow(list_imgs[0][:, :, 0])
182
+ ax_fix.set_title('Fix image', fontsize=C.FONT_SIZE)
183
  ax_fix.tick_params(axis='both',
184
  which='both',
185
  bottom=False,
 
188
  labelbottom=False)
189
  ax_mov = fig.add_subplot(232)
190
  im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
191
+ ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE)
192
  ax_mov.tick_params(axis='both',
193
  which='both',
194
  bottom=False,
 
198
 
199
  ax_pred_im = fig.add_subplot(233)
200
  im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
201
+ ax_pred_im.set_title('Prediction', fontsize=C.FONT_SIZE)
202
  ax_pred_im.tick_params(axis='both',
203
  which='both',
204
  bottom=False,
 
228
  else:
229
  cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3])
230
  im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
231
+ ax_pred_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
232
+ ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE)
233
  ax_pred_disp.tick_params(axis='both',
234
  which='both',
235
  bottom=False,
 
259
  else:
260
  cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4])
261
  im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal')
262
+ ax_gt_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
263
+ ax_gt_disp.set_title('GT disp map', fontsize=C.FONT_SIZE)
264
  ax_gt_disp.tick_params(axis='both',
265
  which='both',
266
  bottom=False,
 
276
 
277
  if filename is not None:
278
  plt.savefig(filename, format='png') # Call before show
279
+ if not C.REMOTE:
280
  plt.show()
281
  else:
282
  plt.close()
 
288
  fig.clear()
289
  plt.figure(fig.number)
290
  else:
291
+ fig = plt.figure(dpi=C.DPI)
292
 
293
  dim = len(img.shape[:-1])
294
 
 
321
  plt.close()
322
 
323
 
324
+ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False):
325
  if fig is not None:
326
  fig.clear()
327
  plt.figure(fig.number)
328
  else:
329
+ fig = plt.figure(dpi=C.DPI)
330
+ dim_h, dim_w, dim_d = disp_map.shape[1:-1]
331
  dim = disp_map.shape[-1]
332
 
333
  if dim == 2:
334
  ax_x = fig.add_subplot(131)
335
  ax_x.set_title('H displacement')
336
+ im_x = ax_x.imshow(disp_map[..., C.H_DISP])
337
  ax_x.tick_params(axis='both',
338
  which='both',
339
  bottom=False,
 
344
 
345
  ax_y = fig.add_subplot(132)
346
  ax_y.set_title('W displacement')
347
+ im_y = ax_y.imshow(disp_map[..., C.W_DISP])
348
  ax_y.tick_params(axis='both',
349
  which='both',
350
  bottom=False,
 
373
  else:
374
  c, d, s = _prepare_quiver_map(disp_map, dim=dim)
375
  im = ax.imshow(s, interpolation='none', aspect='equal')
376
+ ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
377
+ scale=C.QUIVER_PARAMS.arrow_scale)
378
  cb = _set_colorbar(fig, ax, im, False)
379
  ax.set_title('Displacement map')
380
  ax.tick_params(axis='both',
 
387
  else:
388
  ax = fig.add_subplot(111, projection='3d')
389
  c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
390
+ ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
391
+ _square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
 
392
  fig.suptitle('Displacement map')
393
  ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html
394
  which='both',
 
396
  left=False,
397
  labelleft=False,
398
  labelbottom=False)
399
+ add_axes_arrows_3d(ax, xyz_label=['R', 'A', 'S'])
400
  fig.suptitle(title)
401
 
402
  plt.savefig(filename, format='png')
403
+ if show:
404
+ plt.show()
405
  plt.close()
406
+ return fig
407
 
408
 
409
  def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None,
 
412
  fig.clear()
413
  plt.figure(fig.number)
414
  else:
415
+ fig = plt.figure(dpi=C.DPI)
416
 
417
  dim = len(list_imgs[0].shape[:-1])
418
 
419
  if dim == 2:
420
  # TRAINING
421
  ax_input = fig.add_subplot(241)
422
+ ax_input.set_ylabel(title_first_row, fontsize=C.FONT_SIZE)
423
  im_fix = ax_input.imshow(list_imgs[0][:, :, 0])
424
+ ax_input.set_title('Fix image', fontsize=C.FONT_SIZE)
425
  ax_input.tick_params(axis='both',
426
  which='both',
427
  bottom=False,
 
430
  labelbottom=False)
431
  ax_mov = fig.add_subplot(242)
432
  im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
433
+ ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE)
434
  ax_mov.tick_params(axis='both',
435
  which='both',
436
  bottom=False,
 
440
 
441
  ax_pred_im = fig.add_subplot(244)
442
  im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
443
+ ax_pred_im.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
444
  ax_pred_im.tick_params(axis='both',
445
  which='both',
446
  bottom=False,
 
470
  else:
471
  cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim)
472
  im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
473
+ ax_pred_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
474
+ ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE)
475
  ax_pred_disp.tick_params(axis='both',
476
  which='both',
477
  bottom=False,
 
481
 
482
  # VALIDATION
483
  axinput_val = fig.add_subplot(245)
484
+ axinput_val.set_ylabel(title_second_row, fontsize=C.FONT_SIZE)
485
  im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0])
486
+ axinput_val.set_title('Fix image', fontsize=C.FONT_SIZE)
487
  axinput_val.tick_params(axis='both',
488
  which='both',
489
  bottom=False,
 
492
  labelbottom=False)
493
  ax_mov_val = fig.add_subplot(246)
494
  im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0])
495
+ ax_mov_val.set_title('Moving image', fontsize=C.FONT_SIZE)
496
  ax_mov_val.tick_params(axis='both',
497
  which='both',
498
  bottom=False,
 
502
 
503
  ax_pred_im_val = fig.add_subplot(248)
504
  im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0])
505
+ ax_pred_im_val.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
506
  ax_pred_im_val.tick_params(axis='both',
507
  which='both',
508
  bottom=False,
 
532
  else:
533
  c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
534
  im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
535
+ ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=C.QUIVER_PARAMS.arrow_scale)
536
+ ax_pred_disp_val.set_title('Pred disp map', fontsize=C.FONT_SIZE)
537
  ax_pred_disp_val.tick_params(axis='both',
538
  which='both',
539
  bottom=False,
 
555
  # 3D
556
  # TRAINING
557
  ax_input = fig.add_subplot(231, projection='3d')
558
+ ax_input.set_ylabel(title_first_row, fontsize=C.FONT_SIZE)
559
  im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
560
  im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
561
+ ax_input.set_title('Fix image', fontsize=C.FONT_SIZE)
562
  ax_input.tick_params(axis='both',
563
  which='both',
564
  bottom=False,
 
569
  ax_pred_im = fig.add_subplot(232, projection='3d')
570
  im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction')
571
  im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
572
+ ax_pred_im.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
573
  ax_pred_im.tick_params(axis='both',
574
  which='both',
575
  bottom=False,
 
581
 
582
  c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim)
583
  im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
584
+ ax_pred_disp.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP],
585
+ d[C.H_DISP], d[C.W_DISP], d[C.D_DISP], scale=C.QUIVER_PARAMS.arrow_scale)
586
+ ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE)
587
  ax_pred_disp.tick_params(axis='both',
588
  which='both',
589
  bottom=False,
 
593
 
594
  # VALIDATION
595
  axinput_val = fig.add_subplot(234, projection='3d')
596
+ axinput_val.set_ylabel(title_second_row, fontsize=C.FONT_SIZE)
597
  im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
598
  im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)')
599
+ axinput_val.set_title('Fix image', fontsize=C.FONT_SIZE)
600
  axinput_val.tick_params(axis='both',
601
  which='both',
602
  bottom=False,
 
607
  ax_pred_im_val = fig.add_subplot(235, projection='3d')
608
  im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)')
609
  im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
610
+ ax_pred_im_val.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
611
  ax_pred_im_val.tick_params(axis='both',
612
  which='both',
613
  bottom=False,
 
618
  ax_pred_disp_val = fig.add_subplot(236, projection='3d')
619
  c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
620
  im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
621
+ ax_pred_disp_val.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP],
622
+ d[C.H_DISP], d[C.W_DISP], d[C.D_DISP],
623
+ scale=C.QUIVER_PARAMS.arrow_scale)
624
+ ax_pred_disp_val.set_title('Pred disp map', fontsize=C.FONT_SIZE)
625
  ax_pred_disp_val.tick_params(axis='both',
626
  which='both',
627
  bottom=False,
 
631
 
632
  if filename is not None:
633
  plt.savefig(filename, format='png') # Call before show
634
+ if not C.REMOTE:
635
  plt.show()
636
  else:
637
  plt.close()
 
647
  return im_cb
648
 
649
 
650
+ def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=C.QUIVER_PARAMS.spacing):
651
  if isinstance(disp_map, tf.Tensor):
652
  if tf.executing_eagerly():
653
  disp_map = disp_map.numpy()
654
  else:
655
  disp_map = disp_map.eval()
656
+ dx = disp_map[..., C.H_DISP]
657
+ dy = disp_map[..., C.W_DISP]
658
  if dim > 2:
659
+ dz = disp_map[..., C.D_DISP]
660
 
661
+ img_size_x = disp_map.shape[C.H_DISP]
662
+ img_size_y = disp_map.shape[C.W_DISP]
663
  if dim > 2:
664
+ img_size_z = disp_map.shape[C.D_DISP]
665
 
666
  if dim > 2:
667
  s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz))
 
731
 
732
  if filename is not None:
733
  plt.savefig(filename, format='png') # Call before show
734
+ if not C.REMOTE:
735
  plt.show()
736
  else:
737
  plt.close()
 
810
  return fig
811
 
812
 
813
+ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None, show=False):
814
  num_rows = fix_img_batch.shape[0]
815
+ img_dim = len(fix_img_batch.shape) - 2
816
+ img_size = fix_img_batch.shape[1:-1]
817
  if fig is not None:
818
  fig.clear()
819
  plt.figure(fig.number)
820
+ ax = fig.add_subplot(nrows=num_rows, ncols=5, figsize=(10, 3*num_rows), dpi=C.DPI)
821
  else:
822
+ fig, ax = plt.subplots(nrows=num_rows, ncols=5, figsize=(10, 3*num_rows), dpi=C.DPI)
823
+ if num_rows == 1:
824
+ ax = ax[np.newaxis, ...]
825
+
826
+ if img_dim == 3: # Extract slices from the images
827
+ selected_slice = img_size[0] // 2
828
+ fix_img_batch = fix_img_batch[:, selected_slice, ...]
829
+ mov_img_batch = mov_img_batch[:, selected_slice, ...]
830
+ pred_img_batch = pred_img_batch[:, selected_slice, ...]
831
+ disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
832
+ img_size = fix_img_batch.shape[1:-1]
833
+ elif img_dim != 2:
834
+ raise ValueError('Images have a bad shape: {}'.format(fix_img_batch.shape))
835
 
836
  for row in range(num_rows):
837
+ fix_img = fix_img_batch[row, :, :, 0].transpose()
838
+ mov_img = mov_img_batch[row, :, :, 0].transpose()
839
+ disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
840
+ pred_img = pred_img_batch[row, :, :, 0].transpose()
841
+ ax[row, 0].imshow(fix_img, origin='lower')
842
  ax[row, 0].tick_params(axis='both',
843
  which='both',
844
  bottom=False,
845
  left=False,
846
  labelleft=False,
847
  labelbottom=False)
848
+ ax[row, 1].imshow(mov_img, origin='lower')
849
  ax[row, 1].tick_params(axis='both',
850
  which='both',
851
  bottom=False,
 
853
  labelleft=False,
854
  labelbottom=False)
855
 
856
+ c, d, s = _prepare_quiver_map(disp_map, spc=5)
857
+ cx, cy = c
858
+ dx, dy = d
859
  disp_map_color = _prepare_colormap(disp_map)
860
+ ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal', origin='lower')
861
+ ax[row, 2].quiver(cx, cy, dx, dy, units='dots', scale=1)
 
862
  ax[row, 2].tick_params(axis='both',
863
  which='both',
864
  bottom=False,
 
866
  labelleft=False,
867
  labelbottom=False)
868
 
869
+ ax[row, 3].imshow(mov_img, origin='lower')
870
+ ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
871
  ax[row, 3].tick_params(axis='both',
872
  which='both',
873
  bottom=False,
 
875
  labelleft=False,
876
  labelbottom=False)
877
 
878
+ ax[row, 4].imshow(pred_img, origin='lower')
879
+ ax[row, 4].tick_params(axis='both',
880
+ which='both',
881
+ bottom=False,
882
+ left=False,
883
+ labelleft=False,
884
+ labelbottom=False)
885
 
886
+ plt.axis('off')
887
+ ax[0, 0].set_title('Fixed img ($I_f$)', fontsize=C.FONT_SIZE)
888
+ ax[0, 1].set_title('Moving img ($I_m$)', fontsize=C.FONT_SIZE)
889
+ ax[0, 2].set_title('Backwards\ndisp .map ($\delta$)', fontsize=C.FONT_SIZE)
890
+ ax[0, 3].set_title('Disp. map over $I_m$', fontsize=C.FONT_SIZE)
891
+ ax[0, 4].set_title('Predicted $I_m$', fontsize=C.FONT_SIZE)
892
+ plt.tight_layout()
893
  if filename is not None:
894
  plt.savefig(filename, format='png') # Call before show
895
+ if show:
896
  plt.show()
897
+ plt.close()
 
898
  return fig
899
 
900
 
 
903
  fig.clear()
904
  plt.figure(fig.number)
905
  else:
906
+ fig = plt.figure(dpi=C.DPI)
907
 
908
  ax0 = fig.add_subplot(221)
909
  im0 = ax0.imshow(fix_img[..., 0])
 
927
  ax2 = fig.add_subplot(223)
928
  im2 = ax2.imshow(s, interpolation='none', aspect='equal')
929
 
930
+ ax2.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
931
  # ax2.figure.set_size_inches(img_size)
932
  ax2.tick_params(axis='both',
933
  which='both',
 
939
  ax3 = fig.add_subplot(224)
940
  dif = fix_img[..., 0] - mov_img[..., 0]
941
  im3 = ax3.imshow(dif)
942
+ ax3.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
943
  ax3.tick_params(axis='both',
944
  which='both',
945
  bottom=False,
 
948
  labelbottom=False)
949
 
950
  plt.axis('off')
951
+ ax0.set_title('Fixed img ($I_f$)', fontsize=C.FONT_SIZE)
952
+ ax1.set_title('Moving img ($I_m$)', fontsize=C.FONT_SIZE)
953
+ ax2.set_title('Displacement map', fontsize=C.FONT_SIZE)
954
+ ax3.set_title('Fix - Mov', fontsize=C.FONT_SIZE)
955
 
956
  im0_cb = _set_colorbar(fig, ax0, im0, False)
957
  im1_cb = _set_colorbar(fig, ax1, im1, False)
 
960
 
961
  if filename is not None:
962
  plt.savefig(filename, format='png') # Call before show
963
+ if not C.REMOTE:
964
  plt.show()
965
  else:
966
  plt.close()
 
977
  fig = plt.figure()
978
 
979
  ax_grid = fig.add_subplot(231)
980
+ ax_grid.set_title('Grids', fontsize=C.FONT_SIZE)
981
  ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20)
982
  ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1)
983
  ax_grid.tick_params(axis='both',
 
993
  ax_grid.set_aspect('equal')
994
 
995
  ax_disp = fig.add_subplot(232)
996
+ ax_disp.set_title('Displacement map', fontsize=C.FONT_SIZE)
997
  cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
998
  ax_disp.imshow(s, interpolation='none', aspect='equal')
999
  ax_disp.tick_params(axis='both',
 
1004
  labelbottom=False)
1005
 
1006
  ax_mask = fig.add_subplot(233)
1007
+ ax_mask.set_title('Mask', fontsize=C.FONT_SIZE)
1008
  ax_mask.imshow(mask)
1009
  ax_mask.tick_params(axis='both',
1010
  which='both',
 
1014
  labelbottom=False)
1015
 
1016
  ax_fix = fig.add_subplot(234)
1017
+ ax_fix.set_title('Fix image', fontsize=C.FONT_SIZE)
1018
  ax_fix.imshow(fix_img[..., 0])
1019
  ax_fix.tick_params(axis='both',
1020
  which='both',
 
1024
  labelbottom=False)
1025
 
1026
  ax_mov = fig.add_subplot(235)
1027
+ ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE)
1028
  ax_mov.imshow(mov_img[..., 0])
1029
  ax_mov.tick_params(axis='both',
1030
  which='both',
 
1034
  labelbottom=False)
1035
 
1036
  ax_dif = fig.add_subplot(236)
1037
+ ax_dif.set_title('Fix - Moving image', fontsize=C.FONT_SIZE)
1038
  ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
1039
  ax_dif.tick_params(axis='both',
1040
  which='both',
 
1050
 
1051
  if filename is not None:
1052
  plt.savefig(filename, format='png') # Call before show
1053
+ if not C.REMOTE:
1054
  plt.show()
1055
 
1056
  return fig
 
1064
  fig = plt.figure()
1065
 
1066
  ax_d_m_f = fig.add_subplot(131)
1067
+ ax_d_m_f.set_title('Disp M->F', fontsize=C.FONT_SIZE)
1068
  cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f)
1069
  ax_d_m_f.imshow(s, interpolation='none', aspect='equal')
1070
+ ax_d_m_f.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
1071
  ax_d_m_f.tick_params(axis='both',
1072
  which='both',
1073
  bottom=False,
 
1076
  labelbottom=False)
1077
 
1078
  ax_d_f_m = fig.add_subplot(132)
1079
+ ax_d_f_m.set_title('Disp F->M', fontsize=C.FONT_SIZE)
1080
  cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m)
1081
+ ax_d_f_m.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
1082
  ax_d_f_m.imshow(s, interpolation='none', aspect='equal')
1083
  ax_d_f_m.tick_params(axis='both',
1084
  which='both',
 
1088
  labelbottom=False)
1089
 
1090
  ax_dif = fig.add_subplot(133)
1091
+ ax_dif.set_title('Fix - Moving image', fontsize=C.FONT_SIZE)
1092
  ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
1093
  ax_dif.tick_params(axis='both',
1094
  which='both',
 
1105
 
1106
  if filename is not None:
1107
  plt.savefig(filename, format='png') # Call before show
1108
+ if not C.REMOTE:
1109
  plt.show()
1110
  else:
1111
  plt.close()
 
1131
  fig.tight_layout(pad=5.0)
1132
  ax = fig.add_subplot(2, num_cols, 1, projection='3d')
1133
  ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
1134
+ ax.set_title('Fix image', fontsize=C.FONT_SIZE)
1135
  _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1136
 
1137
  for i in range(2, num_preds+2):
 
1143
 
1144
  ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d')
1145
  ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
1146
+ ax.set_title('Fix image', fontsize=C.FONT_SIZE)
1147
  _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1148
 
1149
  for i in range(num_preds+2, 2 * num_preds + 2):
1150
  ax = fig.add_subplot(2, num_cols, i + 1, projection='3d')
1151
  c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3)
1152
+ ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP],
1153
+ d[C.H_DISP], d[C.W_DISP], d[C.D_DISP],
1154
  norm=True)
1155
  ax.set_title('Disp. #{}'.format(i - 5))
1156
  _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1157
 
1158
+ fig.suptitle(fig_title, fontsize=C.FONT_SIZE)
1159
 
1160
  if save_file:
1161
  plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show
1162
+ if not C.REMOTE:
1163
  plt.show()
1164
  else:
1165
  plt.close()
 
1176
  ax.set_ylim(mid_y - max_range, mid_y + max_range)
1177
  ax.set_zlim(mid_z - max_range, mid_z + max_range)
1178
 
1179
+
1180
+ def remove_tick_labels(ax, project_3d=False):
1181
+ ax.set_xticklabels([])
1182
+ ax.set_yticklabels([])
1183
+ if project_3d:
1184
+ ax.set_zticklabels([])
1185
+ return ax
1186
+
1187
+
1188
+ def add_axes_arrows_3d(ax, arrow_length=10, xyz_colours=['r', 'g', 'b'], xyz_label=['X', 'Y', 'Z'], dist_arrow_text=3):
1189
+ x_limits = ax.get_xlim3d()
1190
+ y_limits = ax.get_ylim3d()
1191
+ z_limits = ax.get_zlim3d()
1192
+
1193
+ x_len = x_limits[1] - x_limits[0]
1194
+ y_len = y_limits[1] - y_limits[0]
1195
+ z_len = z_limits[1] - z_limits[0]
1196
+
1197
+ ax.quiver(x_limits[0], y_limits[0], z_limits[0], x_len, 0, 0, color=xyz_colours[0], arrow_length_ratio=0) # (init_loc, end_loc, params)
1198
+ ax.quiver(x_limits[0], y_limits[0], z_limits[0], 0, y_len, 0, color=xyz_colours[1], arrow_length_ratio=0)
1199
+ ax.quiver(x_limits[0], y_limits[0], z_limits[0], 0, 0, z_len, color=xyz_colours[2], arrow_length_ratio=0)
1200
+
1201
+ # X axis
1202
+ ax.quiver(x_limits[1], y_limits[0], z_limits[0], arrow_length, 0, 0, color=xyz_colours[0])
1203
+ ax.text(x_limits[1] + arrow_length + dist_arrow_text, y_limits[0], z_limits[0], xyz_label[0], fontsize=20, ha='right', va='top')
1204
+
1205
+ # Y axis
1206
+ ax.quiver(x_limits[0], y_limits[1], z_limits[0], 0, arrow_length, 0, color=xyz_colours[1])
1207
+ ax.text(x_limits[0], y_limits[1] + arrow_length + dist_arrow_text, z_limits[0], xyz_label[0], fontsize=20, ha='left', va='top')
1208
+
1209
+ # Z axis
1210
+ ax.quiver(x_limits[0], y_limits[0], z_limits[1], 0, 0, arrow_length, color=xyz_colours[2])
1211
+ ax.text(x_limits[0], y_limits[0], z_limits[1] + arrow_length + dist_arrow_text, xyz_label[0], fontsize=20, ha='center', va='bottom')
1212
+
1213
+ return ax
1214
+
1215
+
1216
+ def add_axes_arrows_2d(ax, arrow_length=10, xy_colour=['r', 'g'], xy_label=['X', 'Y']):
1217
+ x_limits = list(ax.get_xlim())
1218
+ y_limits = list(ax.get_ylim())
1219
+ origin = (x_limits[0], y_limits[1])
1220
+
1221
+ ax.annotate('', xy=(origin[0] + arrow_length, origin[1]), xytext=origin,
1222
+ arrowprops=dict(headlength=8., headwidth=10., width=5., color=xy_colour[0]))
1223
+ ax.annotate('', xy=(origin[0], origin[1] + arrow_length), xytext=origin,
1224
+ arrowprops=dict(headlength=8., headwidth=10., width=5., color=xy_colour[0]))
1225
+
1226
+ ax.text(origin[0] + arrow_length, origin[1], xy_label[0], fontsize=25, ha='left', va='bottom')
1227
+ ax.text(origin[0] - 1, origin[1] + arrow_length, xy_label[1], fontsize=25, ha='right', va='top')
1228
+
1229
+ return ax
1230
+
1231
+
1232
+ def set_axes_size(w,h, ax=None):
1233
+ """ w, h: width, height in inches """
1234
+ if not ax: ax=plt.gca()
1235
+ l = ax.figure.subplotpars.left
1236
+ r = ax.figure.subplotpars.right
1237
+ t = ax.figure.subplotpars.top
1238
+ b = ax.figure.subplotpars.bottom
1239
+ figw = float(w)/(r-l)
1240
+ figh = float(h)/(t-b)
1241
+ ax.figure.set_size_inches(figw, figh)