Commit
·
74c6a32
1
Parent(s):
f42fb70
Update DeepDeformationMapRegistration package
Browse files- DeepDeformationMapRegistration/callbacks.py +15 -0
- DeepDeformationMapRegistration/data_generator.py +432 -21
- DeepDeformationMapRegistration/layers/__init__.py +5 -0
- DeepDeformationMapRegistration/layers/augmentation.py +326 -0
- DeepDeformationMapRegistration/layers/b_splines.py +320 -0
- DeepDeformationMapRegistration/layers/depthwise_conv_3d.py +264 -0
- DeepDeformationMapRegistration/layers/uncertainty_weighting.py +347 -0
- DeepDeformationMapRegistration/layers/upsampling.py +761 -0
- DeepDeformationMapRegistration/losses.py +762 -47
- DeepDeformationMapRegistration/ms_ssim_tf.py +642 -0
- DeepDeformationMapRegistration/networks.py +35 -14
- DeepDeformationMapRegistration/utils/acummulated_optimizer.py +138 -3
- DeepDeformationMapRegistration/utils/constants.py +28 -3
- DeepDeformationMapRegistration/utils/misc.py +127 -2
- DeepDeformationMapRegistration/utils/nifti_utils.py +43 -0
- DeepDeformationMapRegistration/utils/operators.py +39 -2
- DeepDeformationMapRegistration/utils/thin_plate_splines.py +154 -0
- DeepDeformationMapRegistration/utils/visualization.py +197 -107
DeepDeformationMapRegistration/callbacks.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import tensorflow as tf
|
2 |
import tensorflow.keras.backend as K
|
|
|
|
|
3 |
|
4 |
|
5 |
class RollingAverageWeighting(tf.keras.callbacks.Callback):
|
@@ -43,3 +45,16 @@ class RollingAverageWeighting(tf.keras.callbacks.Callback):
|
|
43 |
for name, val in zip(self.loss_weights.keys(), new_weights):
|
44 |
out_str += '{}: {:7.2f}\t'.format(name, val)
|
45 |
print('WEIGHTS UPDATE: ' + out_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import tensorflow as tf
|
2 |
import tensorflow.keras.backend as K
|
3 |
+
import logging
|
4 |
+
import numpy as np
|
5 |
|
6 |
|
7 |
class RollingAverageWeighting(tf.keras.callbacks.Callback):
|
|
|
45 |
for name, val in zip(self.loss_weights.keys(), new_weights):
|
46 |
out_str += '{}: {:7.2f}\t'.format(name, val)
|
47 |
print('WEIGHTS UPDATE: ' + out_str)
|
48 |
+
|
49 |
+
|
50 |
+
class UncertaintyWeightingRollingAverageCallback(tf.keras.callbacks.Callback):
|
51 |
+
def __init__(self, method, epoch_update):
|
52 |
+
super(UncertaintyWeightingRollingAverageCallback, self).__init__()
|
53 |
+
self.method = method
|
54 |
+
self.epoch_update = epoch_update
|
55 |
+
|
56 |
+
def on_epoch_end(self, epoch, logs=None):
|
57 |
+
if epoch > self.epoch_update:
|
58 |
+
self.method()
|
59 |
+
print('Calling method: '+self.method.__name__)
|
60 |
+
|
DeepDeformationMapRegistration/data_generator.py
CHANGED
@@ -4,9 +4,16 @@ import os
|
|
4 |
import h5py
|
5 |
import random
|
6 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
import DeepDeformationMapRegistration.utils.constants as C
|
9 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
|
|
|
10 |
|
11 |
|
12 |
class DataGeneratorManager(keras.utils.Sequence):
|
@@ -113,7 +120,7 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
113 |
file_list.sort()
|
114 |
for data_file in files:
|
115 |
file_name, extension = os.path.splitext(data_file)
|
116 |
-
if extension.lower() == '.hd5':
|
117 |
file_list.append(os.path.join(root, data_file))
|
118 |
|
119 |
if not file_list:
|
@@ -230,12 +237,7 @@ class DataGenerator(DataGeneratorManager):
|
|
230 |
ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
|
231 |
return ret_list
|
232 |
|
233 |
-
def
|
234 |
-
"""
|
235 |
-
Generate one batch of data
|
236 |
-
:param index: epoch index
|
237 |
-
:return:
|
238 |
-
"""
|
239 |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
240 |
|
241 |
data_dict = self.__load_data(idxs)
|
@@ -248,6 +250,14 @@ class DataGenerator(DataGeneratorManager):
|
|
248 |
|
249 |
return (inputs, outputs)
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
def next_batch(self):
|
252 |
if self.__last_batch > self.__batches_per_epoch:
|
253 |
raise ValueError('No more batches for this epoch')
|
@@ -259,15 +269,15 @@ class DataGenerator(DataGeneratorManager):
|
|
259 |
if label in self.__input_labels or label in self.__output_labels:
|
260 |
# To avoid extra overhead
|
261 |
try:
|
262 |
-
retVal = data_file[label][:]
|
263 |
except KeyError:
|
264 |
# That particular label is not found in the file. But this should be known by the user by now
|
265 |
retVal = None
|
266 |
|
267 |
if append_array is not None and retVal is not None:
|
268 |
-
return np.append(append_array,
|
269 |
elif append_array is None:
|
270 |
-
return retVal
|
271 |
else:
|
272 |
return retVal # None
|
273 |
else:
|
@@ -280,19 +290,22 @@ class DataGenerator(DataGeneratorManager):
|
|
280 |
:return:
|
281 |
"""
|
282 |
if isinstance(idx_list, (list, np.ndarray)):
|
283 |
-
fix_img = np.empty((0, ) + C.IMG_SHAPE)
|
284 |
-
mov_img = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
|
|
|
|
285 |
|
286 |
-
|
287 |
-
|
288 |
|
289 |
-
|
290 |
-
|
291 |
|
292 |
-
|
293 |
-
mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
294 |
|
295 |
-
|
|
|
296 |
|
297 |
for idx in idx_list:
|
298 |
data_file = h5py.File(self.__list_files[idx], 'r')
|
@@ -306,11 +319,14 @@ class DataGenerator(DataGeneratorManager):
|
|
306 |
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
|
307 |
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
|
308 |
|
309 |
-
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK,
|
310 |
-
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK,
|
311 |
|
312 |
disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
|
313 |
|
|
|
|
|
|
|
314 |
data_file.close()
|
315 |
batch_size = len(idx_list)
|
316 |
else:
|
@@ -330,6 +346,9 @@ class DataGenerator(DataGeneratorManager):
|
|
330 |
|
331 |
disp_map = self.__try_load(data_file, C.H5_GT_DISP)
|
332 |
|
|
|
|
|
|
|
333 |
data_file.close()
|
334 |
batch_size = 1
|
335 |
|
@@ -342,11 +361,61 @@ class DataGenerator(DataGeneratorManager):
|
|
342 |
C.H5_MOV_VESSELS_MASK: mov_vessels,
|
343 |
C.H5_MOV_PARENCHYMA_MASK: mov_parench,
|
344 |
C.H5_GT_DISP: disp_map,
|
|
|
|
|
345 |
'BATCH_SIZE': batch_size
|
346 |
}
|
347 |
|
348 |
return data_dict
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
def get_samples(self, num_samples, random=False):
|
351 |
if random:
|
352 |
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
@@ -509,3 +578,345 @@ class DataGenerator2D(keras.utils.Sequence):
|
|
509 |
return mov, fix
|
510 |
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import h5py
|
5 |
import random
|
6 |
from PIL import Image
|
7 |
+
import nibabel as nib
|
8 |
+
from nilearn.image import resample_img
|
9 |
+
from skimage.exposure import equalize_adapthist
|
10 |
+
from scipy.ndimage import zoom
|
11 |
+
import tensorflow as tf
|
12 |
|
13 |
import DeepDeformationMapRegistration.utils.constants as C
|
14 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
15 |
+
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
16 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
17 |
|
18 |
|
19 |
class DataGeneratorManager(keras.utils.Sequence):
|
|
|
120 |
file_list.sort()
|
121 |
for data_file in files:
|
122 |
file_name, extension = os.path.splitext(data_file)
|
123 |
+
if extension.lower() == '.hd5' or '.h5':
|
124 |
file_list.append(os.path.join(root, data_file))
|
125 |
|
126 |
if not file_list:
|
|
|
237 |
ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
|
238 |
return ret_list
|
239 |
|
240 |
+
def __getitem1(self, index):
|
|
|
|
|
|
|
|
|
|
|
241 |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
242 |
|
243 |
data_dict = self.__load_data(idxs)
|
|
|
250 |
|
251 |
return (inputs, outputs)
|
252 |
|
253 |
+
def __getitem__(self, index):
|
254 |
+
"""
|
255 |
+
Generate one batch of data
|
256 |
+
:param index: epoch index
|
257 |
+
:return:
|
258 |
+
"""
|
259 |
+
return self.__getitem2(index)
|
260 |
+
|
261 |
def next_batch(self):
|
262 |
if self.__last_batch > self.__batches_per_epoch:
|
263 |
raise ValueError('No more batches for this epoch')
|
|
|
269 |
if label in self.__input_labels or label in self.__output_labels:
|
270 |
# To avoid extra overhead
|
271 |
try:
|
272 |
+
retVal = data_file[label][:][np.newaxis, ...]
|
273 |
except KeyError:
|
274 |
# That particular label is not found in the file. But this should be known by the user by now
|
275 |
retVal = None
|
276 |
|
277 |
if append_array is not None and retVal is not None:
|
278 |
+
return np.append(append_array, retVal, axis=0)
|
279 |
elif append_array is None:
|
280 |
+
return retVal
|
281 |
else:
|
282 |
return retVal # None
|
283 |
else:
|
|
|
290 |
:return:
|
291 |
"""
|
292 |
if isinstance(idx_list, (list, np.ndarray)):
|
293 |
+
fix_img = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
294 |
+
mov_img = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
295 |
+
|
296 |
+
fix_parench = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
297 |
+
mov_parench = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
298 |
|
299 |
+
fix_vessels = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
300 |
+
mov_vessels = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
301 |
|
302 |
+
fix_tumors = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
303 |
+
mov_tumors = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
304 |
|
305 |
+
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE, np.float32)
|
|
|
306 |
|
307 |
+
fix_centroid = np.empty((0, 3))
|
308 |
+
mov_centroid = np.empty((0, 3))
|
309 |
|
310 |
for idx in idx_list:
|
311 |
data_file = h5py.File(self.__list_files[idx], 'r')
|
|
|
319 |
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
|
320 |
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
|
321 |
|
322 |
+
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, fix_tumors)
|
323 |
+
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_tumors)
|
324 |
|
325 |
disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
|
326 |
|
327 |
+
fix_centroid = self.__try_load(data_file, C.H5_FIX_CENTROID, fix_centroid)
|
328 |
+
mov_centroid = self.__try_load(data_file, C.H5_MOV_CENTROID, mov_centroid)
|
329 |
+
|
330 |
data_file.close()
|
331 |
batch_size = len(idx_list)
|
332 |
else:
|
|
|
346 |
|
347 |
disp_map = self.__try_load(data_file, C.H5_GT_DISP)
|
348 |
|
349 |
+
fix_centroid = self.__try_load(data_file, C.H5_FIX_CENTROID)
|
350 |
+
mov_centroid = self.__try_load(data_file, C.H5_MOV_CENTROID)
|
351 |
+
|
352 |
data_file.close()
|
353 |
batch_size = 1
|
354 |
|
|
|
361 |
C.H5_MOV_VESSELS_MASK: mov_vessels,
|
362 |
C.H5_MOV_PARENCHYMA_MASK: mov_parench,
|
363 |
C.H5_GT_DISP: disp_map,
|
364 |
+
C.H5_FIX_CENTROID: fix_centroid,
|
365 |
+
C.H5_MOV_CENTROID: mov_centroid,
|
366 |
'BATCH_SIZE': batch_size
|
367 |
}
|
368 |
|
369 |
return data_dict
|
370 |
|
371 |
+
@staticmethod
|
372 |
+
def __get_data_shape(file_path, label):
|
373 |
+
f = h5py.File(file_path, 'r')
|
374 |
+
shape = f[label][:].shape
|
375 |
+
f.close()
|
376 |
+
return shape
|
377 |
+
|
378 |
+
def __load_data_by_label(self, label, idx_list):
|
379 |
+
if isinstance(idx_list, (list, np.ndarray)):
|
380 |
+
data_shape = self.__get_data_shape(self.__list_files[idx_list[0]], label)
|
381 |
+
container = np.empty((0, *data_shape), np.float32)
|
382 |
+
# if label == C.H5_GT_DISP:
|
383 |
+
# container = np.empty((0, ) + C.DISP_MAP_SHAPE, np.float32)
|
384 |
+
# elif label == C.H5_MOV_CENTROID or label == C.H5_FIX_CENTROID:
|
385 |
+
# container = np.empty((0, 3), np.float32)
|
386 |
+
# else:
|
387 |
+
# container = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
388 |
+
|
389 |
+
for idx in idx_list:
|
390 |
+
data_file = h5py.File(self.__list_files[idx], 'r')
|
391 |
+
container = self.__try_load(data_file, label, container)
|
392 |
+
data_file.close()
|
393 |
+
else:
|
394 |
+
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
395 |
+
container = self.__try_load(data_file, label)
|
396 |
+
data_file.close()
|
397 |
+
|
398 |
+
return container
|
399 |
+
|
400 |
+
def __build_list2(self, label_list, file_idxs):
|
401 |
+
ret_list = list()
|
402 |
+
for label in label_list:
|
403 |
+
if label is C.DG_LBL_ZERO_GRADS:
|
404 |
+
aux = np.zeros([len(file_idxs), *C.DISP_MAP_SHAPE])
|
405 |
+
else:
|
406 |
+
aux = self.__load_data_by_label(label, file_idxs)
|
407 |
+
|
408 |
+
if label in [C.DG_LBL_MOV_IMG, C.DG_LBL_FIX_IMG]:
|
409 |
+
aux = min_max_norm(aux).astype(np.float32)
|
410 |
+
ret_list.append(aux)
|
411 |
+
return ret_list
|
412 |
+
|
413 |
+
def __getitem2(self, index):
|
414 |
+
f_indices = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
415 |
+
|
416 |
+
return self.__build_list2(self.__input_labels, f_indices), self.__build_list2(self.__output_labels, f_indices)
|
417 |
+
|
418 |
+
|
419 |
def get_samples(self, num_samples, random=False):
|
420 |
if random:
|
421 |
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
|
|
578 |
return mov, fix
|
579 |
|
580 |
|
581 |
+
FILE_EXT = {'nifti': '.nii.gz',
|
582 |
+
'h5': '.h5'}
|
583 |
+
CTRL_GRID = C.CoordinatesGrid()
|
584 |
+
CTRL_GRID.set_coords_grid([128]*3, [C.TPS_NUM_CTRL_PTS_PER_AXIS]*3, batches=False, norm=False, img_type=tf.float32)
|
585 |
+
|
586 |
+
FINE_GRID = C.CoordinatesGrid()
|
587 |
+
FINE_GRID.set_coords_grid([128]*3, [128]*3, batches=FINE_GRID, norm=False)
|
588 |
+
|
589 |
+
class DataGeneratorAugment(DataGeneratorManager):
|
590 |
+
def __init__(self, GeneratorManager: DataGeneratorManager, file_type='nifti', dataset_type='train'):
|
591 |
+
self.__complete_list_files = GeneratorManager.dataset_list_files
|
592 |
+
self.__list_files = [self.__complete_list_files[idx] for idx in GeneratorManager.get_generator_idxs(dataset_type)]
|
593 |
+
self.__batch_size = GeneratorManager.batch_size
|
594 |
+
self.__augm_per_sample = 10
|
595 |
+
self.__samples_per_batch = np.ceil(self.__batch_size / (self.__augm_per_sample + 1)) # B = S + S*A
|
596 |
+
self.__total_samples = len(self.__list_files)
|
597 |
+
self.__clip_range = GeneratorManager.clip_rage
|
598 |
+
self.__manager = GeneratorManager
|
599 |
+
self.__shuffle = GeneratorManager.shuffle
|
600 |
+
self.__file_extension = FILE_EXT[file_type]
|
601 |
+
|
602 |
+
self.__num_samples = len(self.__list_files)
|
603 |
+
self.__internal_idxs = np.arange(self.__num_samples)
|
604 |
+
# These indices are internal to the generator, they are not the same as the dataset_idxs!!
|
605 |
+
|
606 |
+
self.__dataset_type = dataset_type
|
607 |
+
|
608 |
+
self.__last_batch = 0
|
609 |
+
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
|
610 |
+
|
611 |
+
self.__input_labels = GeneratorManager.input_labels
|
612 |
+
self.__output_labels = GeneratorManager.output_labels
|
613 |
+
|
614 |
+
|
615 |
+
def __get_dataset_files(self, search_path):
|
616 |
+
"""
|
617 |
+
Get the path to the dataset files
|
618 |
+
:param search_path: dir path to search for the hd5 files
|
619 |
+
:return:
|
620 |
+
"""
|
621 |
+
file_list = list()
|
622 |
+
for root, dirs, files in os.walk(search_path):
|
623 |
+
for data_file in files:
|
624 |
+
file_name, extension = os.path.splitext(data_file)
|
625 |
+
if extension.lower() == self.__file_extension:
|
626 |
+
file_list.append(os.path.join(root, data_file))
|
627 |
+
|
628 |
+
if not file_list:
|
629 |
+
raise ValueError('No files found to train in ', search_path)
|
630 |
+
|
631 |
+
print('Found {} files in {}'.format(len(file_list), search_path))
|
632 |
+
return file_list
|
633 |
+
|
634 |
+
def update_samples(self, new_sample_idxs):
|
635 |
+
self.__list_files = [self.__complete_list_files[idx] for idx in new_sample_idxs]
|
636 |
+
self.__num_samples = len(self.__list_files)
|
637 |
+
self.__internal_idxs = np.arange(self.__num_samples)
|
638 |
+
|
639 |
+
def on_epoch_end(self):
|
640 |
+
"""
|
641 |
+
To be executed at the end of each epoch. Reshuffle the assigned samples
|
642 |
+
:return:
|
643 |
+
"""
|
644 |
+
if self.__shuffle:
|
645 |
+
random.shuffle(self.__internal_idxs)
|
646 |
+
self.__last_batch = 0
|
647 |
+
|
648 |
+
def __len__(self):
|
649 |
+
"""
|
650 |
+
Number of batches per epoch
|
651 |
+
:return:
|
652 |
+
"""
|
653 |
+
return self.__batches_per_epoch
|
654 |
+
|
655 |
+
def __getitem__(self, index):
|
656 |
+
"""
|
657 |
+
Generate one batch of data
|
658 |
+
:param index: epoch index
|
659 |
+
:return:
|
660 |
+
"""
|
661 |
+
return self.__getitem(index)
|
662 |
+
|
663 |
+
def next_batch(self):
|
664 |
+
if self.__last_batch > self.__batches_per_epoch:
|
665 |
+
raise ValueError('No more batches for this epoch')
|
666 |
+
batch = self.__getitem__(self.__last_batch)
|
667 |
+
self.__last_batch += 1
|
668 |
+
return batch
|
669 |
+
|
670 |
+
def __try_load(self, data_file, label, append_array=None):
|
671 |
+
if label in self.__input_labels or label in self.__output_labels:
|
672 |
+
# To avoid extra overhead
|
673 |
+
try:
|
674 |
+
retVal = data_file[label][:][np.newaxis, ...]
|
675 |
+
except KeyError:
|
676 |
+
# That particular label is not found in the file. But this should be known by the user by now
|
677 |
+
retVal = None
|
678 |
+
|
679 |
+
if append_array is not None and retVal is not None:
|
680 |
+
return np.append(append_array, retVal, axis=0)
|
681 |
+
elif append_array is None:
|
682 |
+
return retVal
|
683 |
+
else:
|
684 |
+
return retVal # None
|
685 |
+
else:
|
686 |
+
return None
|
687 |
+
|
688 |
+
@staticmethod
|
689 |
+
def __get_data_shape(file_path, label):
|
690 |
+
f = h5py.File(file_path, 'r')
|
691 |
+
shape = f[label][:].shape
|
692 |
+
f.close()
|
693 |
+
return shape
|
694 |
+
|
695 |
+
def __load_data_by_label(self, label, idx_list):
|
696 |
+
if isinstance(idx_list, (list, np.ndarray)):
|
697 |
+
data_shape = self.__get_data_shape(self.__list_files[idx_list[0]], label)
|
698 |
+
container = np.empty((0, *data_shape), np.float32)
|
699 |
+
# if label == C.H5_GT_DISP:
|
700 |
+
# container = np.empty((0, ) + C.DISP_MAP_SHAPE, np.float32)
|
701 |
+
# elif label == C.H5_MOV_CENTROID or label == C.H5_FIX_CENTROID:
|
702 |
+
# container = np.empty((0, 3), np.float32)
|
703 |
+
# else:
|
704 |
+
# container = np.empty((0, ) + C.IMG_SHAPE, np.float32)
|
705 |
+
|
706 |
+
for idx in idx_list:
|
707 |
+
data_file = h5py.File(self.__list_files[idx], 'r')
|
708 |
+
container = self.__try_load(data_file, label, container)
|
709 |
+
data_file.close()
|
710 |
+
else:
|
711 |
+
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
712 |
+
container = self.__try_load(data_file, label)
|
713 |
+
data_file.close()
|
714 |
+
|
715 |
+
return container
|
716 |
+
|
717 |
+
def __build_list(self, label_list, file_idxs):
|
718 |
+
ret_list = list()
|
719 |
+
for label in label_list:
|
720 |
+
if label is C.DG_LBL_ZERO_GRADS:
|
721 |
+
aux = np.zeros([len(file_idxs), *C.DISP_MAP_SHAPE])
|
722 |
+
else:
|
723 |
+
aux = self.__load_data_by_label(label, file_idxs)
|
724 |
+
|
725 |
+
if label in [C.DG_LBL_MOV_IMG, C.DG_LBL_FIX_IMG]:
|
726 |
+
aux = min_max_norm(aux).astype(np.float32)
|
727 |
+
ret_list.append(aux)
|
728 |
+
return ret_list
|
729 |
+
|
730 |
+
def __getitem(self, index):
|
731 |
+
f_indices = self.__internal_idxs[index * self.__samples_per_batch:(index + 1) * self.__samples_per_batch]
|
732 |
+
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
|
733 |
+
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
|
734 |
+
# The second element must match the outputs of the model, in this case (image, displacement map)
|
735 |
+
if 'h5' in self.__file_extension:
|
736 |
+
return self.__build_list(self.__input_labels, f_indices), self.__build_list(self.__output_labels, f_indices)
|
737 |
+
else:
|
738 |
+
f_list = [self.__list_files[i] for i in f_indices]
|
739 |
+
return self.__augment(f_list, 'fixed', C.H5_FIX_IMG), self.__augment(f_list, 'moving', C.H5_FIX_IMG)
|
740 |
+
|
741 |
+
|
742 |
+
def __intensity_preprocessing(self, img_data):
|
743 |
+
# Histogram normalization
|
744 |
+
processed_img = equalize_adapthist(img_data, clip_limit=0.03)
|
745 |
+
processed_img = min_max_norm(processed_img)
|
746 |
+
|
747 |
+
return processed_img
|
748 |
+
|
749 |
+
|
750 |
+
def __resize_img(self, img, output_shape):
|
751 |
+
if isinstance(output_shape, int):
|
752 |
+
output_shape = [output_shape] * len(img.shape)
|
753 |
+
# Resize
|
754 |
+
zoom_vals = np.asarray(output_shape) / np.asarray(img.shape)
|
755 |
+
return zoom(img, zoom_vals)
|
756 |
+
|
757 |
+
|
758 |
+
def __build_augmented_batch(self, f_list, mode):
|
759 |
+
for f_path in f_list:
|
760 |
+
h5_file = h5py.File(f_path, 'r')
|
761 |
+
img_nib = nib.load(h5_file[C.H5_FIX_IMG][:])
|
762 |
+
img_nib = resample_img(img_nib, np.eye(3))
|
763 |
+
try:
|
764 |
+
seg_nib = nib.load(h5_file[C.H5_FIX_SEGMENTATIONS][:])
|
765 |
+
seg_nib = resample_img(seg_nib, np.eye(3))
|
766 |
+
except FileNotFoundError:
|
767 |
+
seg_nib = None
|
768 |
+
|
769 |
+
img_nib = self.__intensity_preprocessing(img_nib)
|
770 |
+
img_nib = self.__resize_img(img_nib, 128)
|
771 |
+
|
772 |
+
|
773 |
+
|
774 |
+
|
775 |
+
|
776 |
+
|
777 |
+
|
778 |
+
|
779 |
+
def get_samples(self, num_samples, random=False):
|
780 |
+
return
|
781 |
+
|
782 |
+
def get_input_shape(self):
|
783 |
+
input_batch, _ = self.__getitem__(0)
|
784 |
+
data_dict = self.__load_data(0)
|
785 |
+
|
786 |
+
ret_val = data_dict[self.__input_labels[0]].shape
|
787 |
+
ret_val = (None, ) + ret_val[1:]
|
788 |
+
return ret_val # const.BATCH_SHAPE_SEGM
|
789 |
+
|
790 |
+
def who_are_you(self):
|
791 |
+
return self.__dataset_type
|
792 |
+
|
793 |
+
def print_datafiles(self):
|
794 |
+
return self.__list_files
|
795 |
+
|
796 |
+
|
797 |
+
def tf_graph_deform():
|
798 |
+
# Place holders
|
799 |
+
fix_img = tf.placeholder(tf.float32, [128]*3, 'fix_img')
|
800 |
+
fix_segmentations = tf.placeholder_with_default(np.zeros([128]*3), shape=[128]*3, name='fix_segmentations')
|
801 |
+
max_deformation = tf.placeholder(tf.float32, shape=(), name='max_deformation')
|
802 |
+
max_displacement = tf.placeholder(tf.float32, shape=(), name='max_displacement')
|
803 |
+
max_rotation = tf.placeholder(tf.float32, shape=(), name='max_rotation')
|
804 |
+
num_moved_points = tf.placeholder_with_default(50, shape=(), name='num_moved_points')
|
805 |
+
only_image = tf.placeholder_with_default(True, shape=(), name='only_image')
|
806 |
+
|
807 |
+
search_voxels = tf.cond(only_image,
|
808 |
+
lambda: fix_img,
|
809 |
+
lambda: fix_segmentations)
|
810 |
+
|
811 |
+
# Apply TPS deformation
|
812 |
+
# Get points in the segmentation or image, and add it to the control grid and target grid
|
813 |
+
# Indices of the points in the seaerch image with intensity greater than 0 (It would be bad if we only move the bg)
|
814 |
+
idx_points_in_label = tf.where(tf.greater(search_voxels, 0.0))
|
815 |
+
|
816 |
+
# Randomly select one of the points
|
817 |
+
random_idx = tf.random.uniform((num_moved_points,), minval=0, maxval=tf.shape(idx_points_in_label)[0], dtype=tf.int32)
|
818 |
+
|
819 |
+
disp_location = tf.gather_nd(idx_points_in_label, random_idx) # And get the coordinates
|
820 |
+
disp_location = tf.cast(disp_location, tf.float32)
|
821 |
+
# Get the coordinates of the control point displaces
|
822 |
+
rand_disp = tf.random.uniform((num_moved_points, 3), minval=-1, maxval=1, dtype=tf.float32) * max_deformation
|
823 |
+
warped_location = disp_location + rand_disp
|
824 |
+
|
825 |
+
# Add the selected locations to the control grid and the warped locations to the target grid
|
826 |
+
control_grid = tf.concat([CTRL_GRID.grid_flat(), disp_location], axis=0)
|
827 |
+
trg_grid = tf.concat([CTRL_GRID.grid_flat(), warped_location], axis=0)
|
828 |
+
|
829 |
+
# Add global affine transformation
|
830 |
+
trg_grid, aff = transform_points(trg_grid, max_displacement=max_displacement, max_rotation=max_rotation)
|
831 |
+
|
832 |
+
tps = ThinPlateSplines(control_grid, trg_grid)
|
833 |
+
def_grid = tps.interpolate(FINE_GRID.grid_flat())
|
834 |
+
|
835 |
+
disp_map = FINE_GRID.grid_flat() - def_grid
|
836 |
+
disp_map = tf.reshape(disp_map, (*FINE_GRID.shape, -1))
|
837 |
+
# disp_map = interpn(disp_map, FULL_FINE_GRID.grid)
|
838 |
+
|
839 |
+
# add the batch and channel dimensions
|
840 |
+
fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
|
841 |
+
fix_segmentations = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
|
842 |
+
disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
|
843 |
+
|
844 |
+
mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
|
845 |
+
mov_segmentations = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_segmentations, disp_map])
|
846 |
+
|
847 |
+
return tf.squeeze(mov_img),\
|
848 |
+
tf.squeeze(mov_segmentations),\
|
849 |
+
tf.squeeze(disp_map),\
|
850 |
+
disp_location,\
|
851 |
+
rand_disp,\
|
852 |
+
aff #, w, trg_grid, def_grid
|
853 |
+
|
854 |
+
|
855 |
+
def transform_points(points: tf.Tensor, max_displacement, max_rotation):
|
856 |
+
axis = tf.random.uniform((), 0, 3)
|
857 |
+
|
858 |
+
alpha = tf.cond(tf.less_equal(axis, 0.),
|
859 |
+
lambda: tf.random.uniform((1,), -max_rotation, max_rotation),
|
860 |
+
lambda: tf.zeros((1,), tf.float32))
|
861 |
+
beta = tf.cond(tf.less_equal(axis, 1.),
|
862 |
+
lambda: tf.random.uniform((1,), -max_rotation, max_rotation),
|
863 |
+
lambda: tf.zeros((1,), tf.float32))
|
864 |
+
gamma = tf.cond(tf.less_equal(axis, 2.),
|
865 |
+
lambda: tf.random.uniform((1,), -max_rotation, max_rotation),
|
866 |
+
lambda: tf.zeros((1,), tf.float32))
|
867 |
+
|
868 |
+
ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * max_displacement
|
869 |
+
tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * max_displacement
|
870 |
+
tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * max_displacement
|
871 |
+
|
872 |
+
M = build_affine_trf(tf.convert_to_tensor(FINE_GRID.shape, tf.float32), alpha, beta, gamma, ti, tj, tk)
|
873 |
+
if points.shape.as_list()[-1] == 3:
|
874 |
+
points = tf.transpose(points)
|
875 |
+
new_pts = tf.matmul(M[:3, :3], points)
|
876 |
+
new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
|
877 |
+
return tf.transpose(new_pts), M # Remove the last row of ones
|
878 |
+
|
879 |
+
|
880 |
+
def build_affine_trf(img_size, alpha, beta, gamma, ti, tj, tk):
|
881 |
+
img_centre = tf.expand_dims(tf.divide(img_size, 2.), -1)
|
882 |
+
|
883 |
+
# Rotation matrix around the image centre
|
884 |
+
# R* = T(p) R(ang) T(-p)
|
885 |
+
# tf.cos and tf.sin expect radians
|
886 |
+
zero = tf.zeros((1,))
|
887 |
+
one = tf.ones((1,))
|
888 |
+
|
889 |
+
T = tf.convert_to_tensor([[one, zero, zero, ti],
|
890 |
+
[zero, one, zero, tj],
|
891 |
+
[zero, zero, one, tk],
|
892 |
+
[zero, zero, zero, one]], tf.float32)
|
893 |
+
T = tf.squeeze(T)
|
894 |
+
|
895 |
+
R = tf.convert_to_tensor([[tf.math.cos(gamma) * tf.math.cos(beta),
|
896 |
+
tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.sin(alpha) - tf.math.sin(gamma) * tf.math.cos(alpha),
|
897 |
+
tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.cos(alpha) + tf.math.sin(gamma) * tf.math.sin(alpha),
|
898 |
+
zero],
|
899 |
+
[tf.math.sin(gamma) * tf.math.cos(beta),
|
900 |
+
tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.sin(gamma) + tf.math.cos(gamma) * tf.math.cos(alpha),
|
901 |
+
tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.cos(gamma) - tf.math.cos(gamma) * tf.math.sin(gamma),
|
902 |
+
zero],
|
903 |
+
[-tf.math.sin(beta),
|
904 |
+
tf.math.cos(beta) * tf.math.sin(alpha),
|
905 |
+
tf.math.cos(beta) * tf.math.cos(alpha),
|
906 |
+
zero],
|
907 |
+
[zero, zero, zero, one]], tf.float32)
|
908 |
+
|
909 |
+
R = tf.squeeze(R)
|
910 |
+
|
911 |
+
Tc = tf.convert_to_tensor([[one, zero, zero, img_centre[0]],
|
912 |
+
[zero, one, zero, img_centre[1]],
|
913 |
+
[zero, zero, one, img_centre[2]],
|
914 |
+
[zero, zero, zero, one]], tf.float32)
|
915 |
+
Tc = tf.squeeze(Tc)
|
916 |
+
Tc_ = tf.convert_to_tensor([[one, zero, zero, -img_centre[0]],
|
917 |
+
[zero, one, zero, -img_centre[1]],
|
918 |
+
[zero, zero, one, -img_centre[2]],
|
919 |
+
[zero, zero, zero, one]], tf.float32)
|
920 |
+
Tc_ = tf.squeeze(Tc_)
|
921 |
+
|
922 |
+
return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
|
DeepDeformationMapRegistration/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .augmentation import AugmentationLayer
|
2 |
+
from .uncertainty_weighting import UncertaintyWeighting, UncertaintyWeightingWithRollingAverage
|
3 |
+
from .upsampling import UpSampling1D, UpSampling2D, UpSampling3D
|
4 |
+
from .depthwise_conv_3d import DepthwiseConv3D
|
5 |
+
from .b_splines import interpolate_spline
|
DeepDeformationMapRegistration/layers/augmentation.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
8 |
+
|
9 |
+
import tensorflow.keras.layers as kl
|
10 |
+
import tensorflow as tf
|
11 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
12 |
+
|
13 |
+
from DeepDeformationMapRegistration.utils.operators import soft_threshold, gaussian_kernel, sample_unique
|
14 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
15 |
+
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
16 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
17 |
+
|
18 |
+
|
19 |
+
class AugmentationLayer(kl.Layer):
|
20 |
+
def __init__(self,
|
21 |
+
max_deformation,
|
22 |
+
max_displacement,
|
23 |
+
max_rotation,
|
24 |
+
num_control_points,
|
25 |
+
in_img_shape,
|
26 |
+
out_img_shape,
|
27 |
+
num_augmentations=1,
|
28 |
+
gamma_augmentation=True,
|
29 |
+
brightness_augmentation=True,
|
30 |
+
only_image=False,
|
31 |
+
only_resize=True,
|
32 |
+
return_displacement_map=False,
|
33 |
+
**kwargs):
|
34 |
+
super(AugmentationLayer, self).__init__(**kwargs)
|
35 |
+
|
36 |
+
self.max_deformation = max_deformation
|
37 |
+
self.max_displacement = max_displacement
|
38 |
+
self.max_rotation = max_rotation
|
39 |
+
self.num_control_points = num_control_points
|
40 |
+
self.num_augmentations = num_augmentations
|
41 |
+
self.in_img_shape = in_img_shape
|
42 |
+
self.out_img_shape = out_img_shape
|
43 |
+
self.only_image = only_image
|
44 |
+
self.return_disp_map = return_displacement_map
|
45 |
+
|
46 |
+
self.do_gamma_augm = gamma_augmentation
|
47 |
+
self.do_brightness_augm = brightness_augmentation
|
48 |
+
|
49 |
+
grid = C.CoordinatesGrid()
|
50 |
+
grid.set_coords_grid(in_img_shape, [C.TPS_NUM_CTRL_PTS_PER_AXIS] * 3)
|
51 |
+
self.control_grid = tf.identity(grid.grid_flat(), name='control_grid')
|
52 |
+
self.target_grid = tf.identity(grid.grid_flat(), name='target_grid')
|
53 |
+
|
54 |
+
grid.set_coords_grid(in_img_shape, in_img_shape)
|
55 |
+
self.fine_grid = tf.identity(grid.grid_flat(), 'fine_grid')
|
56 |
+
|
57 |
+
if out_img_shape is not None:
|
58 |
+
self.downsample_factor = [i // o for o, i in zip(out_img_shape, in_img_shape)]
|
59 |
+
self.img_gauss_filter = gaussian_kernel(3, 0.001, 1, 1, 3)
|
60 |
+
# self.resize_transf = tf.diag([*self.downsample_factor, 1])[:-1, :]
|
61 |
+
# self.resize_transf = tf.expand_dims(tf.reshape(self.resize_transf, [-1]), 0, name='resize_transformation') # ST expects a (12,) vector
|
62 |
+
|
63 |
+
self.augment = not only_resize
|
64 |
+
|
65 |
+
def compute_output_shape(self, input_shape):
|
66 |
+
input_shape = tf.TensorShape(input_shape).as_list()
|
67 |
+
img_shape = (input_shape[0], *self.out_img_shape, 1)
|
68 |
+
seg_shape = (input_shape[0], *self.out_img_shape, input_shape[-1] - 1)
|
69 |
+
disp_shape = (input_shape[0], *self.out_img_shape, 3)
|
70 |
+
# Expect the input to have the image and segmentations in the same tensor
|
71 |
+
if self.return_disp_map:
|
72 |
+
return (img_shape, img_shape, seg_shape, seg_shape, disp_shape)
|
73 |
+
else:
|
74 |
+
return (img_shape, img_shape, seg_shape, seg_shape)
|
75 |
+
|
76 |
+
#@tf.custom_gradient
|
77 |
+
def call(self, in_data, training=None):
|
78 |
+
# def custom_grad(in_grad):
|
79 |
+
# return tf.ones_like(in_grad)
|
80 |
+
if training is not None:
|
81 |
+
self.augment = training
|
82 |
+
return self.build_batch(in_data)# , custom_grad
|
83 |
+
|
84 |
+
def build_batch(self, fix_data: tf.Tensor):
|
85 |
+
if len(fix_data.get_shape().as_list()) < 5:
|
86 |
+
fix_data = tf.expand_dims(fix_data, axis=0) # Add Batch dimension
|
87 |
+
# fix_data = tf.tile(fix_data, (self.num_augmentations, *(1,)*4))
|
88 |
+
fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map = tf.map_fn(lambda x: self.augment_sample(x),
|
89 |
+
fix_data,
|
90 |
+
dtype=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
|
91 |
+
# map_fn unstacks elems on axis 0
|
92 |
+
if self.return_disp_map:
|
93 |
+
return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map
|
94 |
+
else:
|
95 |
+
return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch
|
96 |
+
|
97 |
+
def augment_sample(self, fix_data: tf.Tensor):
|
98 |
+
if self.only_image or not self.augment:
|
99 |
+
fix_img = fix_data
|
100 |
+
fix_segm = tf.zeros_like(fix_data, dtype=tf.float32)
|
101 |
+
else:
|
102 |
+
fix_img = fix_data[..., 0]
|
103 |
+
fix_img = tf.expand_dims(fix_img, -1)
|
104 |
+
fix_segm = fix_data[..., 1:] # We expect several segmentation masks
|
105 |
+
|
106 |
+
if self.augment:
|
107 |
+
# If we are training, do the full-fledged augmentation
|
108 |
+
fix_img = self.min_max_normalization(fix_img)
|
109 |
+
|
110 |
+
mov_img, mov_segm, disp_map = self.deform_image(tf.squeeze(fix_img), fix_segm)
|
111 |
+
mov_img = tf.expand_dims(mov_img, -1) # Add the removed channel axis
|
112 |
+
|
113 |
+
# Resample to output_shape
|
114 |
+
if self.out_img_shape is not None:
|
115 |
+
fix_img = self.downsize_image(fix_img)
|
116 |
+
mov_img = self.downsize_image(mov_img)
|
117 |
+
|
118 |
+
fix_segm = self.downsize_segmentation(fix_segm)
|
119 |
+
mov_segm = self.downsize_segmentation(mov_segm)
|
120 |
+
|
121 |
+
disp_map = self.downsize_displacement_map(disp_map)
|
122 |
+
|
123 |
+
if self.do_gamma_augm:
|
124 |
+
fix_img = self.gamma_augmentation(fix_img)
|
125 |
+
mov_img = self.gamma_augmentation(mov_img)
|
126 |
+
|
127 |
+
if self.do_brightness_augm:
|
128 |
+
fix_img = self.brightness_augmentation(fix_img)
|
129 |
+
mov_img = self.brightness_augmentation(mov_img)
|
130 |
+
|
131 |
+
else:
|
132 |
+
# During inference, just resize the input images
|
133 |
+
mov_img = tf.zeros_like(fix_img)
|
134 |
+
mov_segm = tf.zeros_like(fix_segm)
|
135 |
+
|
136 |
+
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3])
|
137 |
+
|
138 |
+
if self.out_img_shape is not None:
|
139 |
+
fix_img = self.downsize_image(fix_img)
|
140 |
+
mov_img = self.downsize_image(mov_img)
|
141 |
+
|
142 |
+
fix_segm = self.downsize_segmentation(fix_segm)
|
143 |
+
mov_segm = self.downsize_segmentation(mov_segm)
|
144 |
+
|
145 |
+
disp_map = self.downsize_displacement_map(disp_map)
|
146 |
+
|
147 |
+
fix_img = self.min_max_normalization(fix_img)
|
148 |
+
mov_img = self.min_max_normalization(mov_img)
|
149 |
+
return fix_img, mov_img, fix_segm, mov_segm, disp_map
|
150 |
+
|
151 |
+
def downsize_image(self, img):
|
152 |
+
img = tf.expand_dims(img, axis=0)
|
153 |
+
# The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
|
154 |
+
img = tf.nn.conv3d(img, self.img_gauss_filter, strides=[1, ] * 5, padding='SAME', data_format='NDHWC')
|
155 |
+
img = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(img)
|
156 |
+
|
157 |
+
return tf.squeeze(img, axis=0)
|
158 |
+
|
159 |
+
def downsize_segmentation(self, segm):
|
160 |
+
segm = tf.expand_dims(segm, axis=0)
|
161 |
+
segm = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(segm)
|
162 |
+
|
163 |
+
segm = tf.cast(segm, tf.float32)
|
164 |
+
return tf.squeeze(segm, axis=0)
|
165 |
+
|
166 |
+
def downsize_displacement_map(self, disp_map):
|
167 |
+
disp_map = tf.expand_dims(disp_map, axis=0)
|
168 |
+
# The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
|
169 |
+
disp_map = tf.layers.AveragePooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(disp_map)
|
170 |
+
|
171 |
+
# self.downsample_factor = in_shape / out_shape, but here we need out_shape / in_shape. Hence, 1 / factor
|
172 |
+
if self.downsample_factor[0] != self.downsample_factor[1] != self.downsample_factor[2]:
|
173 |
+
# Downsize the displacement magnitude along the different axes
|
174 |
+
disp_map_x = disp_map[..., 0] * 1 / self.downsample_factor[0]
|
175 |
+
disp_map_y = disp_map[..., 1] * 1 / self.downsample_factor[1]
|
176 |
+
disp_map_z = disp_map[..., 2] * 1 / self.downsample_factor[2]
|
177 |
+
|
178 |
+
disp_map = tf.stack([disp_map_x, disp_map_y, disp_map_z], axis=-1)
|
179 |
+
else:
|
180 |
+
disp_map = disp_map * 1 / self.downsample_factor[0]
|
181 |
+
|
182 |
+
return tf.squeeze(disp_map, axis=0)
|
183 |
+
|
184 |
+
def gamma_augmentation(self, in_img: tf.Tensor):
|
185 |
+
in_img += 1e-5 # To prvent NaNs
|
186 |
+
gamma = tf.random.uniform((), 0.5, 2, tf.float32)
|
187 |
+
|
188 |
+
return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
|
189 |
+
|
190 |
+
def brightness_augmentation(self, in_img: tf.Tensor):
|
191 |
+
c = tf.random.uniform((), 0.5, 2, tf.float32)
|
192 |
+
return tf.clip_by_value(c*in_img, 0, 1)
|
193 |
+
|
194 |
+
def min_max_normalization(self, in_img: tf.Tensor):
|
195 |
+
return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
|
196 |
+
tf.subtract(tf.reduce_max(in_img), tf.reduce_min(in_img)))
|
197 |
+
|
198 |
+
def deform_image(self, fix_img: tf.Tensor, fix_segm: tf.Tensor):
|
199 |
+
# Get locations where the intensity > 0.0
|
200 |
+
idx_points_in_label = tf.where(tf.greater(fix_img, 0.0))
|
201 |
+
|
202 |
+
# Randomly select N points
|
203 |
+
# random_idx = tf.random.uniform((self.num_control_points,),
|
204 |
+
# minval=0, maxval=tf.shape(idx_points_in_label)[0],
|
205 |
+
# dtype=tf.int32)
|
206 |
+
#
|
207 |
+
# disp_location = tf.gather(idx_points_in_label, random_idx) # And get the coordinates
|
208 |
+
# disp_location = tf.cast(disp_location, tf.float32)
|
209 |
+
disp_location = sample_unique(idx_points_in_label, self.num_control_points, tf.float32)
|
210 |
+
|
211 |
+
# Get the coordinates of the control point displaces
|
212 |
+
rand_disp = tf.random.uniform((self.num_control_points, 3), minval=-1, maxval=1, dtype=tf.float32) * self.max_deformation
|
213 |
+
warped_location = disp_location + rand_disp
|
214 |
+
|
215 |
+
# Add the selected locations to the control grid and the warped locations to the target grid
|
216 |
+
control_grid = tf.concat([self.control_grid, disp_location], axis=0)
|
217 |
+
trg_grid = tf.concat([self.control_grid, warped_location], axis=0)
|
218 |
+
|
219 |
+
# Apply global transformation
|
220 |
+
valid_trf = False
|
221 |
+
while not valid_trf:
|
222 |
+
trg_grid, aff = self.global_transformation(trg_grid)
|
223 |
+
|
224 |
+
# Interpolate the displacement map
|
225 |
+
try:
|
226 |
+
tps = ThinPlateSplines(control_grid, trg_grid)
|
227 |
+
def_grid = tps.interpolate(self.fine_grid)
|
228 |
+
except InvalidArgumentError as err:
|
229 |
+
# If the transformation raises a non-invertible error,
|
230 |
+
# try again until we get a valid transformation
|
231 |
+
continue
|
232 |
+
else:
|
233 |
+
valid_trf = True
|
234 |
+
|
235 |
+
disp_map = self.fine_grid - def_grid
|
236 |
+
disp_map = tf.reshape(disp_map, (*self.in_img_shape, -1))
|
237 |
+
|
238 |
+
# Apply the displacement map
|
239 |
+
fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
|
240 |
+
fix_segm = tf.expand_dims(fix_segm, 0)
|
241 |
+
disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
|
242 |
+
|
243 |
+
mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
|
244 |
+
mov_segm = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([fix_segm, disp_map])
|
245 |
+
|
246 |
+
mov_img = tf.where(tf.is_nan(mov_img), tf.zeros_like(mov_img), mov_img)
|
247 |
+
mov_img = tf.where(tf.is_inf(mov_img), tf.zeros_like(mov_img), mov_img)
|
248 |
+
|
249 |
+
mov_segm = tf.where(tf.is_nan(mov_segm), tf.zeros_like(mov_segm), mov_segm)
|
250 |
+
mov_segm = tf.where(tf.is_inf(mov_segm), tf.zeros_like(mov_segm), mov_segm)
|
251 |
+
|
252 |
+
return tf.squeeze(mov_img), tf.squeeze(mov_segm, axis=0), tf.squeeze(disp_map, axis=0)
|
253 |
+
|
254 |
+
def global_transformation(self, points: tf.Tensor):
|
255 |
+
axis = tf.random.uniform((), 0, 3)
|
256 |
+
|
257 |
+
alpha = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 0.), tf.less_equal(axis, 1.)),
|
258 |
+
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
|
259 |
+
lambda: tf.zeros((), tf.float32))
|
260 |
+
beta = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 1.), tf.less_equal(axis, 2.)),
|
261 |
+
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
|
262 |
+
lambda: tf.zeros((), tf.float32))
|
263 |
+
gamma = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 2.), tf.less_equal(axis, 3.)),
|
264 |
+
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
|
265 |
+
lambda: tf.zeros((), tf.float32))
|
266 |
+
|
267 |
+
ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
|
268 |
+
tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
|
269 |
+
tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
|
270 |
+
|
271 |
+
M = self.build_affine_transformation(tf.convert_to_tensor(self.in_img_shape, tf.float32),
|
272 |
+
alpha, beta, gamma, ti, tj, tk)
|
273 |
+
|
274 |
+
points = tf.transpose(points)
|
275 |
+
new_pts = tf.matmul(M[:3, :3], points)
|
276 |
+
new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
|
277 |
+
return tf.transpose(new_pts), M
|
278 |
+
|
279 |
+
@staticmethod
|
280 |
+
def build_affine_transformation(img_shape, alpha, beta, gamma, ti, tj, tk):
|
281 |
+
img_centre = tf.divide(img_shape, 2.)
|
282 |
+
|
283 |
+
# Rotation matrix around the image centre
|
284 |
+
# R* = T(p) R(ang) T(-p)
|
285 |
+
# tf.cos and tf.sin expect radians
|
286 |
+
|
287 |
+
T = tf.convert_to_tensor([[1, 0, 0, ti],
|
288 |
+
[0, 1, 0, tj],
|
289 |
+
[0, 0, 1, tk],
|
290 |
+
[0, 0, 0, 1]], tf.float32)
|
291 |
+
|
292 |
+
Ri = tf.convert_to_tensor([[1, 0, 0, 0],
|
293 |
+
[0, tf.math.cos(alpha), -tf.math.sin(alpha), 0],
|
294 |
+
[0, tf.math.sin(alpha), tf.math.cos(alpha), 0],
|
295 |
+
[0, 0, 0, 1]], tf.float32)
|
296 |
+
|
297 |
+
Rj = tf.convert_to_tensor([[ tf.math.cos(beta), 0, tf.math.sin(beta), 0],
|
298 |
+
[0, 1, 0, 0],
|
299 |
+
[-tf.math.sin(beta), 0, tf.math.cos(beta), 0],
|
300 |
+
[0, 0, 0, 1]], tf.float32)
|
301 |
+
|
302 |
+
Rk = tf.convert_to_tensor([[tf.math.cos(gamma), -tf.math.sin(gamma), 0, 0],
|
303 |
+
[tf.math.sin(gamma), tf.math.cos(gamma), 0, 0],
|
304 |
+
[0, 0, 1, 0],
|
305 |
+
[0, 0, 0, 1]], tf.float32)
|
306 |
+
|
307 |
+
R = tf.matmul(tf.matmul(Ri, Rj), Rk)
|
308 |
+
|
309 |
+
Tc = tf.convert_to_tensor([[1, 0, 0, img_centre[0]],
|
310 |
+
[0, 1, 0, img_centre[1]],
|
311 |
+
[0, 0, 1, img_centre[2]],
|
312 |
+
[0, 0, 0, 1]], tf.float32)
|
313 |
+
|
314 |
+
Tc_ = tf.convert_to_tensor([[1, 0, 0, -img_centre[0]],
|
315 |
+
[0, 1, 0, -img_centre[1]],
|
316 |
+
[0, 0, 1, -img_centre[2]],
|
317 |
+
[0, 0, 0, 1]], tf.float32)
|
318 |
+
|
319 |
+
return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
|
320 |
+
|
321 |
+
def get_config(self):
|
322 |
+
config = super(AugmentationLayer, self).get_config()
|
323 |
+
return config
|
324 |
+
|
325 |
+
|
326 |
+
|
DeepDeformationMapRegistration/layers/b_splines.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Polyharmonic spline interpolation."""
|
16 |
+
|
17 |
+
import tensorflow as tf
|
18 |
+
from typing import Union, List
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
Number = Union[
|
22 |
+
float,
|
23 |
+
int,
|
24 |
+
np.float16,
|
25 |
+
np.float32,
|
26 |
+
np.float64,
|
27 |
+
np.int8,
|
28 |
+
np.int16,
|
29 |
+
np.int32,
|
30 |
+
np.int64,
|
31 |
+
np.uint8,
|
32 |
+
np.uint16,
|
33 |
+
np.uint32,
|
34 |
+
np.uint64,
|
35 |
+
]
|
36 |
+
|
37 |
+
TensorLike = Union[
|
38 |
+
List[Union[Number, list]],
|
39 |
+
tuple,
|
40 |
+
Number,
|
41 |
+
np.ndarray,
|
42 |
+
tf.Tensor,
|
43 |
+
tf.SparseTensor,
|
44 |
+
tf.Variable,
|
45 |
+
]
|
46 |
+
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
|
47 |
+
|
48 |
+
EPSILON = 0.0000000001
|
49 |
+
|
50 |
+
|
51 |
+
def _cross_squared_distance_matrix(x: TensorLike, y: TensorLike) -> tf.Tensor:
|
52 |
+
"""Pairwise squared distance between two (batch) matrices' rows (2nd dim).
|
53 |
+
Computes the pairwise distances between rows of x and rows of y.
|
54 |
+
Args:
|
55 |
+
x: `[batch_size, n, d]` float `Tensor`.
|
56 |
+
y: `[batch_size, m, d]` float `Tensor`.
|
57 |
+
Returns:
|
58 |
+
squared_dists: `[batch_size, n, m]` float `Tensor`, where
|
59 |
+
`squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2`.
|
60 |
+
"""
|
61 |
+
x_norm_squared = tf.reduce_sum(tf.square(x), 2)
|
62 |
+
y_norm_squared = tf.reduce_sum(tf.square(y), 2)
|
63 |
+
|
64 |
+
# Expand so that we can broadcast.
|
65 |
+
x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2)
|
66 |
+
y_norm_squared_tile = tf.expand_dims(y_norm_squared, 1)
|
67 |
+
|
68 |
+
x_y_transpose = tf.matmul(x, y, adjoint_b=True)
|
69 |
+
|
70 |
+
# squared_dists[b,i,j] = ||x_bi - y_bj||^2 =
|
71 |
+
# x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
|
72 |
+
squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile
|
73 |
+
|
74 |
+
return squared_dists
|
75 |
+
|
76 |
+
|
77 |
+
def _pairwise_squared_distance_matrix(x: TensorLike) -> tf.Tensor:
|
78 |
+
"""Pairwise squared distance among a (batch) matrix's rows (2nd dim).
|
79 |
+
This saves a bit of computation vs. using
|
80 |
+
`_cross_squared_distance_matrix(x, x)`
|
81 |
+
Args:
|
82 |
+
x: `[batch_size, n, d]` float `Tensor`.
|
83 |
+
Returns:
|
84 |
+
squared_dists: `[batch_size, n, n]` float `Tensor`, where
|
85 |
+
`squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2`.
|
86 |
+
"""
|
87 |
+
|
88 |
+
x_x_transpose = tf.matmul(x, x, adjoint_b=True)
|
89 |
+
x_norm_squared = tf.linalg.diag_part(x_x_transpose)
|
90 |
+
x_norm_squared_tile = tf.expand_dims(x_norm_squared, 2)
|
91 |
+
|
92 |
+
# squared_dists[b,i,j] = ||x_bi - x_bj||^2 =
|
93 |
+
# = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
|
94 |
+
squared_dists = (
|
95 |
+
x_norm_squared_tile
|
96 |
+
- 2 * x_x_transpose
|
97 |
+
+ tf.transpose(x_norm_squared_tile, [0, 2, 1])
|
98 |
+
)
|
99 |
+
|
100 |
+
return squared_dists
|
101 |
+
|
102 |
+
|
103 |
+
def _solve_interpolation(
|
104 |
+
train_points: TensorLike,
|
105 |
+
train_values: TensorLike,
|
106 |
+
order: int,
|
107 |
+
regularization_weight: FloatTensorLike,
|
108 |
+
) -> TensorLike:
|
109 |
+
r"""Solve for interpolation coefficients.
|
110 |
+
Computes the coefficients of the polyharmonic interpolant for the
|
111 |
+
'training' data defined by `(train_points, train_values)` using the kernel
|
112 |
+
$\phi$.
|
113 |
+
Args:
|
114 |
+
train_points: `[b, n, d]` interpolation centers.
|
115 |
+
train_values: `[b, n, k]` function values.
|
116 |
+
order: order of the interpolation.
|
117 |
+
regularization_weight: weight to place on smoothness regularization term.
|
118 |
+
Returns:
|
119 |
+
w: `[b, n, k]` weights on each interpolation center
|
120 |
+
v: `[b, d, k]` weights on each input dimension
|
121 |
+
Raises:
|
122 |
+
ValueError: if d or k is not fully specified.
|
123 |
+
"""
|
124 |
+
|
125 |
+
# These dimensions are set dynamically at runtime.
|
126 |
+
b, n, _ = tf.unstack(tf.shape(train_points), num=3)
|
127 |
+
|
128 |
+
d = train_points.shape[-1]
|
129 |
+
if d is None:
|
130 |
+
raise ValueError(
|
131 |
+
"The dimensionality of the input points (d) must be "
|
132 |
+
"statically-inferrable."
|
133 |
+
)
|
134 |
+
|
135 |
+
k = train_values.shape[-1]
|
136 |
+
if k is None:
|
137 |
+
raise ValueError(
|
138 |
+
"The dimensionality of the output values (k) must be "
|
139 |
+
"statically-inferrable."
|
140 |
+
)
|
141 |
+
|
142 |
+
# First, rename variables so that the notation (c, f, w, v, A, B, etc.)
|
143 |
+
# follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
|
144 |
+
# To account for python style guidelines we use
|
145 |
+
# matrix_a for A and matrix_b for B.
|
146 |
+
|
147 |
+
c = train_points
|
148 |
+
f = train_values
|
149 |
+
|
150 |
+
# Next, construct the linear system.
|
151 |
+
with tf.name_scope("construct_linear_system"):
|
152 |
+
|
153 |
+
matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n]
|
154 |
+
if regularization_weight > 0:
|
155 |
+
batch_identity_matrix = tf.expand_dims(tf.eye(n, dtype=c.dtype), 0)
|
156 |
+
matrix_a += regularization_weight * batch_identity_matrix
|
157 |
+
|
158 |
+
# Append ones to the feature values for the bias term
|
159 |
+
# in the linear model.
|
160 |
+
ones = tf.ones_like(c[..., :1], dtype=c.dtype)
|
161 |
+
matrix_b = tf.concat([c, ones], 2) # [b, n, d + 1]
|
162 |
+
|
163 |
+
# [b, n + d + 1, n]
|
164 |
+
left_block = tf.concat([matrix_a, tf.transpose(matrix_b, [0, 2, 1])], 1)
|
165 |
+
|
166 |
+
num_b_cols = matrix_b.get_shape()[2] # d + 1
|
167 |
+
lhs_zeros = tf.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
|
168 |
+
right_block = tf.concat([matrix_b, lhs_zeros], 1) # [b, n + d + 1, d + 1]
|
169 |
+
lhs = tf.concat([left_block, right_block], 2) # [b, n + d + 1, n + d + 1]
|
170 |
+
|
171 |
+
rhs_zeros = tf.zeros([b, d + 1, k], train_points.dtype)
|
172 |
+
rhs = tf.concat([f, rhs_zeros], 1) # [b, n + d + 1, k]
|
173 |
+
|
174 |
+
# Then, solve the linear system and unpack the results.
|
175 |
+
with tf.name_scope("solve_linear_system"):
|
176 |
+
w_v = tf.linalg.solve(lhs, rhs)
|
177 |
+
w = w_v[:, :n, :]
|
178 |
+
v = w_v[:, n:, :]
|
179 |
+
|
180 |
+
return w, v
|
181 |
+
|
182 |
+
|
183 |
+
def _apply_interpolation(
|
184 |
+
query_points: TensorLike,
|
185 |
+
train_points: TensorLike,
|
186 |
+
w: TensorLike,
|
187 |
+
v: TensorLike,
|
188 |
+
order: int,
|
189 |
+
) -> TensorLike:
|
190 |
+
"""Apply polyharmonic interpolation model to data.
|
191 |
+
Given coefficients w and v for the interpolation model, we evaluate
|
192 |
+
interpolated function values at query_points.
|
193 |
+
Args:
|
194 |
+
query_points: `[b, m, d]` x values to evaluate the interpolation at.
|
195 |
+
train_points: `[b, n, d]` x values that act as the interpolation centers
|
196 |
+
(the c variables in the wikipedia article).
|
197 |
+
w: `[b, n, k]` weights on each interpolation center.
|
198 |
+
v: `[b, d, k]` weights on each input dimension.
|
199 |
+
order: order of the interpolation.
|
200 |
+
Returns:
|
201 |
+
Polyharmonic interpolation evaluated at points defined in `query_points`.
|
202 |
+
"""
|
203 |
+
|
204 |
+
# First, compute the contribution from the rbf term.
|
205 |
+
pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
|
206 |
+
phi_pairwise_dists = _phi(pairwise_dists, order)
|
207 |
+
|
208 |
+
rbf_term = tf.matmul(phi_pairwise_dists, w)
|
209 |
+
|
210 |
+
# Then, compute the contribution from the linear term.
|
211 |
+
# Pad query_points with ones, for the bias term in the linear model.
|
212 |
+
query_points_pad = tf.concat(
|
213 |
+
[query_points, tf.ones_like(query_points[..., :1], train_points.dtype)], 2
|
214 |
+
)
|
215 |
+
linear_term = tf.matmul(query_points_pad, v)
|
216 |
+
|
217 |
+
return rbf_term + linear_term
|
218 |
+
|
219 |
+
|
220 |
+
def _phi(r: FloatTensorLike, order: int) -> FloatTensorLike:
|
221 |
+
"""Coordinate-wise nonlinearity used to define the order of the
|
222 |
+
interpolation.
|
223 |
+
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
|
224 |
+
Args:
|
225 |
+
r: input op.
|
226 |
+
order: interpolation order.
|
227 |
+
Returns:
|
228 |
+
`phi_k` evaluated coordinate-wise on `r`, for `k = r`.
|
229 |
+
"""
|
230 |
+
|
231 |
+
# using EPSILON prevents log(0), sqrt0), etc.
|
232 |
+
# sqrt(0) is well-defined, but its gradient is not
|
233 |
+
with tf.name_scope("phi"):
|
234 |
+
if order == 1:
|
235 |
+
r = tf.maximum(r, EPSILON)
|
236 |
+
r = tf.sqrt(r)
|
237 |
+
return r
|
238 |
+
elif order == 2:
|
239 |
+
return 0.5 * r * tf.math.log(tf.maximum(r, EPSILON))
|
240 |
+
elif order == 4:
|
241 |
+
return 0.5 * tf.square(r) * tf.math.log(tf.maximum(r, EPSILON))
|
242 |
+
elif order % 2 == 0:
|
243 |
+
r = tf.maximum(r, EPSILON)
|
244 |
+
return 0.5 * tf.pow(r, 0.5 * order) * tf.math.log(r)
|
245 |
+
else:
|
246 |
+
r = tf.maximum(r, EPSILON)
|
247 |
+
return tf.pow(r, 0.5 * order)
|
248 |
+
|
249 |
+
|
250 |
+
def interpolate_spline(
|
251 |
+
train_points: TensorLike,
|
252 |
+
train_values: TensorLike,
|
253 |
+
query_points: TensorLike,
|
254 |
+
order: int,
|
255 |
+
regularization_weight: FloatTensorLike = 0.0,
|
256 |
+
name: str = "interpolate_spline",
|
257 |
+
) -> tf.Tensor:
|
258 |
+
r"""Interpolate signal using polyharmonic interpolation.
|
259 |
+
The interpolant has the form
|
260 |
+
$$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$
|
261 |
+
This is a sum of two terms: (1) a weighted sum of radial basis function
|
262 |
+
(RBF) terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term
|
263 |
+
with a bias. The \\(c_i\\) vectors are 'training' points.
|
264 |
+
In the code, b is absorbed into v
|
265 |
+
by appending 1 as a final dimension to x. The coefficients w and v are
|
266 |
+
estimated such that the interpolant exactly fits the value of the function
|
267 |
+
at the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\),
|
268 |
+
and the vector w sums to 0. With these constraints, the coefficients
|
269 |
+
can be obtained by solving a linear system.
|
270 |
+
\\(\phi\\) is an RBF, parametrized by an interpolation
|
271 |
+
order. Using order=2 produces the well-known thin-plate spline.
|
272 |
+
We also provide the option to perform regularized interpolation. Here, the
|
273 |
+
interpolant is selected to trade off between the squared loss on the
|
274 |
+
training data and a certain measure of its curvature
|
275 |
+
([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
|
276 |
+
Using a regularization weight greater than zero has the effect that the
|
277 |
+
interpolant will no longer exactly fit the training data. However, it may
|
278 |
+
be less vulnerable to overfitting, particularly for high-order
|
279 |
+
interpolation.
|
280 |
+
Note the interpolation procedure is differentiable with respect to all
|
281 |
+
inputs besides the order parameter.
|
282 |
+
We support dynamically-shaped inputs, where batch_size, n, and m are None
|
283 |
+
at graph construction time. However, d and k must be known.
|
284 |
+
Args:
|
285 |
+
train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
|
286 |
+
locations. These do not need to be regularly-spaced.
|
287 |
+
train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional
|
288 |
+
values evaluated at train_points.
|
289 |
+
query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
|
290 |
+
where we will output the interpolant's values.
|
291 |
+
order: order of the interpolation. Common values are 1 for
|
292 |
+
\\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\)
|
293 |
+
(thin-plate spline), or 3 for \\(\phi(r) = r^3\\).
|
294 |
+
regularization_weight: weight placed on the regularization term.
|
295 |
+
This will depend substantially on the problem, and it should always be
|
296 |
+
tuned. For many problems, it is reasonable to use no regularization.
|
297 |
+
If using a non-zero value, we recommend a small value like 0.001.
|
298 |
+
name: name prefix for ops created by this function
|
299 |
+
Returns:
|
300 |
+
`[b, m, k]` float `Tensor` of query values. We use train_points and
|
301 |
+
train_values to perform polyharmonic interpolation. The query values are
|
302 |
+
the values of the interpolant evaluated at the locations specified in
|
303 |
+
query_points.
|
304 |
+
"""
|
305 |
+
with tf.name_scope(name or "interpolate_spline"):
|
306 |
+
train_points = tf.convert_to_tensor(train_points)
|
307 |
+
train_values = tf.convert_to_tensor(train_values)
|
308 |
+
query_points = tf.convert_to_tensor(query_points)
|
309 |
+
|
310 |
+
# First, fit the spline to the observed data.
|
311 |
+
with tf.name_scope("solve"):
|
312 |
+
w, v = _solve_interpolation(
|
313 |
+
train_points, train_values, order, regularization_weight
|
314 |
+
)
|
315 |
+
|
316 |
+
# Then, evaluate the spline at the query locations.
|
317 |
+
with tf.name_scope("predict"):
|
318 |
+
query_values = _apply_interpolation(query_points, train_points, w, v, order)
|
319 |
+
|
320 |
+
return query_values
|
DeepDeformationMapRegistration/layers/depthwise_conv_3d.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SRC: https://github.com/alexandrosstergiou/keras-DepthwiseConv3D
|
2 |
+
|
3 |
+
'''
|
4 |
+
This is a modification of the SeparableConv3D code in Keras,
|
5 |
+
to perform just the Depthwise Convolution (1st step) of the
|
6 |
+
Depthwise Separable Convolution layer.
|
7 |
+
'''
|
8 |
+
from __future__ import absolute_import
|
9 |
+
|
10 |
+
from tensorflow.keras import backend as K
|
11 |
+
from tensorflow.keras import initializers
|
12 |
+
from tensorflow.keras import regularizers
|
13 |
+
from tensorflow.keras import constraints
|
14 |
+
import tensorflow.keras.utils as conv_utils
|
15 |
+
from tensorflow.keras.layers import Conv3D, InputSpec
|
16 |
+
from tensorflow.python.keras.backend import _preprocess_padding, _preprocess_conv3d_input
|
17 |
+
|
18 |
+
import tensorflow as tf
|
19 |
+
|
20 |
+
|
21 |
+
class DepthwiseConv3D(Conv3D):
|
22 |
+
"""Depthwise 3D convolution.
|
23 |
+
Depth-wise part of separable convolutions consist in performing
|
24 |
+
just the first step/operation
|
25 |
+
(which acts on each input channel separately).
|
26 |
+
It does not perform the pointwise convolution (second step).
|
27 |
+
The `depth_multiplier` argument controls how many
|
28 |
+
output channels are generated per input channel in the depthwise step.
|
29 |
+
# Arguments
|
30 |
+
kernel_size: An integer or tuple/list of 3 integers, specifying the
|
31 |
+
depth, width and height of the 3D convolution window.
|
32 |
+
Can be a single integer to specify the same value for
|
33 |
+
all spatial dimensions.
|
34 |
+
strides: An integer or tuple/list of 3 integers,
|
35 |
+
specifying the strides of the convolution along the depth, width and height.
|
36 |
+
Can be a single integer to specify the same value for
|
37 |
+
all spatial dimensions.
|
38 |
+
padding: one of `"valid"` or `"same"` (case-insensitive).
|
39 |
+
depth_multiplier: The number of depthwise convolution output channels
|
40 |
+
for each input channel.
|
41 |
+
The total number of depthwise convolution output
|
42 |
+
channels will be equal to `filterss_in * depth_multiplier`.
|
43 |
+
groups: The depth size of the convolution (as a variant of the original Depthwise conv)
|
44 |
+
data_format: A string,
|
45 |
+
one of `channels_last` (default) or `channels_first`.
|
46 |
+
The ordering of the dimensions in the inputs.
|
47 |
+
`channels_last` corresponds to inputs with shape
|
48 |
+
`(batch, height, width, channels)` while `channels_first`
|
49 |
+
corresponds to inputs with shape
|
50 |
+
`(batch, channels, height, width)`.
|
51 |
+
It defaults to the `image_data_format` value found in your
|
52 |
+
Keras config file at `~/.keras/keras.json`.
|
53 |
+
If you never set it, then it will be "channels_last".
|
54 |
+
activation: Activation function to use
|
55 |
+
(see [activations](../activations.md)).
|
56 |
+
If you don't specify anything, no activation is applied
|
57 |
+
(ie. "linear" activation: `a(x) = x`).
|
58 |
+
use_bias: Boolean, whether the layer uses a bias vector.
|
59 |
+
depthwise_initializer: Initializer for the depthwise kernel matrix
|
60 |
+
(see [initializers](../initializers.md)).
|
61 |
+
bias_initializer: Initializer for the bias vector
|
62 |
+
(see [initializers](../initializers.md)).
|
63 |
+
depthwise_regularizer: Regularizer function applied to
|
64 |
+
the depthwise kernel matrix
|
65 |
+
(see [regularizer](../regularizers.md)).
|
66 |
+
bias_regularizer: Regularizer function applied to the bias vector
|
67 |
+
(see [regularizer](../regularizers.md)).
|
68 |
+
dialation_rate: List of ints.
|
69 |
+
Defines the dilation factor for each dimension in the
|
70 |
+
input. Defaults to (1,1,1)
|
71 |
+
activity_regularizer: Regularizer function applied to
|
72 |
+
the output of the layer (its "activation").
|
73 |
+
(see [regularizer](../regularizers.md)).
|
74 |
+
depthwise_constraint: Constraint function applied to
|
75 |
+
the depthwise kernel matrix
|
76 |
+
(see [constraints](../constraints.md)).
|
77 |
+
bias_constraint: Constraint function applied to the bias vector
|
78 |
+
(see [constraints](../constraints.md)).
|
79 |
+
# Input shape
|
80 |
+
5D tensor with shape:
|
81 |
+
`(batch, depth, channels, rows, cols)` if data_format='channels_first'
|
82 |
+
or 5D tensor with shape:
|
83 |
+
`(batch, depth, rows, cols, channels)` if data_format='channels_last'.
|
84 |
+
# Output shape
|
85 |
+
5D tensor with shape:
|
86 |
+
`(batch, filters * depth, new_depth, new_rows, new_cols)` if data_format='channels_first'
|
87 |
+
or 4D tensor with shape:
|
88 |
+
`(batch, new_depth, new_rows, new_cols, filters * depth)` if data_format='channels_last'.
|
89 |
+
`rows` and `cols` values might have changed due to padding.
|
90 |
+
"""
|
91 |
+
|
92 |
+
#@legacy_depthwise_conv3d_support
|
93 |
+
def __init__(self,
|
94 |
+
kernel_size,
|
95 |
+
strides=(1, 1, 1),
|
96 |
+
padding='valid',
|
97 |
+
depth_multiplier=1,
|
98 |
+
groups=None,
|
99 |
+
data_format=None,
|
100 |
+
activation=None,
|
101 |
+
use_bias=True,
|
102 |
+
depthwise_initializer='glorot_uniform',
|
103 |
+
bias_initializer='zeros',
|
104 |
+
dilation_rate = (1, 1, 1),
|
105 |
+
depthwise_regularizer=None,
|
106 |
+
bias_regularizer=None,
|
107 |
+
activity_regularizer=None,
|
108 |
+
depthwise_constraint=None,
|
109 |
+
bias_constraint=None,
|
110 |
+
**kwargs):
|
111 |
+
super(DepthwiseConv3D, self).__init__(
|
112 |
+
filters=None,
|
113 |
+
kernel_size=kernel_size,
|
114 |
+
strides=strides,
|
115 |
+
padding=padding,
|
116 |
+
data_format=data_format,
|
117 |
+
activation=activation,
|
118 |
+
use_bias=use_bias,
|
119 |
+
bias_regularizer=bias_regularizer,
|
120 |
+
dilation_rate=dilation_rate,
|
121 |
+
activity_regularizer=activity_regularizer,
|
122 |
+
bias_constraint=bias_constraint,
|
123 |
+
**kwargs)
|
124 |
+
self.depth_multiplier = depth_multiplier
|
125 |
+
self.groups = groups
|
126 |
+
self.depthwise_initializer = initializers.get(depthwise_initializer)
|
127 |
+
self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
|
128 |
+
self.depthwise_constraint = constraints.get(depthwise_constraint)
|
129 |
+
self.bias_initializer = initializers.get(bias_initializer)
|
130 |
+
self.dilation_rate = dilation_rate
|
131 |
+
self._padding = _preprocess_padding(self.padding)
|
132 |
+
self._strides = (1,) + self.strides + (1,)
|
133 |
+
self._data_format = "NDHWC"
|
134 |
+
self.input_dim = None
|
135 |
+
|
136 |
+
def build(self, input_shape):
|
137 |
+
if len(input_shape) < 5:
|
138 |
+
raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. '
|
139 |
+
'Received input shape:', str(input_shape))
|
140 |
+
if self.data_format == 'channels_first':
|
141 |
+
channel_axis = 1
|
142 |
+
else:
|
143 |
+
channel_axis = -1
|
144 |
+
if input_shape[channel_axis] is None:
|
145 |
+
raise ValueError('The channel dimension of the inputs to '
|
146 |
+
'`DepthwiseConv3D` '
|
147 |
+
'should be defined. Found `None`.')
|
148 |
+
self.input_dim = int(input_shape[channel_axis])
|
149 |
+
|
150 |
+
if self.groups is None:
|
151 |
+
self.groups = self.input_dim
|
152 |
+
|
153 |
+
if self.groups > self.input_dim:
|
154 |
+
raise ValueError('The number of groups cannot exceed the number of channels')
|
155 |
+
|
156 |
+
if self.input_dim % self.groups != 0:
|
157 |
+
raise ValueError('Warning! The channels dimension is not divisible by the group size chosen')
|
158 |
+
|
159 |
+
depthwise_kernel_shape = (self.kernel_size[0],
|
160 |
+
self.kernel_size[1],
|
161 |
+
self.kernel_size[2],
|
162 |
+
self.input_dim,
|
163 |
+
self.depth_multiplier)
|
164 |
+
|
165 |
+
self.depthwise_kernel = self.add_weight(
|
166 |
+
shape=depthwise_kernel_shape,
|
167 |
+
initializer=self.depthwise_initializer,
|
168 |
+
name='depthwise_kernel',
|
169 |
+
regularizer=self.depthwise_regularizer,
|
170 |
+
constraint=self.depthwise_constraint)
|
171 |
+
|
172 |
+
if self.use_bias:
|
173 |
+
self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,),
|
174 |
+
initializer=self.bias_initializer,
|
175 |
+
name='bias',
|
176 |
+
regularizer=self.bias_regularizer,
|
177 |
+
constraint=self.bias_constraint)
|
178 |
+
else:
|
179 |
+
self.bias = None
|
180 |
+
# Set input spec.
|
181 |
+
self.input_spec = InputSpec(ndim=5, axes={channel_axis: self.input_dim})
|
182 |
+
self.built = True
|
183 |
+
|
184 |
+
def call(self, inputs, training=None):
|
185 |
+
inputs = _preprocess_conv3d_input(inputs, self.data_format)
|
186 |
+
|
187 |
+
if self.data_format == 'channels_last':
|
188 |
+
dilation = (1,) + self.dilation_rate + (1,)
|
189 |
+
else:
|
190 |
+
dilation = self.dilation_rate + (1,) + (1,)
|
191 |
+
|
192 |
+
if self._data_format == 'NCDHW':
|
193 |
+
outputs = tf.concat(
|
194 |
+
[tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
|
195 |
+
strides=self._strides,
|
196 |
+
padding=self._padding,
|
197 |
+
dilations=dilation,
|
198 |
+
data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1)
|
199 |
+
|
200 |
+
else:
|
201 |
+
outputs = tf.concat(
|
202 |
+
[tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
|
203 |
+
strides=self._strides,
|
204 |
+
padding=self._padding,
|
205 |
+
dilations=dilation,
|
206 |
+
data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1)
|
207 |
+
|
208 |
+
if self.bias is not None:
|
209 |
+
outputs = K.bias_add(
|
210 |
+
outputs,
|
211 |
+
self.bias,
|
212 |
+
data_format=self.data_format)
|
213 |
+
|
214 |
+
if self.activation is not None:
|
215 |
+
return self.activation(outputs)
|
216 |
+
|
217 |
+
return outputs
|
218 |
+
|
219 |
+
def compute_output_shape(self, input_shape):
|
220 |
+
if self.data_format == 'channels_first':
|
221 |
+
depth = input_shape[2]
|
222 |
+
rows = input_shape[3]
|
223 |
+
cols = input_shape[4]
|
224 |
+
out_filters = self.groups * self.depth_multiplier
|
225 |
+
elif self.data_format == 'channels_last':
|
226 |
+
depth = input_shape[1]
|
227 |
+
rows = input_shape[2]
|
228 |
+
cols = input_shape[3]
|
229 |
+
out_filters = self.groups * self.depth_multiplier
|
230 |
+
|
231 |
+
depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
|
232 |
+
self.padding,
|
233 |
+
self.strides[0])
|
234 |
+
|
235 |
+
rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
|
236 |
+
self.padding,
|
237 |
+
self.strides[1])
|
238 |
+
|
239 |
+
cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
|
240 |
+
self.padding,
|
241 |
+
self.strides[2])
|
242 |
+
|
243 |
+
if self.data_format == 'channels_first':
|
244 |
+
return (input_shape[0], out_filters, depth, rows, cols)
|
245 |
+
|
246 |
+
elif self.data_format == 'channels_last':
|
247 |
+
return (input_shape[0], depth, rows, cols, out_filters)
|
248 |
+
|
249 |
+
def get_config(self):
|
250 |
+
config = super(DepthwiseConv3D, self).get_config()
|
251 |
+
config.pop('filters')
|
252 |
+
config.pop('kernel_initializer')
|
253 |
+
config.pop('kernel_regularizer')
|
254 |
+
config.pop('kernel_constraint')
|
255 |
+
config['depth_multiplier'] = self.depth_multiplier
|
256 |
+
config['depthwise_initializer'] = initializers.serialize(self.depthwise_initializer)
|
257 |
+
config['depthwise_regularizer'] = regularizers.serialize(self.depthwise_regularizer)
|
258 |
+
config['depthwise_constraint'] = constraints.serialize(self.depthwise_constraint)
|
259 |
+
return config
|
260 |
+
|
261 |
+
def __call__(self, inputs, training=True):
|
262 |
+
return self.call(inputs, training)
|
263 |
+
|
264 |
+
DepthwiseConvolution3D = DepthwiseConv3D
|
DeepDeformationMapRegistration/layers/uncertainty_weighting.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
# parentdir = os.path.dirname(currentdir)
|
5 |
+
# sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
#
|
7 |
+
# PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
8 |
+
|
9 |
+
import tensorflow.keras.layers as kl
|
10 |
+
import tensorflow.keras.backend as K
|
11 |
+
import tensorflow as tf
|
12 |
+
import numpy as np
|
13 |
+
import random
|
14 |
+
|
15 |
+
from DeepDeformationMapRegistration.utils.operators import soft_threshold, gaussian_kernel, sample_unique
|
16 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
17 |
+
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
18 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
19 |
+
from neurite.tf.utils import resize
|
20 |
+
#from cupyx.scipy.ndimage import zoom
|
21 |
+
#import cupy
|
22 |
+
|
23 |
+
|
24 |
+
class UncertaintyWeighting(kl.Layer):
|
25 |
+
def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
|
26 |
+
reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
|
27 |
+
**kwargs):
|
28 |
+
assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
|
29 |
+
assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
|
30 |
+
self.num_loss = num_loss_fns
|
31 |
+
if len(loss_fns) == 1 and self.num_loss > 1:
|
32 |
+
self.loss_fns = loss_fns * self.num_loss
|
33 |
+
else:
|
34 |
+
self.loss_fns = loss_fns
|
35 |
+
|
36 |
+
if len(prior_loss_w) == 1:
|
37 |
+
self.prior_loss_w = prior_loss_w * num_loss_fns
|
38 |
+
else:
|
39 |
+
self.prior_loss_w = prior_loss_w
|
40 |
+
self.prior_loss_w = np.log(self.prior_loss_w)
|
41 |
+
|
42 |
+
if len(manual_loss_w) == 1:
|
43 |
+
self.manual_loss_w = manual_loss_w * num_loss_fns
|
44 |
+
else:
|
45 |
+
self.manual_loss_w = manual_loss_w
|
46 |
+
|
47 |
+
self.num_reg = num_reg_fns
|
48 |
+
if self.num_reg != 0:
|
49 |
+
if len(reg_fns) == 1 and self.num_reg > 1:
|
50 |
+
self.reg_fns = reg_fns * self.num_reg
|
51 |
+
else:
|
52 |
+
self.reg_fns = reg_fns
|
53 |
+
|
54 |
+
self.is_placeholder = True
|
55 |
+
if self.num_reg != 0:
|
56 |
+
if len(prior_reg_w) == 1:
|
57 |
+
self.prior_reg_w = prior_reg_w * num_reg_fns
|
58 |
+
else:
|
59 |
+
self.prior_reg_w = prior_reg_w
|
60 |
+
self.prior_reg_w = np.log(self.prior_reg_w)
|
61 |
+
|
62 |
+
if len(manual_reg_w) == 1:
|
63 |
+
self.manual_reg_w = manual_reg_w * num_reg_fns
|
64 |
+
else:
|
65 |
+
self.manual_reg_w = manual_reg_w
|
66 |
+
|
67 |
+
else:
|
68 |
+
self.prior_reg_w = list()
|
69 |
+
self.manual_reg_w = list()
|
70 |
+
|
71 |
+
super(UncertaintyWeighting, self).__init__(**kwargs)
|
72 |
+
|
73 |
+
def build(self, input_shape=None):
|
74 |
+
self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
|
75 |
+
initializer=tf.keras.initializers.Constant(self.prior_loss_w),
|
76 |
+
trainable=True)
|
77 |
+
self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
|
78 |
+
|
79 |
+
if self.num_reg != 0:
|
80 |
+
self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
|
81 |
+
initializer=tf.keras.initializers.Constant(self.prior_reg_w),
|
82 |
+
trainable=True)
|
83 |
+
if self.num_reg == 1:
|
84 |
+
self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
|
85 |
+
else:
|
86 |
+
self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
|
87 |
+
|
88 |
+
super(UncertaintyWeighting, self).build(input_shape)
|
89 |
+
|
90 |
+
def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
|
91 |
+
loss_values = list()
|
92 |
+
loss_names_loss = list()
|
93 |
+
loss_names_reg = list()
|
94 |
+
|
95 |
+
for y_true, y_pred, loss_fn, man_w in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w):
|
96 |
+
loss_values.append(tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
|
97 |
+
loss_names_loss.append(loss_fn.__name__)
|
98 |
+
|
99 |
+
loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
|
100 |
+
loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
|
101 |
+
|
102 |
+
if self.num_reg != 0:
|
103 |
+
loss_reg = list()
|
104 |
+
for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
|
105 |
+
loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
|
106 |
+
loss_names_reg.append(reg_fn.__name__)
|
107 |
+
|
108 |
+
reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
|
109 |
+
loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
|
110 |
+
|
111 |
+
for i, loss_name in enumerate(loss_names_loss):
|
112 |
+
self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
|
113 |
+
aggregation='mean')
|
114 |
+
self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
|
115 |
+
aggregation='mean')
|
116 |
+
if self.num_reg != 0:
|
117 |
+
for i, loss_name in enumerate(loss_names_reg):
|
118 |
+
self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
|
119 |
+
aggregation='mean')
|
120 |
+
self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
|
121 |
+
aggregation='mean')
|
122 |
+
|
123 |
+
return K.sum(loss)
|
124 |
+
|
125 |
+
def call(self, inputs):
|
126 |
+
ys_true = inputs[:self.num_loss]
|
127 |
+
ys_pred = inputs[self.num_loss:self.num_loss*2]
|
128 |
+
reg_true = inputs[-self.num_reg*2:-self.num_reg]
|
129 |
+
reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
|
130 |
+
loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
|
131 |
+
self.add_loss(loss, inputs=inputs)
|
132 |
+
# We won't actually use the output, but we need something for the TF graph
|
133 |
+
return K.concatenate(inputs, -1)
|
134 |
+
|
135 |
+
def get_config(self):
|
136 |
+
base_config = super(UncertaintyWeighting, self).get_config()
|
137 |
+
base_config['num_loss_fns'] = self.num_loss
|
138 |
+
base_config['num_reg_fns'] = self.num_reg
|
139 |
+
|
140 |
+
return base_config
|
141 |
+
|
142 |
+
|
143 |
+
class UncertaintyWeightingWithRollingAverage(kl.Layer):
|
144 |
+
def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
|
145 |
+
reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
|
146 |
+
roll_avg_reference=0, # position in loss_fns of the reference loss function for the rolling avg
|
147 |
+
**kwargs):
|
148 |
+
assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
|
149 |
+
assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
|
150 |
+
# Rolling average attributes
|
151 |
+
self.ref_loss = roll_avg_reference
|
152 |
+
self.compute_roll_avg = False # Toogle between computing the average of the losses or updating a know average
|
153 |
+
self.scale_factor = [1.] * num_loss_fns
|
154 |
+
self.n = 0 # Number of viewed samples
|
155 |
+
self.temp_storage = [0.] * num_loss_fns
|
156 |
+
|
157 |
+
self.num_loss = num_loss_fns
|
158 |
+
if len(loss_fns) == 1 and self.num_loss > 1:
|
159 |
+
self.loss_fns = loss_fns * self.num_loss
|
160 |
+
else:
|
161 |
+
self.loss_fns = loss_fns
|
162 |
+
|
163 |
+
if len(prior_loss_w) == 1:
|
164 |
+
self.prior_loss_w = prior_loss_w * num_loss_fns
|
165 |
+
else:
|
166 |
+
self.prior_loss_w = prior_loss_w
|
167 |
+
self.prior_loss_w = np.log(self.prior_loss_w)
|
168 |
+
|
169 |
+
if len(manual_loss_w) == 1:
|
170 |
+
self.manual_loss_w = manual_loss_w * num_loss_fns
|
171 |
+
else:
|
172 |
+
self.manual_loss_w = manual_loss_w
|
173 |
+
|
174 |
+
self.num_reg = num_reg_fns
|
175 |
+
if self.num_reg != 0:
|
176 |
+
if len(reg_fns) == 1 and self.num_reg > 1:
|
177 |
+
self.reg_fns = reg_fns * self.num_reg
|
178 |
+
else:
|
179 |
+
self.reg_fns = reg_fns
|
180 |
+
|
181 |
+
self.is_placeholder = True
|
182 |
+
if self.num_reg != 0:
|
183 |
+
if len(prior_reg_w) == 1:
|
184 |
+
self.prior_reg_w = prior_reg_w * num_reg_fns
|
185 |
+
else:
|
186 |
+
self.prior_reg_w = prior_reg_w
|
187 |
+
self.prior_reg_w = np.log(self.prior_reg_w)
|
188 |
+
|
189 |
+
if len(manual_reg_w) == 1:
|
190 |
+
self.manual_reg_w = manual_reg_w * num_reg_fns
|
191 |
+
else:
|
192 |
+
self.manual_reg_w = manual_reg_w
|
193 |
+
|
194 |
+
else:
|
195 |
+
self.prior_reg_w = list()
|
196 |
+
self.manual_reg_w = list()
|
197 |
+
|
198 |
+
super(UncertaintyWeightingWithRollingAverage, self).__init__(**kwargs)
|
199 |
+
|
200 |
+
def build(self, input_shape=None):
|
201 |
+
self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
|
202 |
+
initializer=tf.keras.initializers.Constant(self.prior_loss_w),
|
203 |
+
trainable=True)
|
204 |
+
self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
|
205 |
+
|
206 |
+
if self.num_reg != 0:
|
207 |
+
self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
|
208 |
+
initializer=tf.keras.initializers.Constant(self.prior_reg_w),
|
209 |
+
trainable=True)
|
210 |
+
if self.num_reg == 1:
|
211 |
+
self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
|
212 |
+
else:
|
213 |
+
self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
|
214 |
+
|
215 |
+
super(UncertaintyWeightingWithRollingAverage, self).build(input_shape)
|
216 |
+
|
217 |
+
def store_values(self, new_loss_values):
|
218 |
+
for i, (t, v) in enumerate(zip(self.temp_storage, new_loss_values)):
|
219 |
+
self.temp_storage[i] = t + v
|
220 |
+
self.n += 1
|
221 |
+
|
222 |
+
def compute_scale_factors(self):
|
223 |
+
for i, val in enumerate(self.temp_storage):
|
224 |
+
self.scale_factor[i] = self.n / val # 1/avg
|
225 |
+
|
226 |
+
self.scale_factor[self.ref_loss] = 1.
|
227 |
+
|
228 |
+
self.temp_storage = [0.] * self.num_loss
|
229 |
+
self.n = 0
|
230 |
+
|
231 |
+
@property
|
232 |
+
def ref_on_epoch_end_function(self):
|
233 |
+
return self.compute_scale_factors
|
234 |
+
|
235 |
+
def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
|
236 |
+
loss_values = list()
|
237 |
+
loss_names_loss = list()
|
238 |
+
loss_names_reg = list()
|
239 |
+
|
240 |
+
for y_true, y_pred, loss_fn, man_w, sf in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w, self.scale_factor):
|
241 |
+
loss_values.append(sf * tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
|
242 |
+
loss_names_loss.append(loss_fn.__name__)
|
243 |
+
|
244 |
+
self.store_values(loss_values)
|
245 |
+
loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
|
246 |
+
loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
|
247 |
+
|
248 |
+
if self.num_reg != 0:
|
249 |
+
loss_reg = list()
|
250 |
+
for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
|
251 |
+
loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
|
252 |
+
loss_names_reg.append(reg_fn.__name__)
|
253 |
+
|
254 |
+
reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
|
255 |
+
loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
|
256 |
+
|
257 |
+
for i, loss_name in enumerate(loss_names_loss):
|
258 |
+
self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
|
259 |
+
aggregation='mean')
|
260 |
+
self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
|
261 |
+
aggregation='mean')
|
262 |
+
if self.num_reg != 0:
|
263 |
+
for i, loss_name in enumerate(loss_names_reg):
|
264 |
+
self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
|
265 |
+
aggregation='mean')
|
266 |
+
self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
|
267 |
+
aggregation='mean')
|
268 |
+
sc_tf = tf.convert_to_tensor(self.scale_factor, dtype=tf.float32, name='scale_factors_tf')
|
269 |
+
self.add_metric(tf.slice(sc_tf, [i], [1]), name='SCALE_FACTOR_{}_{}'.format(i, loss_name),
|
270 |
+
aggregation='mean')
|
271 |
+
|
272 |
+
return K.sum(loss)
|
273 |
+
|
274 |
+
def call(self, inputs):
|
275 |
+
ys_true = inputs[:self.num_loss]
|
276 |
+
ys_pred = inputs[self.num_loss:self.num_loss*2]
|
277 |
+
reg_true = inputs[-self.num_reg*2:-self.num_reg]
|
278 |
+
reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
|
279 |
+
loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
|
280 |
+
self.add_loss(loss, inputs=inputs)
|
281 |
+
# We won't actually use the output, but we need something for the TF graph
|
282 |
+
return K.concatenate(inputs, -1)
|
283 |
+
|
284 |
+
def get_config(self):
|
285 |
+
base_config = super(UncertaintyWeighting, self).get_config()
|
286 |
+
base_config['num_loss_fns'] = self.num_loss
|
287 |
+
base_config['num_reg_fns'] = self.num_reg
|
288 |
+
|
289 |
+
return base_config
|
290 |
+
|
291 |
+
|
292 |
+
def distance_map(coord1, coord2, dist, img_shape_w_channel=(64, 64, 1)):
|
293 |
+
max_dist = np.max(img_shape_w_channel)
|
294 |
+
dm_p = np.ones(img_shape_w_channel, np.float32)*max_dist
|
295 |
+
dm_n = np.ones(img_shape_w_channel, np.float32)*max_dist
|
296 |
+
|
297 |
+
for c1, c2, d in zip(coord1, coord2, dist):
|
298 |
+
dm_p[c1, c2, 0] = d if dm_p[c1, c2, 0] > d else dm_p[c1, c2]
|
299 |
+
d_n = 64. - max_dist
|
300 |
+
dm_n[c1, c2, 0] = d_n if dm_n[c1, c2, 0] > d_n else dm_n[c1, c2]
|
301 |
+
|
302 |
+
return dm_p/max_dist, dm_n/max_dist
|
303 |
+
|
304 |
+
|
305 |
+
def volume_to_ov_and_dm(in_volume: tf.Tensor):
|
306 |
+
# This one is run as a preprocessing step
|
307 |
+
def get_ov_projections_and_dm(volume):
|
308 |
+
# tf.sign returns -1, 0, 1 depending on the sign of the elements of the input (negative, zero, positive)
|
309 |
+
i, j, k, c = tf.where(volume > 0.0)
|
310 |
+
top = tf.sign(tf.reduce_sum(volume, axis=0), name='ov_top')
|
311 |
+
right = tf.sign(tf.reduce_sum(volume, axis=1), name='ov_right')
|
312 |
+
front = tf.sign(tf.reduce_sum(volume, axis=2), name='ov_front')
|
313 |
+
|
314 |
+
top_p, top_n = tf.py_func(distance_map, [j, k, i], tf.float32)
|
315 |
+
right_p, right_n = tf.py_func(distance_map, [i, k, j], tf.float32)
|
316 |
+
front_p, front_n = tf.py_func(distance_map, [i, j, k], tf.float32)
|
317 |
+
|
318 |
+
return [front, right, top], [front_p, front_n, top_p, top_n, right_p, right_n]
|
319 |
+
|
320 |
+
if len(in_volume.shape.as_list()) > 4:
|
321 |
+
return tf.map_fn(get_ov_projections_and_dm, in_volume, [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32])
|
322 |
+
else:
|
323 |
+
return get_ov_projections_and_dm(in_volume)
|
324 |
+
|
325 |
+
|
326 |
+
def ov_and_dm_to_volume(ov_projections):
|
327 |
+
front, right, top = ov_projections
|
328 |
+
|
329 |
+
def get_volume(front: tf.Tensor, right: tf.Tensor, top: tf.Tensor):
|
330 |
+
front_shape = front.shape.as_list() # Assume (H, W, C)
|
331 |
+
top_shape = top.shape.as_list()
|
332 |
+
|
333 |
+
front_vol = tf.tile(tf.expand_dims(front, 2), [1, 1, top_shape[0], 1])
|
334 |
+
right_vol = tf.tile(tf.expand_dims(right, 1), [1, front_shape[1], 1, 1])
|
335 |
+
top_vol = tf.tile(tf.expand_dims(top, 0), [front_shape[0], 1, 1, 1])
|
336 |
+
sum = tf.add(tf.add(front_vol, right_vol), top_vol)
|
337 |
+
return soft_threshold(sum, 2., 'get_volume')
|
338 |
+
|
339 |
+
if len(front.shape.as_list()) > 3:
|
340 |
+
return tf.map_fn(lambda x: get_volume(x[0], x[1], x[2]), ov_projections, tf.float32)
|
341 |
+
else:
|
342 |
+
return get_volume(front, right, top)
|
343 |
+
|
344 |
+
# TODO: Recovering the coordinates from the distance maps to prevent artifacts
|
345 |
+
# will the gradients be backpropagated??!?!!?!?!
|
346 |
+
|
347 |
+
|
DeepDeformationMapRegistration/layers/upsampling.py
ADDED
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
from numpy import (zeros, where, diff, floor, minimum, maximum, array, concatenate, logical_or, logical_xor,
|
4 |
+
sqrt)
|
5 |
+
from tensorflow.python.framework import tensor_shape
|
6 |
+
from tensorflow.python.keras.utils import conv_utils
|
7 |
+
from tensorflow.python.keras.engine.base_layer import Layer
|
8 |
+
from tensorflow.python.keras.engine.input_spec import InputSpec
|
9 |
+
from tensorflow.python.util.tf_export import keras_export # api_export
|
10 |
+
# SRC: https://github.com/tensorflow/tensorflow/issues/46609
|
11 |
+
# import functools
|
12 |
+
|
13 |
+
# keras_export = functools.partial(api_export, 'keras') # keras_export is not defined in 1.13 but in 1.15 --> https://github.com/tensorflow/tensorflow/blob/3d6e4f24e32b5dbe0a83aaa6e9d0f6671ba41da8/tensorflow/python/util/tf_export.py
|
14 |
+
|
15 |
+
def linear_interpolate(x_fix, y_fix, x_var):
|
16 |
+
'''
|
17 |
+
Functionality:
|
18 |
+
1D linear interpolation
|
19 |
+
Author:
|
20 |
+
Michael Osthege
|
21 |
+
Link:
|
22 |
+
https://gist.github.com/michaelosthege/e20d242bc62a434843b586c78ebce6cc
|
23 |
+
'''
|
24 |
+
|
25 |
+
x_repeat = tf.tile(x_var[:, None], (len(x_fix), ))
|
26 |
+
distances = tf.abs(x_repeat - x_fix)
|
27 |
+
|
28 |
+
x_indices = tf.searchsorted(x_fix, x_var)
|
29 |
+
|
30 |
+
weights = tf.zeros_like(distances)
|
31 |
+
idx = tf.arange(len(x_indices))
|
32 |
+
weights[idx, x_indices] = distances[idx, x_indices - 1]
|
33 |
+
weights[idx, x_indices - 1] = distances[idx, x_indices]
|
34 |
+
weights /= np.sum(weights, axis=1)[:, None]
|
35 |
+
|
36 |
+
y_var = np.dot(weights, y_fix.T)
|
37 |
+
|
38 |
+
return y_var
|
39 |
+
|
40 |
+
|
41 |
+
def cubic_interpolate(x, y, x0):
|
42 |
+
'''
|
43 |
+
Functionliaty:
|
44 |
+
1D cubic spline interpolation
|
45 |
+
Author:
|
46 |
+
Raphael Valentin
|
47 |
+
Link:
|
48 |
+
https://stackoverflow.com/questions/31543775/how-to-perform-cubic-spline-interpolation-in-python
|
49 |
+
'''
|
50 |
+
|
51 |
+
x = np.asfarray(x)
|
52 |
+
y = np.asfarray(y)
|
53 |
+
|
54 |
+
# remove non finite values
|
55 |
+
# indexes = np.isfinite(x)
|
56 |
+
# x = x[indexes]
|
57 |
+
# y = y[indexes]
|
58 |
+
|
59 |
+
# check if sorted
|
60 |
+
if np.any(np.diff(x) < 0):
|
61 |
+
indexes = np.argsort(x)
|
62 |
+
x = x[indexes]
|
63 |
+
y = y[indexes]
|
64 |
+
|
65 |
+
size = len(x)
|
66 |
+
|
67 |
+
xdiff = np.diff(x)
|
68 |
+
ydiff = np.diff(y)
|
69 |
+
|
70 |
+
# allocate buffer matrices
|
71 |
+
Li = np.empty(size)
|
72 |
+
Li_1 = np.empty(size - 1)
|
73 |
+
z = np.empty(size)
|
74 |
+
|
75 |
+
# fill diagonals Li and Li-1 and solve [L][y] = [B]
|
76 |
+
Li[0] = sqrt(2 * xdiff[0])
|
77 |
+
Li_1[0] = 0.0
|
78 |
+
B0 = 0.0 # natural boundary
|
79 |
+
z[0] = B0 / Li[0]
|
80 |
+
|
81 |
+
for i in range(1, size - 1, 1):
|
82 |
+
Li_1[i] = xdiff[i - 1] / Li[i - 1]
|
83 |
+
Li[i] = sqrt(2 * (xdiff[i - 1] + xdiff[i]) - Li_1[i - 1] * Li_1[i - 1])
|
84 |
+
Bi = 6 * (ydiff[i] / xdiff[i] - ydiff[i - 1] / xdiff[i - 1])
|
85 |
+
z[i] = (Bi - Li_1[i - 1] * z[i - 1]) / Li[i]
|
86 |
+
|
87 |
+
i = size - 1
|
88 |
+
Li_1[i - 1] = xdiff[-1] / Li[i - 1]
|
89 |
+
Li[i] = sqrt(2 * xdiff[-1] - Li_1[i - 1] * Li_1[i - 1])
|
90 |
+
Bi = 0.0 # natural boundary
|
91 |
+
z[i] = (Bi - Li_1[i - 1] * z[i - 1]) / Li[i]
|
92 |
+
|
93 |
+
# solve [L.T][x] = [y]
|
94 |
+
i = size - 1
|
95 |
+
z[i] = z[i] / Li[i]
|
96 |
+
for i in range(size - 2, -1, -1):
|
97 |
+
z[i] = (z[i] - Li_1[i - 1] * z[i + 1]) / Li[i]
|
98 |
+
|
99 |
+
# find index
|
100 |
+
index = x.searchsorted(x0)
|
101 |
+
np.clip(index, 1, size - 1, index)
|
102 |
+
|
103 |
+
xi1, xi0 = x[index], x[index - 1]
|
104 |
+
yi1, yi0 = y[index], y[index - 1]
|
105 |
+
zi1, zi0 = z[index], z[index - 1]
|
106 |
+
hi1 = xi1 - xi0
|
107 |
+
|
108 |
+
# calculate cubic
|
109 |
+
f0 = zi0/(6*hi1)*(xi1-x0)**3 + \
|
110 |
+
zi1/(6*hi1)*(x0-xi0)**3 + \
|
111 |
+
(yi1/hi1 - zi1*hi1/6)*(x0-xi0) + \
|
112 |
+
(yi0/hi1 - zi0*hi1/6)*(xi1-x0)
|
113 |
+
|
114 |
+
return f0
|
115 |
+
|
116 |
+
|
117 |
+
def pchip_interpolate(xi, yi, x, mode="mono", verbose=False):
|
118 |
+
'''
|
119 |
+
Functionality:
|
120 |
+
1D PCHP interpolation
|
121 |
+
Authors:
|
122 |
+
Michael Taylor <[email protected]>
|
123 |
+
Mathieu Virbel <[email protected]>
|
124 |
+
Link:
|
125 |
+
https://gist.github.com/tito/553f1135959921ce6699652bf656150d
|
126 |
+
'''
|
127 |
+
|
128 |
+
if mode not in ("mono", "quad"):
|
129 |
+
raise ValueError("Unrecognized mode string")
|
130 |
+
|
131 |
+
# Search for [xi,xi+1] interval for each x
|
132 |
+
xi = xi.astype("double")
|
133 |
+
yi = yi.astype("double")
|
134 |
+
|
135 |
+
x_index = zeros(len(x), dtype="int")
|
136 |
+
xi_steps = diff(xi)
|
137 |
+
if not all(xi_steps > 0):
|
138 |
+
raise ValueError("x-coordinates are not in increasing order.")
|
139 |
+
|
140 |
+
x_steps = diff(x)
|
141 |
+
if xi_steps.max() / xi_steps.min() < 1.000001:
|
142 |
+
# uniform input grid
|
143 |
+
if verbose:
|
144 |
+
print("pchip: uniform input grid")
|
145 |
+
xi_start = xi[0]
|
146 |
+
xi_step = (xi[-1] - xi[0]) / (len(xi) - 1)
|
147 |
+
x_index = minimum(maximum(floor((x - xi_start) / xi_step).astype(int), 0), len(xi) - 2)
|
148 |
+
|
149 |
+
# Calculate gradients d
|
150 |
+
h = (xi[-1] - xi[0]) / (len(xi) - 1)
|
151 |
+
d = zeros(len(xi), dtype="double")
|
152 |
+
if mode == "quad":
|
153 |
+
# quadratic polynomial fit
|
154 |
+
d[[0]] = (yi[1] - yi[0]) / h
|
155 |
+
d[[-1]] = (yi[-1] - yi[-2]) / h
|
156 |
+
d[1:-1] = (yi[2:] - yi[0:-2]) / 2 / h
|
157 |
+
else:
|
158 |
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
159 |
+
# recipe
|
160 |
+
delta = diff(yi) / h
|
161 |
+
d = concatenate((delta[0:1], 2 / (1 / delta[0:-1] + 1 / delta[1:]), delta[-1:]))
|
162 |
+
d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
|
163 |
+
d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
|
164 |
+
(delta == 0, array([False]))))] = 0
|
165 |
+
# Calculate output values y
|
166 |
+
dxxi = x - xi[x_index]
|
167 |
+
dxxid = x - xi[1 + x_index]
|
168 |
+
dxxi2 = pow(dxxi, 2)
|
169 |
+
dxxid2 = pow(dxxid, 2)
|
170 |
+
y = (2 / pow(h, 3) * (yi[x_index] * dxxid2 * (dxxi + h / 2) - yi[1 + x_index] * dxxi2 *
|
171 |
+
(dxxid - h / 2)) + 1 / pow(h, 2) *
|
172 |
+
(d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
|
173 |
+
else:
|
174 |
+
# not uniform input grid
|
175 |
+
if (x_steps.max() / x_steps.min() < 1.000001 and x_steps.max() / x_steps.min() > 0.999999):
|
176 |
+
# non-uniform input grid, uniform output grid
|
177 |
+
if verbose:
|
178 |
+
print("pchip: non-uniform input grid, uniform output grid")
|
179 |
+
x_decreasing = x[-1] < x[0]
|
180 |
+
if x_decreasing:
|
181 |
+
x = x[::-1]
|
182 |
+
x_start = x[0]
|
183 |
+
x_step = (x[-1] - x[0]) / (len(x) - 1)
|
184 |
+
x_indexprev = -1
|
185 |
+
for xi_loop in range(len(xi) - 2):
|
186 |
+
x_indexcur = max(int(floor((xi[1 + xi_loop] - x_start) / x_step)), -1)
|
187 |
+
x_index[1 + x_indexprev:1 + x_indexcur] = xi_loop
|
188 |
+
x_indexprev = x_indexcur
|
189 |
+
x_index[1 + x_indexprev:] = len(xi) - 2
|
190 |
+
if x_decreasing:
|
191 |
+
x = x[::-1]
|
192 |
+
x_index = x_index[::-1]
|
193 |
+
elif all(x_steps > 0) or all(x_steps < 0):
|
194 |
+
# non-uniform input/output grids, output grid monotonic
|
195 |
+
if verbose:
|
196 |
+
print("pchip: non-uniform in/out grid, output grid monotonic")
|
197 |
+
x_decreasing = x[-1] < x[0]
|
198 |
+
if x_decreasing:
|
199 |
+
x = x[::-1]
|
200 |
+
x_len = len(x)
|
201 |
+
x_loop = 0
|
202 |
+
for xi_loop in range(len(xi) - 1):
|
203 |
+
while x_loop < x_len and x[x_loop] < xi[1 + xi_loop]:
|
204 |
+
x_index[x_loop] = xi_loop
|
205 |
+
x_loop += 1
|
206 |
+
x_index[x_loop:] = len(xi) - 2
|
207 |
+
if x_decreasing:
|
208 |
+
x = x[::-1]
|
209 |
+
x_index = x_index[::-1]
|
210 |
+
else:
|
211 |
+
# non-uniform input/output grids, output grid not monotonic
|
212 |
+
if verbose:
|
213 |
+
print("pchip: non-uniform in/out grids, " "output grid not monotonic")
|
214 |
+
for index in range(len(x)):
|
215 |
+
loc = where(x[index] < xi)[0]
|
216 |
+
if loc.size == 0:
|
217 |
+
x_index[index] = len(xi) - 2
|
218 |
+
elif loc[0] == 0:
|
219 |
+
x_index[index] = 0
|
220 |
+
else:
|
221 |
+
x_index[index] = loc[0] - 1
|
222 |
+
# Calculate gradients d
|
223 |
+
h = diff(xi)
|
224 |
+
d = zeros(len(xi), dtype="double")
|
225 |
+
delta = diff(yi) / h
|
226 |
+
if mode == "quad":
|
227 |
+
# quadratic polynomial fit
|
228 |
+
d[[0, -1]] = delta[[0, -1]]
|
229 |
+
d[1:-1] = (delta[1:] * h[0:-1] + delta[0:-1] * h[1:]) / (h[0:-1] + h[1:])
|
230 |
+
else:
|
231 |
+
# mode=='mono', Fritsch-Carlson algorithm from fortran numerical
|
232 |
+
# recipe
|
233 |
+
d = concatenate(
|
234 |
+
(delta[0:1], 3 * (h[0:-1] + h[1:]) / ((h[0:-1] + 2 * h[1:]) / delta[0:-1] +
|
235 |
+
(2 * h[0:-1] + h[1:]) / delta[1:]), delta[-1:]))
|
236 |
+
d[concatenate((array([False]), logical_xor(delta[0:-1] > 0, delta[1:] > 0), array([False])))] = 0
|
237 |
+
d[logical_or(concatenate((array([False]), delta == 0)), concatenate(
|
238 |
+
(delta == 0, array([False]))))] = 0
|
239 |
+
dxxi = x - xi[x_index]
|
240 |
+
dxxid = x - xi[1 + x_index]
|
241 |
+
dxxi2 = pow(dxxi, 2)
|
242 |
+
dxxid2 = pow(dxxid, 2)
|
243 |
+
y = (2 / pow(h[x_index], 3) *
|
244 |
+
(yi[x_index] * dxxid2 * (dxxi + h[x_index] / 2) - yi[1 + x_index] * dxxi2 *
|
245 |
+
(dxxid - h[x_index] / 2)) + 1 / pow(h[x_index], 2) *
|
246 |
+
(d[x_index] * dxxid2 * dxxi + d[1 + x_index] * dxxi2 * dxxid))
|
247 |
+
return y
|
248 |
+
|
249 |
+
|
250 |
+
def Interpolate1D(x, y, xx, method='nearest'):
|
251 |
+
'''
|
252 |
+
Functionality:
|
253 |
+
1D interpolation with various methods
|
254 |
+
Author:
|
255 |
+
Kai Gao <[email protected]>
|
256 |
+
'''
|
257 |
+
|
258 |
+
n = len(x)
|
259 |
+
nn = len(xx)
|
260 |
+
yy = np.zeros(nn)
|
261 |
+
|
262 |
+
# Nearest neighbour interpolation
|
263 |
+
if method == 'nearest':
|
264 |
+
for i in range(0, nn):
|
265 |
+
xi = tf.argmin(tf.abs(xx[i] - x))
|
266 |
+
yy[i] = y[xi]
|
267 |
+
|
268 |
+
# Linear interpolation
|
269 |
+
elif method == 'linear':
|
270 |
+
|
271 |
+
# # slower version
|
272 |
+
# if n == 1:
|
273 |
+
# yy[:-1] = y[0]
|
274 |
+
|
275 |
+
# else:
|
276 |
+
# for i in range(0, nn):
|
277 |
+
|
278 |
+
# if xx[i] < x[0]:
|
279 |
+
# t = (xx[i] - x[0]) / (x[1] - x[0])
|
280 |
+
# yy[i] = (1.0 - t) * y[0] + t * y[1]
|
281 |
+
|
282 |
+
# elif x[n - 1] <= xx[i]:
|
283 |
+
# t = (xx[i] - x[n - 2]) / (x[n - 1] - x[n - 2])
|
284 |
+
# yy[i] = (1.0 - t) * y[n - 2] + t * y[n - 1]
|
285 |
+
|
286 |
+
# else:
|
287 |
+
# for k in range(1, n):
|
288 |
+
# if x[k - 1] <= xx[i] and xx[i] < x[k]:
|
289 |
+
# t = (xx[i] - x[k - 1]) / (x[k] - x[k - 1])
|
290 |
+
# yy[i] = (1.0 - t) * y[k - 1] + t * y[k]
|
291 |
+
# break
|
292 |
+
|
293 |
+
# # faster version
|
294 |
+
yy = linear_interpolate(x, y, xx)
|
295 |
+
|
296 |
+
# Cubic interpolation
|
297 |
+
elif method == 'cubic':
|
298 |
+
yy = cubic_interpolate(x, y, xx)
|
299 |
+
|
300 |
+
# Piecewise cubic Hermite interpolating polynomial (PCHIP)
|
301 |
+
elif method == 'pchip':
|
302 |
+
yy = pchip_interpolate(x, y, xx, mode='mono')
|
303 |
+
|
304 |
+
return yy
|
305 |
+
|
306 |
+
|
307 |
+
def Interpolate2D(x, y, f, xx, yy, method='nearest'):
|
308 |
+
'''
|
309 |
+
Functionality:
|
310 |
+
2D interpolation implemented in a separable fashion
|
311 |
+
There are methods that do real 2D non-separable interpolation, which are
|
312 |
+
more difficult to implement.
|
313 |
+
Author:
|
314 |
+
Kai Gao <[email protected]>
|
315 |
+
'''
|
316 |
+
|
317 |
+
n1 = len(x)
|
318 |
+
n2 = len(y)
|
319 |
+
nn1 = len(xx)
|
320 |
+
nn2 = len(yy)
|
321 |
+
|
322 |
+
w = np.zeros((nn1, n2))
|
323 |
+
ff = np.zeros((nn1, nn2))
|
324 |
+
|
325 |
+
# Interpolate along the 1st dimension
|
326 |
+
for j in range(0, n2):
|
327 |
+
w[:, j] = Interpolate1D(x, f[:, j], xx, method)
|
328 |
+
|
329 |
+
# Interpolate along the 2nd dimension
|
330 |
+
for i in range(0, nn1):
|
331 |
+
ff[i, :] = Interpolate1D(y, w[i, :], yy, method)
|
332 |
+
|
333 |
+
return ff
|
334 |
+
|
335 |
+
|
336 |
+
def Interpolate3D(x, y, z, f, xx, yy, zz, method='nearest'):
|
337 |
+
'''
|
338 |
+
Functionality:
|
339 |
+
3D interpolation implemented in a separable fashion
|
340 |
+
There are methods that do real 3D non-separable interpolation, which are
|
341 |
+
more difficult to implement.
|
342 |
+
Author:
|
343 |
+
Kai Gao <[email protected]>
|
344 |
+
'''
|
345 |
+
|
346 |
+
n1 = len(x)
|
347 |
+
n2 = len(y)
|
348 |
+
n3 = len(z)
|
349 |
+
nn1 = len(xx)
|
350 |
+
nn2 = len(yy)
|
351 |
+
nn3 = len(zz)
|
352 |
+
|
353 |
+
w1 = tf.zeros((nn1, n2, n3))
|
354 |
+
w2 = tf.zeros((nn1, nn2, n3))
|
355 |
+
ff = tf.zeros((nn1, nn2, nn3))
|
356 |
+
|
357 |
+
# Interpolate along the 1st dimension
|
358 |
+
for k in range(0, n3):
|
359 |
+
for j in range(0, n2):
|
360 |
+
w1[:, j, k] = Interpolate1D(x, f[:, j, k], xx, method)
|
361 |
+
|
362 |
+
# Interpolate along the 2nd dimension
|
363 |
+
for k in range(0, n3):
|
364 |
+
for i in range(0, nn1):
|
365 |
+
w2[i, :, k] = Interpolate1D(y, w1[i, :, k], yy, method)
|
366 |
+
|
367 |
+
# Interpolate along the 3rd dimension
|
368 |
+
for j in range(0, nn2):
|
369 |
+
for i in range(0, nn1):
|
370 |
+
ff[i, j, :] = Interpolate1D(z, w2[i, j, :], zz, method)
|
371 |
+
|
372 |
+
return ff
|
373 |
+
|
374 |
+
|
375 |
+
def UpInterpolate1D(x, size=2, interpolation='nearest', data_format='channels_first', align_corners=True):
|
376 |
+
'''
|
377 |
+
Functionality:
|
378 |
+
1D upsampling interpolation for tf
|
379 |
+
Author:
|
380 |
+
Kai Gao <[email protected]>
|
381 |
+
'''
|
382 |
+
|
383 |
+
x = x.numpy()
|
384 |
+
|
385 |
+
if data_format == 'channels_last':
|
386 |
+
nb, nr, nh = x.shape
|
387 |
+
elif data_format == 'channels_first':
|
388 |
+
nb, nh, nr = x.shape
|
389 |
+
|
390 |
+
r = size
|
391 |
+
ir = np.linspace(0.0, nr - 1.0, num=nr)
|
392 |
+
|
393 |
+
if align_corners:
|
394 |
+
# align_corners=True assumes that values are sampled at discrete points
|
395 |
+
iir = np.linspace(0.0, nr - 1.0, num=nr * r)
|
396 |
+
else:
|
397 |
+
# aling_corners=False assumes that values are sampled at centers of discrete blocks
|
398 |
+
iir = np.linspace(0.0 - 0.5 + 0.5 / r, nr - 1.0 + 0.5 - 0.5 / r, num=nr * r)
|
399 |
+
iir = np.clip(iir, 0.0, nr - 1.0)
|
400 |
+
|
401 |
+
if data_format == 'channels_last':
|
402 |
+
xx = np.zeros((nb, nr * r, nh))
|
403 |
+
for i in range(0, nb):
|
404 |
+
for j in range(0, nh):
|
405 |
+
t = np.reshape(x[i, :, j], (nr))
|
406 |
+
xx[i, :, j] = Interpolate1D(ir, t, iir, interpolation)
|
407 |
+
|
408 |
+
elif data_format == 'channels_first':
|
409 |
+
xx = np.zeros((nb, nh, nr * r))
|
410 |
+
for i in range(0, nb):
|
411 |
+
for j in range(0, nh):
|
412 |
+
t = np.reshape(x[i, j, :], (nr))
|
413 |
+
xx[i, j, :] = Interpolate1D(ir, t, iir, interpolation)
|
414 |
+
|
415 |
+
return tf.convert_to_tensor(xx, dtype=x.dtype)
|
416 |
+
|
417 |
+
|
418 |
+
def UpInterpolate2D(x,
|
419 |
+
size=(2, 2),
|
420 |
+
interpolation='nearest',
|
421 |
+
data_format='channels_first',
|
422 |
+
align_corners=True):
|
423 |
+
'''
|
424 |
+
Functionality:
|
425 |
+
2D upsampling interpolation for tf
|
426 |
+
Author:
|
427 |
+
Kai Gao <[email protected]>
|
428 |
+
'''
|
429 |
+
|
430 |
+
x = x.numpy()
|
431 |
+
|
432 |
+
if data_format == 'channels_last':
|
433 |
+
nb, nr, nc, nh = x.shape
|
434 |
+
elif data_format == 'channels_first':
|
435 |
+
nb, nh, nr, nc = x.shape
|
436 |
+
|
437 |
+
r = size[0]
|
438 |
+
c = size[1]
|
439 |
+
ir = np.linspace(0.0, nr - 1.0, num=nr)
|
440 |
+
ic = np.linspace(0.0, nc - 1.0, num=nc)
|
441 |
+
|
442 |
+
if align_corners:
|
443 |
+
# align_corners=True assumes that values are sampled at discrete points
|
444 |
+
iir = np.linspace(0.0, nr - 1.0, num=nr * r)
|
445 |
+
iic = np.linspace(0.0, nc - 1.0, num=nc * c)
|
446 |
+
else:
|
447 |
+
# aling_corners=False assumes that values are sampled at centers of discrete blocks
|
448 |
+
iir = np.linspace(0.0 - 0.5 + 0.5 / r, nr - 1.0 + 0.5 - 0.5 / r, num=nr * r)
|
449 |
+
iic = np.linspace(0.0 - 0.5 + 0.5 / c, nc - 1.0 + 0.5 - 0.5 / c, num=nc * c)
|
450 |
+
iir = np.clip(iir, 0.0, nr - 1.0)
|
451 |
+
iic = np.clip(iic, 0.0, nc - 1.0)
|
452 |
+
|
453 |
+
if data_format == 'channels_last':
|
454 |
+
xx = np.zeros((nb, nr * r, nc * c, nh))
|
455 |
+
for i in range(0, nb):
|
456 |
+
for j in range(0, nh):
|
457 |
+
t = np.reshape(x[i, :, :, j], (nr, nc))
|
458 |
+
xx[i, :, :, j] = Interpolate2D(ir, ic, t, iir, iic, interpolation)
|
459 |
+
|
460 |
+
elif data_format == 'channels_first':
|
461 |
+
xx = np.zeros((nb, nh, nr * r, nc * c))
|
462 |
+
for i in range(0, nb):
|
463 |
+
for j in range(0, nh):
|
464 |
+
t = np.reshape(x[i, j, :, :], (nr, nc))
|
465 |
+
xx[i, j, :, :] = Interpolate2D(ir, ic, t, iir, iic, interpolation)
|
466 |
+
|
467 |
+
return tf.convert_to_tensor(xx, dtype=x.dtype)
|
468 |
+
|
469 |
+
|
470 |
+
def UpInterpolate3D(x,
|
471 |
+
size=(2, 2, 2),
|
472 |
+
interpolation='nearest',
|
473 |
+
data_format='channels_first',
|
474 |
+
align_corners=True):
|
475 |
+
'''
|
476 |
+
Functionality:
|
477 |
+
3D upsampling interpolation for tf
|
478 |
+
Author:
|
479 |
+
Kai Gao <[email protected]>
|
480 |
+
'''
|
481 |
+
|
482 |
+
# x = x.numpy()
|
483 |
+
|
484 |
+
if data_format == 'channels_last':
|
485 |
+
nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
|
486 |
+
elif data_format == 'channels_first':
|
487 |
+
nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
|
488 |
+
|
489 |
+
r = size[0]
|
490 |
+
c = size[1]
|
491 |
+
d = size[2]
|
492 |
+
ir = tf.linspace(0.0, nr - 1.0, num=nr)
|
493 |
+
ic = tf.linspace(0.0, nc - 1.0, num=nc)
|
494 |
+
id = tf.linspace(0.0, nd - 1.0, num=nd)
|
495 |
+
|
496 |
+
if align_corners:
|
497 |
+
# align_corners=True assumes that values are sampled at discrete points
|
498 |
+
iir = tf.linspace(0.0, nr - 1.0, num=nr * r)
|
499 |
+
iic = tf.linspace(0.0, nc - 1.0, num=nc * c)
|
500 |
+
iid = tf.linspace(0.0, nd - 1.0, num=nd * d)
|
501 |
+
else:
|
502 |
+
# aling_corners=False assumes that values are sampled at centers of discrete blocks
|
503 |
+
iir = tf.linspace(0.0 - 0.5 + 0.5 / r, nr - 1.0 + 0.5 - 0.5 / r, num=nr * r)
|
504 |
+
iic = tf.linspace(0.0 - 0.5 + 0.5 / c, nc - 1.0 + 0.5 - 0.5 / c, num=nc * c)
|
505 |
+
iid = tf.linspace(0.0 - 0.5 + 0.5 / d, nd - 1.0 + 0.5 - 0.5 / d, num=nd * d)
|
506 |
+
iir = tf.clip_by_value(iir, 0.0, nr - 1.0)
|
507 |
+
iic = tf.clip_by_value(iic, 0.0, nc - 1.0)
|
508 |
+
iid = tf.clip_by_value(iid, 0.0, nd - 1.0)
|
509 |
+
|
510 |
+
if data_format == 'channels_last':
|
511 |
+
xx = tf.zeros((nb, nr * r, nc * c, nd * d, nh))
|
512 |
+
for i in range(0, nb):
|
513 |
+
for j in range(0, nh):
|
514 |
+
t = tf.reshape(x[i, :, :, :, j], (nr, nc, nd))
|
515 |
+
xx[i, :, :, :, j] = Interpolate3D(ir, ic, id, t, iir, iic, iid, interpolation)
|
516 |
+
|
517 |
+
elif data_format == 'channels_first':
|
518 |
+
xx = tf.zeros((nb, nh, nr * r, nc * c, nd * d))
|
519 |
+
for i in range(0, nb):
|
520 |
+
for j in range(0, nh):
|
521 |
+
t = tf.reshape(x[i, j, :, :, :], (nr, nc, nd))
|
522 |
+
xx[i, j, :, :, :] = Interpolate3D(ir, ic, id, t, iir, iic, iid, interpolation)
|
523 |
+
|
524 |
+
return tf.convert_to_tensor(xx, dtype=x.dtype)
|
525 |
+
|
526 |
+
|
527 |
+
# ################################################################################
|
528 |
+
@keras_export('keras.layers.UpSampling1D')
|
529 |
+
class UpSampling1D(Layer):
|
530 |
+
"""Upsampling layer for 1D inputs.
|
531 |
+
Repeats each temporal step `size` times along the time axis.
|
532 |
+
Examples:
|
533 |
+
>>> input_shape = (2, 2, 3)
|
534 |
+
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
|
535 |
+
>>> print(x)
|
536 |
+
[[[ 0 1 2]
|
537 |
+
[ 3 4 5]]
|
538 |
+
[[ 6 7 8]
|
539 |
+
[ 9 10 11]]]
|
540 |
+
>>> y = tf.keras.layers.UpSampling1D(size=2)(x)
|
541 |
+
>>> print(y)
|
542 |
+
tf.Tensor(
|
543 |
+
[[[ 0 1 2]
|
544 |
+
[ 0 1 2]
|
545 |
+
[ 3 4 5]
|
546 |
+
[ 3 4 5]]
|
547 |
+
[[ 6 7 8]
|
548 |
+
[ 6 7 8]
|
549 |
+
[ 9 10 11]
|
550 |
+
[ 9 10 11]]], shape=(2, 4, 3), dtype=int64)
|
551 |
+
Args:
|
552 |
+
size: Integer. Upsampling factor.
|
553 |
+
Input shape:
|
554 |
+
3D tensor with shape: `(batch_size, steps, features)`.
|
555 |
+
Output shape:
|
556 |
+
3D tensor with shape: `(batch_size, upsampled_steps, features)`.
|
557 |
+
"""
|
558 |
+
def __init__(self, size=2, data_format='None', interpolation='nearest', align_corners=True, **kwargs):
|
559 |
+
super(UpSampling1D, self).__init__(**kwargs)
|
560 |
+
self.data_format = conv_utils.normalize_data_format(data_format)
|
561 |
+
self.size = int(size)
|
562 |
+
self.input_spec = InputSpec(ndim=3)
|
563 |
+
self.interpolation = interpolation
|
564 |
+
if self.interpolation not in {'nearest', 'linear', 'cubic', 'pchip'}:
|
565 |
+
raise ValueError('`interpolation` argument should be one of `"nearest"` '
|
566 |
+
'or `"linear"` '
|
567 |
+
'or `"cubic"` '
|
568 |
+
'or `"pchip"`.')
|
569 |
+
self.align_corners = align_corners
|
570 |
+
|
571 |
+
def compute_output_shape(self, input_shape):
|
572 |
+
input_shape = tf.TensorShape(input_shape).as_list()
|
573 |
+
size = self.size * input_shape[1] if input_shape[1] is not None else None
|
574 |
+
return tf.TensorShape([input_shape[0], size, input_shape[2]])
|
575 |
+
|
576 |
+
def call(self, inputs):
|
577 |
+
return UpInterpolate1D(inputs,
|
578 |
+
self.size,
|
579 |
+
data_format=self.data_format,
|
580 |
+
interpolation=self.interpolation,
|
581 |
+
align_corners=self.align_corners)
|
582 |
+
|
583 |
+
def get_config(self):
|
584 |
+
config = {'size': self.size}
|
585 |
+
base_config = super(UpSampling1D, self).get_config()
|
586 |
+
return dict(list(base_config.items()) + list(config.items()))
|
587 |
+
|
588 |
+
|
589 |
+
@keras_export('keras.layers.UpSampling2D')
|
590 |
+
class UpSampling2D(Layer):
|
591 |
+
"""Upsampling layer for 2D inputs.
|
592 |
+
Repeats the rows and columns of the data
|
593 |
+
by `size[0]` and `size[1]` respectively.
|
594 |
+
Examples:
|
595 |
+
>>> input_shape = (2, 2, 1, 3)
|
596 |
+
>>> x = np.arange(np.prod(input_shape)).reshape(input_shape)
|
597 |
+
>>> print(x)
|
598 |
+
[[[[ 0 1 2]]
|
599 |
+
[[ 3 4 5]]]
|
600 |
+
[[[ 6 7 8]]
|
601 |
+
[[ 9 10 11]]]]
|
602 |
+
>>> y = tf.keras.layers.UpSampling2D(size=(1, 2))(x)
|
603 |
+
>>> print(y)
|
604 |
+
tf.Tensor(
|
605 |
+
[[[[ 0 1 2]
|
606 |
+
[ 0 1 2]]
|
607 |
+
[[ 3 4 5]
|
608 |
+
[ 3 4 5]]]
|
609 |
+
[[[ 6 7 8]
|
610 |
+
[ 6 7 8]]
|
611 |
+
[[ 9 10 11]
|
612 |
+
[ 9 10 11]]]], shape=(2, 2, 2, 3), dtype=int64)
|
613 |
+
Args:
|
614 |
+
size: Int, or tuple of 2 integers.
|
615 |
+
The upsampling factors for rows and columns.
|
616 |
+
data_format: A string,
|
617 |
+
one of `channels_last` (default) or `channels_first`.
|
618 |
+
The ordering of the dimensions in the inputs.
|
619 |
+
`channels_last` corresponds to inputs with shape
|
620 |
+
`(batch_size, height, width, channels)` while `channels_first`
|
621 |
+
corresponds to inputs with shape
|
622 |
+
`(batch_size, channels, height, width)`.
|
623 |
+
It defaults to the `image_data_format` value found in your
|
624 |
+
Keras config file at `~/.keras/keras.json`.
|
625 |
+
If you never set it, then it will be "channels_last".
|
626 |
+
interpolation: A string, one of `nearest` or `bilinear`.
|
627 |
+
Input shape:
|
628 |
+
4D tensor with shape:
|
629 |
+
- If `data_format` is `"channels_last"`:
|
630 |
+
`(batch_size, rows, cols, channels)`
|
631 |
+
- If `data_format` is `"channels_first"`:
|
632 |
+
`(batch_size, channels, rows, cols)`
|
633 |
+
Output shape:
|
634 |
+
4D tensor with shape:
|
635 |
+
- If `data_format` is `"channels_last"`:
|
636 |
+
`(batch_size, upsampled_rows, upsampled_cols, channels)`
|
637 |
+
- If `data_format` is `"channels_first"`:
|
638 |
+
`(batch_size, channels, upsampled_rows, upsampled_cols)`
|
639 |
+
"""
|
640 |
+
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', align_corners=True, **kwargs):
|
641 |
+
super(UpSampling2D, self).__init__(**kwargs)
|
642 |
+
self.data_format = conv_utils.normalize_data_format(data_format)
|
643 |
+
self.size = conv_utils.normalize_tuple(size, 2, 'size')
|
644 |
+
self.input_spec = InputSpec(ndim=4)
|
645 |
+
self.interpolation = interpolation
|
646 |
+
if self.interpolation not in {'nearest', 'bilinear', 'linear', 'cubic', 'pchip'}:
|
647 |
+
raise ValueError('`interpolation` argument should be one of `"nearest"` '
|
648 |
+
'or `"bilinear"` '
|
649 |
+
'or `"linear"` '
|
650 |
+
'or `"cubic"` '
|
651 |
+
'or `"pchip"`.')
|
652 |
+
if self.interpolation == 'bilinear':
|
653 |
+
self.interpolation = 'linear'
|
654 |
+
self.align_corners = align_corners
|
655 |
+
|
656 |
+
def compute_output_shape(self, input_shape):
|
657 |
+
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
658 |
+
if self.data_format == 'channels_first':
|
659 |
+
height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
660 |
+
width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
661 |
+
return tensor_shape.TensorShape([input_shape[0], input_shape[1], height, width])
|
662 |
+
else:
|
663 |
+
height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
664 |
+
width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
665 |
+
return tensor_shape.TensorShape([input_shape[0], height, width, input_shape[3]])
|
666 |
+
|
667 |
+
def call(self, inputs):
|
668 |
+
return UpInterpolate2D(inputs,
|
669 |
+
self.size,
|
670 |
+
data_format=self.data_format,
|
671 |
+
interpolation=self.interpolation,
|
672 |
+
align_corners=self.align_corners)
|
673 |
+
|
674 |
+
def get_config(self):
|
675 |
+
config = {'size': self.size, 'data_format': self.data_format, 'interpolation': self.interpolation}
|
676 |
+
base_config = super(UpSampling2D, self).get_config()
|
677 |
+
return dict(list(base_config.items()) + list(config.items()))
|
678 |
+
|
679 |
+
|
680 |
+
@keras_export('keras.layers.UpSampling3D')
|
681 |
+
class UpSampling3D(Layer):
|
682 |
+
"""Upsampling layer for 3D inputs.
|
683 |
+
Repeats the 1st, 2nd and 3rd dimensions
|
684 |
+
of the data by `size[0]`, `size[1]` and `size[2]` respectively.
|
685 |
+
Examples:
|
686 |
+
>>> input_shape = (2, 1, 2, 1, 3)
|
687 |
+
>>> x = tf.constant(1, shape=input_shape)
|
688 |
+
>>> y = tf.keras.layers.UpSampling3D(size=2)(x)
|
689 |
+
>>> print(y.shape)
|
690 |
+
(2, 2, 4, 2, 3)
|
691 |
+
Args:
|
692 |
+
size: Int, or tuple of 3 integers.
|
693 |
+
The upsampling factors for dim1, dim2 and dim3.
|
694 |
+
data_format: A string,
|
695 |
+
one of `channels_last` (default) or `channels_first`.
|
696 |
+
The ordering of the dimensions in the inputs.
|
697 |
+
`channels_last` corresponds to inputs with shape
|
698 |
+
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
|
699 |
+
while `channels_first` corresponds to inputs with shape
|
700 |
+
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
|
701 |
+
It defaults to the `image_data_format` value found in your
|
702 |
+
Keras config file at `~/.keras/keras.json`.
|
703 |
+
If you never set it, then it will be "channels_last".
|
704 |
+
Input shape:
|
705 |
+
5D tensor with shape:
|
706 |
+
- If `data_format` is `"channels_last"`:
|
707 |
+
`(batch_size, dim1, dim2, dim3, channels)`
|
708 |
+
- If `data_format` is `"channels_first"`:
|
709 |
+
`(batch_size, channels, dim1, dim2, dim3)`
|
710 |
+
Output shape:
|
711 |
+
5D tensor with shape:
|
712 |
+
- If `data_format` is `"channels_last"`:
|
713 |
+
`(batch_size, upsampled_dim1, upsampled_dim2, upsampled_dim3, channels)`
|
714 |
+
- If `data_format` is `"channels_first"`:
|
715 |
+
`(batch_size, channels, upsampled_dim1, upsampled_dim2, upsampled_dim3)`
|
716 |
+
"""
|
717 |
+
def __init__(self,
|
718 |
+
size=(2, 2, 2),
|
719 |
+
data_format=None,
|
720 |
+
interpolation='nearest',
|
721 |
+
align_corners=True,
|
722 |
+
**kwargs):
|
723 |
+
super(UpSampling3D, self).__init__(**kwargs)
|
724 |
+
self.data_format = conv_utils.normalize_data_format(data_format)
|
725 |
+
self.size = conv_utils.normalize_tuple(size, 3, 'size')
|
726 |
+
self.input_spec = InputSpec(ndim=5)
|
727 |
+
self.interpolation = interpolation
|
728 |
+
if interpolation not in {'nearest', 'trilinear', 'linear', 'cubic', 'pchip'}:
|
729 |
+
raise ValueError('`interpolation` argument should be one of `"nearest"` '
|
730 |
+
'or `"trilinear"` '
|
731 |
+
'or `"linear"` '
|
732 |
+
'or `"cubic"` '
|
733 |
+
'or `"pchip"`.')
|
734 |
+
if self.interpolation == 'trilinear':
|
735 |
+
self.interpolation = 'linear'
|
736 |
+
self.align_corners = align_corners
|
737 |
+
|
738 |
+
def compute_output_shape(self, input_shape):
|
739 |
+
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
740 |
+
if self.data_format == 'channels_first':
|
741 |
+
dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None
|
742 |
+
dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None
|
743 |
+
dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None
|
744 |
+
return tensor_shape.TensorShape([input_shape[0], input_shape[1], dim1, dim2, dim3])
|
745 |
+
else:
|
746 |
+
dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
|
747 |
+
dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
|
748 |
+
dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
|
749 |
+
return tensor_shape.TensorShape([input_shape[0], dim1, dim2, dim3, input_shape[4]])
|
750 |
+
|
751 |
+
def call(self, inputs):
|
752 |
+
return UpInterpolate3D(inputs,
|
753 |
+
self.size,
|
754 |
+
data_format=self.data_format,
|
755 |
+
interpolation=self.interpolation,
|
756 |
+
align_corners=self.align_corners)
|
757 |
+
|
758 |
+
def get_config(self):
|
759 |
+
config = {'size': self.size, 'data_format': self.data_format}
|
760 |
+
base_config = super(UpSampling3D, self).get_config()
|
761 |
+
return dict(list(base_config.items()) + list(config.items()))
|
DeepDeformationMapRegistration/losses.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
import tensorflow as tf
|
|
|
2 |
from scipy.ndimage import generate_binary_structure
|
|
|
3 |
|
4 |
-
from DeepDeformationMapRegistration.utils.operators import soft_threshold
|
5 |
from DeepDeformationMapRegistration.utils.constants import EPS_tf
|
|
|
6 |
|
|
|
|
|
7 |
|
8 |
class HausdorffDistanceErosion:
|
9 |
-
def __init__(self, ndim=3, nerosion=10,
|
10 |
"""
|
11 |
Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
|
12 |
Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
|
@@ -14,50 +19,222 @@ class HausdorffDistanceErosion:
|
|
14 |
|
15 |
:param ndim: Dimensionality of the images
|
16 |
:param nerosion: Number of erosion steps. Defaults to 10.
|
17 |
-
:param
|
18 |
"""
|
|
|
19 |
self.ndims = ndim
|
20 |
-
|
|
|
|
|
21 |
self.nerosions = nerosion
|
22 |
-
self.sum_range = tf.range(0, self.ndims)
|
23 |
-
|
24 |
-
def _erode(self, in_tensor, kernel):
|
25 |
-
out = 1. - tf.squeeze(self.conv(tf.expand_dims(1. - in_tensor, 0), kernel, [1] * (self.ndims + 2), 'SAME'), axis=0)
|
26 |
-
return soft_threshold(out, 0.5, name='soft_thresholding')
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def _erosion_distance_single(self, y_true, y_pred):
|
36 |
-
diff = tf.math.pow(y_pred - y_true, 2)
|
37 |
alpha = 2
|
38 |
|
39 |
-
norm = 1 / (self.ndims * 2 + 1)
|
40 |
-
kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
|
41 |
-
kernel = tf.constant(kernel, tf.float32)
|
42 |
-
kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) # [H, W, D, C_in, C_out]
|
43 |
-
|
44 |
ret = 0.
|
45 |
for k in range(1, self.nerosions+1):
|
46 |
er = diff
|
47 |
# k successive erosions
|
48 |
for j in range(k):
|
49 |
-
er = self.
|
50 |
-
ret += tf.reduce_sum(tf.multiply(er,
|
51 |
|
52 |
-
|
53 |
-
return tf.divide(ret, img_vol) # Divide by the image size
|
54 |
|
55 |
-
|
|
|
56 |
batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
|
57 |
-
dtype=tf.float32)
|
58 |
|
59 |
return tf.reduce_mean(batched_dist)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
class NCC:
|
63 |
def __init__(self, in_shape, eps=EPS_tf):
|
@@ -69,48 +246,105 @@ class NCC:
|
|
69 |
f_yp = tf.reshape(y_pred, [-1])
|
70 |
mean_yt = tf.reduce_mean(f_yt)
|
71 |
mean_yp = tf.reduce_mean(f_yp)
|
72 |
-
std_yt = tf.math.reduce_std(f_yt)
|
73 |
-
std_yp = tf.math.reduce_std(f_yp)
|
74 |
|
75 |
n_f_yt = f_yt - mean_yt
|
76 |
n_f_yp = f_yp - mean_yp
|
77 |
-
|
78 |
-
|
|
|
|
|
79 |
return tf.math.divide_no_nan(numerator, denominator)
|
80 |
|
|
|
81 |
def loss(self, y_true, y_pred):
|
82 |
# According to the documentation, the loss returns a scalar
|
83 |
# Ref: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
|
84 |
return tf.reduce_mean(tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
class StructuralSimilarity:
|
88 |
# Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
|
89 |
-
def __init__(self, k1=0.01, k2=0.03,
|
|
|
|
|
|
|
90 |
"""
|
91 |
Structural (Di)Similarity Index Measure:
|
92 |
|
93 |
:param k1: Internal parameter. Defaults to 0.01
|
94 |
:param k2: Internal parameter. Defaults to 0.02
|
95 |
-
:param patch_size: Size of the extracted patches
|
96 |
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
|
|
|
|
|
|
|
|
97 |
"""
|
98 |
-
|
99 |
-
|
100 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
stride = int(patch_size * (1 - overlap))
|
102 |
-
self.
|
103 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def __int_shape(self, x):
|
106 |
return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
|
107 |
|
108 |
def ssim(self, y_true, y_pred):
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
114 |
|
115 |
#bs, w, h, d, *c = self.__int_shape(patches_pred)
|
116 |
#patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
|
@@ -124,15 +358,496 @@ class StructuralSimilarity:
|
|
124 |
v_true = tf.math.reduce_variance(patches_true, axis=-1)
|
125 |
v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
|
126 |
|
|
|
|
|
|
|
|
|
127 |
# Covariance
|
128 |
covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
|
129 |
|
130 |
# SSIM
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
134 |
|
135 |
-
return tf.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import tensorflow as tf
|
2 |
+
import tensorflow.keras.backend as K
|
3 |
from scipy.ndimage import generate_binary_structure
|
4 |
+
from sklearn.utils.extmath import cartesian
|
5 |
|
6 |
+
from DeepDeformationMapRegistration.utils.operators import soft_threshold, min_max_norm, hard_threshold
|
7 |
from DeepDeformationMapRegistration.utils.constants import EPS_tf
|
8 |
+
from DeepDeformationMapRegistration.utils.misc import function_decorator
|
9 |
|
10 |
+
import numpy as np
|
11 |
+
import warnings
|
12 |
|
13 |
class HausdorffDistanceErosion:
|
14 |
+
def __init__(self, ndim=3, nerosion=10, im_shape: [list, tuple] = (64, 64, 64, 1), alpha=2):
|
15 |
"""
|
16 |
Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
|
17 |
Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
|
|
|
19 |
|
20 |
:param ndim: Dimensionality of the images
|
21 |
:param nerosion: Number of erosion steps. Defaults to 10.
|
22 |
+
:param alpha: Parameter to penalize large segmentations. Defaults to 2
|
23 |
"""
|
24 |
+
assert len(im_shape) == ndim + 1, "im_shape does not match with ndim. Missing channel dimension?"
|
25 |
self.ndims = ndim
|
26 |
+
axes = np.arange(0, self.ndims).tolist()
|
27 |
+
self.before_erosion_transp = [axes[-1], *axes[:-1]] # [H, W, ..., C] -> [C, H, W, ...]
|
28 |
+
self.after_erosion_transp = [*axes[1:], axes[0]] # [C, H, W, ...] -> [H, W, ..., C]
|
29 |
self.nerosions = nerosion
|
30 |
+
self.sum_range = tf.range(0, self.ndims)
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
self.im_shape = im_shape
|
33 |
+
self.im_vol = np.prod(im_shape[:-1])
|
34 |
+
kernel = generate_binary_structure(self.ndims, 1).astype(int)
|
35 |
+
kernel = kernel / np.sum(kernel)
|
36 |
+
kernel = kernel[..., np.newaxis, np.newaxis]
|
37 |
+
self.kernel = tf.convert_to_tensor(kernel, tf.float32)
|
38 |
+
self.k_alpha = [np.power(k, alpha).astype(float) for k in range(1, nerosion + 1)]
|
39 |
+
self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
|
40 |
|
41 |
+
def _erode(self, in_tensor):
|
42 |
+
indiv_channels = tf.split(in_tensor, self.im_shape[-1], -1)
|
43 |
+
res = list()
|
44 |
+
with tf.variable_scope('erode', reuse=tf.AUTO_REUSE):
|
45 |
+
for ch in indiv_channels:
|
46 |
+
res.append(self.conv(tf.expand_dims(ch, 0), self.kernel, [1] * (self.ndims + 2), 'SAME'))
|
47 |
+
# out = -tf.nn.max_pool3d(-tf.expand_dims(in_tensor, 0), [3]*self.ndims, [1]*self.ndims, 'SAME', name='HDE_erosion')
|
48 |
+
out = tf.concat(res, -1)
|
49 |
+
out = tf.squeeze(out, axis=0)
|
50 |
+
out = hard_threshold(out, 0.5, name='thresholding') # soft_threshold(out, 0.5, name='thresholding')
|
51 |
+
return out
|
52 |
|
53 |
def _erosion_distance_single(self, y_true, y_pred):
|
54 |
+
diff = tf.math.pow(y_pred - y_true, 2, name='HDE_diff')
|
55 |
alpha = 2
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
ret = 0.
|
58 |
for k in range(1, self.nerosions+1):
|
59 |
er = diff
|
60 |
# k successive erosions
|
61 |
for j in range(k):
|
62 |
+
er = self._erode(er) # er contains the eroded version along the channels
|
63 |
+
ret += tf.reduce_sum(tf.multiply(er, self.k_alpha[k - 1]), self.sum_range, name='HDE_ret')
|
64 |
|
65 |
+
return tf.divide(ret, self.im_vol) # Divide by the image size
|
|
|
66 |
|
67 |
+
@function_decorator('Hausdorff_erosion__loss')
|
68 |
+
def loss(self, y_true, y_pred, name='HDE_loss'):
|
69 |
batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
|
70 |
+
dtype=tf.float32, name=name+'_map_fn')
|
71 |
|
72 |
return tf.reduce_mean(batched_dist)
|
73 |
|
74 |
+
@function_decorator('Hausdorff_erosion__metric')
|
75 |
+
def metric(self, y_true, y_pred):
|
76 |
+
return self.loss(y_true, y_pred, name='HDE_metric')
|
77 |
+
|
78 |
+
def debug(self, y_true, y_pred):
|
79 |
+
return tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
|
80 |
+
dtype=tf.float32, name='HDE_loss_map_fn')
|
81 |
+
|
82 |
+
|
83 |
+
# class HausdorffDiatanceConvolution:
|
84 |
+
# def __init__(self, ndim=3, im_shape: tuple = (64, 64, 64, 1), max_kernel_size=9, step_kernel_size=3, alpha=2):
|
85 |
+
# """
|
86 |
+
# Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
|
87 |
+
# Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
|
88 |
+
# Networks". IEEE Transactions on Medical Imaging, 39, 2020. DOI 10.1109/TMI.2019.2930068
|
89 |
+
#
|
90 |
+
# :param ndim: Dimensionality of the images
|
91 |
+
# :param nerosion: Number of erosion steps. Defaults to 10.
|
92 |
+
# :param alpha: Parameter to penalize large segmentations. Defaults to 2
|
93 |
+
# """
|
94 |
+
# assert len(im_shape) == ndim + 1, "im_shape does not match with ndim. Missing channel dimension?"
|
95 |
+
# self.ndims = ndim
|
96 |
+
# axes = np.arange(0, self.ndims).tolist()
|
97 |
+
# self.before_erosion_transp = [axes[-1], *axes[:-1]] # [H, W, ..., C] -> [C, H, W, ...]
|
98 |
+
# self.after_erosion_transp = [*axes[1:], axes[0]] # [C, H, W, ...] -> [H, W, ..., C]
|
99 |
+
# self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
|
100 |
+
# self.sum_range = tf.range(0, self.ndims)
|
101 |
+
#
|
102 |
+
# self.im_shape = im_shape
|
103 |
+
# self.im_vol = np.prod(im_shape[:-1])
|
104 |
+
# kernel = generate_binary_structure(self.ndims, 1).astype(int)
|
105 |
+
# self.kernel = tf.constant(kernel / np.sum(kernel), tf.float32)
|
106 |
+
# self.kernel = tf.expand_dims(tf.expand_dims(self.kernel, -1), -1) # [H, W, D, C_in, C_out]
|
107 |
+
# self.kernel = tf.tile(self.kernel, [*[1]*self.ndims, self.im_shape[-1], self.im_shape[-1]])
|
108 |
+
# self.alpha = int(alpha)
|
109 |
+
# self.radii = np.arange(1, max_kernel_size, step=step_kernel_size)
|
110 |
+
# self.radii_alpha = [np.pow(r, alpha).astype(float) for r in self.radii]
|
111 |
+
#
|
112 |
+
# def soft_diff(self, p, q):
|
113 |
+
# return tf.multiply(tf.pow(p - q, 2.), q)
|
114 |
+
#
|
115 |
+
# def body(self, y_true, y_pred):
|
116 |
+
#
|
117 |
+
|
118 |
+
"""
|
119 |
+
class HausdorffDistanceErosion_2:
|
120 |
+
def __init__(self, im_shape, num_erosions, num_dimensions=3, alpha=2., loop_max_iterations=20):
|
121 |
+
self.alpha = alpha
|
122 |
+
self.ndims = num_dimensions
|
123 |
+
self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
|
124 |
+
|
125 |
+
self.iterator = tf.constant(num_erosions, name='num_erosions')
|
126 |
+
self.norm = 1 / np.prod(im_shape)
|
127 |
+
self.erosion_kernel = generate_binary_structure(self.ndims, 1).astype(float)
|
128 |
+
self.erosion_kernel /= np.sum(self.erosion_kernel)
|
129 |
+
self.erosion_kernel = tf.constant( self.erosion_kernel, tf.float32)
|
130 |
+
|
131 |
+
self.loop_max_iterations = loop_max_iterations
|
132 |
+
|
133 |
+
def erosion_sum(self, p, q, k):
|
134 |
+
er_tensor = p - q
|
135 |
+
er_tensor = tf.pow(er_tensor, 2.)
|
136 |
+
|
137 |
+
def erode(in_tensor):
|
138 |
+
# Erosion of in_tensor = Dilation of (1 - in_tensor)
|
139 |
+
return self.conv(tf.expand_dims(1. - in_tensor, 0), self.erosion_kernel, [1] * (self.ndims + 2), 'SAME')
|
140 |
+
|
141 |
+
def while_loop_body(i, in_tensor):
|
142 |
+
in_tensor = erode(in_tensor)
|
143 |
+
i -= 1
|
144 |
+
return i, in_tensor
|
145 |
+
|
146 |
+
def while_loop_condition(i, in_tensor):
|
147 |
+
return tf.less_equal(i, 1), in_tensor
|
148 |
+
|
149 |
+
er_iterator = tf.constant(k)
|
150 |
+
_, er_tensor = tf.while_loop(while_loop_condition, while_loop_body, loop_vars=[er_iterator, er_tensor],
|
151 |
+
maximum_iterations=self.loop_max_iterations)
|
152 |
+
|
153 |
+
er_tensor *= tf.pow(k, self.alpha)
|
154 |
+
return tf.reduce_sum(er_tensor)
|
155 |
+
|
156 |
+
def loss(self, y_true, y_pred):
|
157 |
+
hd_distance = tf.constant(0, name='hausdroff_distance')
|
158 |
+
|
159 |
+
def while_loop_body(i, p, q, ret):
|
160 |
+
i -= 1
|
161 |
+
return i, p, q, ret + self.erosion_sum(p, q, i)
|
162 |
+
|
163 |
+
_, _, _, hd_distance = tf.while_loop(lambda i, p, q, ret: tf.less_equal(i, 1),
|
164 |
+
while_loop_body,
|
165 |
+
loop_vars=[self.iterator, y_pred, y_true, hd_distance])
|
166 |
+
hd_distance /= self.norm
|
167 |
+
return hd_distance
|
168 |
+
|
169 |
+
"""
|
170 |
+
|
171 |
+
|
172 |
+
class WeightedHausdorffDistance:
|
173 |
+
def __init__(self, input_shape, alpha=-1, threshold=0.5):
|
174 |
+
"""
|
175 |
+
WARNING: Requires a insane amount of memory
|
176 |
+
:param input_shape: [H, W, D, C] or [H, W, C]
|
177 |
+
:param alpha: Parameter of the generalized mean. Ideally -inf, but then the function becomes less smooth.
|
178 |
+
:param threshold: Threshold of segmentations, used in tf.where function
|
179 |
+
"""
|
180 |
+
warnings.warn("This function requires an insane amount of memory")
|
181 |
+
self.input_shape = input_shape
|
182 |
+
self.dim = len(input_shape[:-1])
|
183 |
+
self.ohe_segm = bool(input_shape[-1] > 1) # One-Hot Encoded segmentations on the channel axis
|
184 |
+
aux = np.arange(len(self.input_shape)).tolist()
|
185 |
+
self.ohe_transpose = [aux[-1], *aux[:-1]]
|
186 |
+
self.alpha = alpha
|
187 |
+
self.threshold = threshold
|
188 |
+
list_coords = [np.arange(c) for c in self.input_shape[:-1]]
|
189 |
+
self.img_loc = tf.convert_to_tensor(cartesian(list_coords), dtype=tf.float32)
|
190 |
+
self.max_dist = np.sqrt(np.sum(np.square(self.input_shape[:-1]))) # Largest diagonal
|
191 |
+
|
192 |
+
def pairwise_distance(self, A, B):
|
193 |
+
sq_norm_a = tf.reduce_sum(tf.square(A), 1)
|
194 |
+
sq_norm_b = tf.reduce_sum(tf.square(B), 1)
|
195 |
+
|
196 |
+
sq_norm_a = tf.reshape(sq_norm_a, [-1, 1])
|
197 |
+
sq_norm_b = tf.reshape(sq_norm_b, [1, -1])
|
198 |
+
|
199 |
+
return tf.sqrt(tf.maximum(sq_norm_a - 2 * tf.matmul(A, B, transpose_a=False, transpose_b=True) + sq_norm_b, 0.))
|
200 |
+
|
201 |
+
def hausdorff(self, y_true, y_pred):
|
202 |
+
if self.ohe_segm:
|
203 |
+
y_true = tf.transpose(y_true, self.ohe_transpose)
|
204 |
+
y_pred = tf.transpose(y_pred, self.ohe_transpose)
|
205 |
+
hausdorff_per_ch = tf.map_fn(lambda x: self.hausdorff_per_channel(x[0], x[1]), (y_true, y_pred), tf.float32)
|
206 |
+
return tf.reduce_mean(hausdorff_per_ch)
|
207 |
+
else:
|
208 |
+
return self.hausdorff_per_channel(y_true, y_pred)
|
209 |
+
|
210 |
+
def hausdorff_per_channel(self, y_true, y_pred):
|
211 |
+
Y = tf.cast(tf.where(y_true > self.threshold), dtype=tf.float32)
|
212 |
+
p = K.flatten(y_pred) # Flatten the predicted segmentation (activation map 'p' in d_WH)
|
213 |
+
|
214 |
+
size_Y = tf.shape(Y)[0]
|
215 |
+
S = tf.reduce_sum(p)
|
216 |
+
|
217 |
+
p = tf.squeeze(K.repeat(tf.expand_dims(p, -1), size_Y))
|
218 |
+
dist_mat = self.pairwise_distance(self.img_loc, Y)
|
219 |
+
|
220 |
+
term_1 = tf.reduce_sum(p * tf.minimum(dist_mat, 1)) / (S + EPS_tf)
|
221 |
+
|
222 |
+
term_2 = tf.minimum((dist_mat + EPS_tf) / (tf.pow(p, self.alpha) + (EPS_tf / self.max_dist)), 0.)
|
223 |
+
term_2 = tf.clip_by_value(term_2, 0., self.max_dist)
|
224 |
+
term_2 = tf.reduce_mean(term_2, axis=0)
|
225 |
+
|
226 |
+
return term_1 + term_2
|
227 |
+
|
228 |
+
@function_decorator('Weighted_Hausdorff__loss')
|
229 |
+
def loss(self, y_true, y_pred):
|
230 |
+
batch_hdist = tf.map_fn(lambda x: self.hausdorff(x[0], x[1]), (y_true, y_pred), dtype=tf.float32)
|
231 |
+
|
232 |
+
return tf.reduce_mean(batch_hdist)
|
233 |
+
|
234 |
+
@function_decorator('Weighted_Hausdorff__metric')
|
235 |
+
def metric(self, y_true, y_pred):
|
236 |
+
return self.loss(y_true, y_pred)
|
237 |
+
|
238 |
|
239 |
class NCC:
|
240 |
def __init__(self, in_shape, eps=EPS_tf):
|
|
|
246 |
f_yp = tf.reshape(y_pred, [-1])
|
247 |
mean_yt = tf.reduce_mean(f_yt)
|
248 |
mean_yp = tf.reduce_mean(f_yp)
|
|
|
|
|
249 |
|
250 |
n_f_yt = f_yt - mean_yt
|
251 |
n_f_yp = f_yp - mean_yp
|
252 |
+
norm_yt = tf.norm(f_yt, ord='euclidean')
|
253 |
+
norm_yp = tf.norm(f_yp, ord='euclidean')
|
254 |
+
numerator = tf.reduce_sum(tf.multiply(n_f_yt, n_f_yp))
|
255 |
+
denominator = norm_yt * norm_yp + self.__eps
|
256 |
return tf.math.divide_no_nan(numerator, denominator)
|
257 |
|
258 |
+
@function_decorator('NCC__loss')
|
259 |
def loss(self, y_true, y_pred):
|
260 |
# According to the documentation, the loss returns a scalar
|
261 |
# Ref: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
|
262 |
return tf.reduce_mean(tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
|
263 |
|
264 |
+
@function_decorator('NCC__metric')
|
265 |
+
def metric(self, y_true, y_pred):
|
266 |
+
return tf.reduce_mean(tf.map_fn(lambda x: self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
|
267 |
+
|
268 |
+
|
269 |
+
def ncc(y_true, y_pred):
|
270 |
+
y_true = K.flatten(K.cast(y_true, 'float32'))
|
271 |
+
y_pred = K.flatten(K.cast(y_pred, 'float32'))
|
272 |
+
|
273 |
+
mean_true = K.mean(y_true)
|
274 |
+
mean_pred = K.mean(y_pred)
|
275 |
+
|
276 |
+
std_true = K.std(y_true)
|
277 |
+
std_pred = K.std(y_pred)
|
278 |
+
|
279 |
+
num = K.mean((y_true - mean_true) * (y_pred - mean_pred))
|
280 |
+
den = std_true * std_pred + EPS_tf
|
281 |
+
batch_ncc = num / den
|
282 |
+
|
283 |
+
return K.mean(batch_ncc)
|
284 |
+
|
285 |
|
286 |
class StructuralSimilarity:
|
287 |
# Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
|
288 |
+
def __init__(self, k1=0.01, k2=0.03,
|
289 |
+
patch_size=32, dynamic_range=1., overlap=0.0, dim=3,
|
290 |
+
alpha=1., beta=1., gamma=1.,
|
291 |
+
**kwargs):
|
292 |
"""
|
293 |
Structural (Di)Similarity Index Measure:
|
294 |
|
295 |
:param k1: Internal parameter. Defaults to 0.01
|
296 |
:param k2: Internal parameter. Defaults to 0.02
|
297 |
+
:param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
|
298 |
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
299 |
+
:param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
|
300 |
+
:param dim: Data dimensionality. Must be {1, 2, 3}. Defaults to 3.
|
301 |
+
:param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
|
302 |
+
structure measures. Default to 1.
|
303 |
"""
|
304 |
+
assert (dim > 0) and (dim < 4), 'Invalid dimension. It must be 1, 2, or 3'
|
305 |
+
assert overlap < 1., 'Invalid overlap. It must be in the range [0., 1.)'
|
306 |
+
self.c1 = (k1 * dynamic_range) ** 2
|
307 |
+
self.c2 = (k2 * dynamic_range) ** 2
|
308 |
+
self.c3 = self.c2 / 2
|
309 |
+
self.alpha = tf.cast(alpha, tf.float32)
|
310 |
+
self.beta = tf.cast(beta, tf.float32)
|
311 |
+
self.gamma = tf.cast(gamma, tf.float32)
|
312 |
+
|
313 |
+
self.kernel_shape = [1] + [patch_size] * dim + [1]
|
314 |
stride = int(patch_size * (1 - overlap))
|
315 |
+
self.stride = [1] + [stride if stride else 1] * dim + [1]
|
316 |
+
self.dim = dim
|
317 |
+
self.patch_extractor = None
|
318 |
+
self.reduce_axis = list()
|
319 |
+
if dim == 2:
|
320 |
+
self.patch_extractor = tf.extract_image_patches
|
321 |
+
self.reduce_axis = [1, 2]
|
322 |
+
elif dim == 3:
|
323 |
+
self.patch_extractor = tf.extract_volume_patches
|
324 |
+
self.reduce_axis = [1, 2, 3]
|
325 |
+
else:
|
326 |
+
raise ValueError('Invalid dimension value. Expected 2 or 3')
|
327 |
+
|
328 |
+
if patch_size == -1:
|
329 |
+
# Don't extract patches
|
330 |
+
self.dim = 1
|
331 |
+
|
332 |
+
self.L = None # Luminance
|
333 |
+
self.C = None # Contrast
|
334 |
+
self.S = None # Structure
|
335 |
|
336 |
def __int_shape(self, x):
|
337 |
return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
|
338 |
|
339 |
def ssim(self, y_true, y_pred):
|
340 |
+
if self.dim > 1:
|
341 |
+
# Don't use for training. The gradient doesn't backpropagate through the patch extractors
|
342 |
+
# patches: [B, out_rows, out_cols, ..., krows*kcols*...*channels] -> out_rows * out_cols * ... = nb patches
|
343 |
+
patches_true = self.patch_extractor(y_true, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_true')
|
344 |
+
patches_pred = self.patch_extractor(y_pred, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_pred')
|
345 |
+
else:
|
346 |
+
patches_true = y_true
|
347 |
+
patches_pred = y_pred
|
348 |
|
349 |
#bs, w, h, d, *c = self.__int_shape(patches_pred)
|
350 |
#patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
|
|
|
358 |
v_true = tf.math.reduce_variance(patches_true, axis=-1)
|
359 |
v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
|
360 |
|
361 |
+
# Standard dev.
|
362 |
+
s_true = tf.sqrt(v_true)
|
363 |
+
s_pred = tf.sqrt(v_pred)
|
364 |
+
|
365 |
# Covariance
|
366 |
covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
|
367 |
|
368 |
# SSIM
|
369 |
+
self.L = (2 * u_true * u_pred + self.c1) / (tf.square(u_true) + tf.square(u_pred) + self.c1)
|
370 |
+
self.C = (2 * s_true * s_pred + self.c2) / (v_true + v_pred + self.c2)
|
371 |
+
self.S = (covar + self.c3) / (s_true * s_pred + self.c3)
|
372 |
+
self.L = tf.reduce_mean(self.L, axis=self.reduce_axis)
|
373 |
+
self.C = tf.reduce_mean(self.C, axis=self.reduce_axis)
|
374 |
+
self.S = tf.reduce_mean(self.S, axis=self.reduce_axis)
|
375 |
|
376 |
+
return tf.pow(self.L, self.alpha) * tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma)
|
377 |
+
|
378 |
+
@function_decorator('SSIM__loss')
|
379 |
+
def loss(self, y_true, y_pred):
|
380 |
+
return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)
|
381 |
+
|
382 |
+
@function_decorator('SSIM__metric')
|
383 |
+
def metric(self, y_true, y_pred):
|
384 |
+
return tf.reduce_mean(self.ssim(y_true, y_pred))
|
385 |
+
|
386 |
+
|
387 |
+
class StructuralSimilarity_simplified:
|
388 |
+
# Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
|
389 |
+
def __init__(self, k1=0.01, k2=0.03,
|
390 |
+
patch_size=32, dynamic_range=1., overlap=0.0, dim=3,
|
391 |
+
alpha=1., beta=1., gamma=1.,
|
392 |
+
**kwargs):
|
393 |
+
"""
|
394 |
+
Structural (Di)Similarity Index Measure:
|
395 |
+
|
396 |
+
:param k1: Internal parameter. Defaults to 0.01
|
397 |
+
:param k2: Internal parameter. Defaults to 0.02
|
398 |
+
:param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
|
399 |
+
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
400 |
+
:param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
|
401 |
+
:param dim: Data dimensionality. Must be {1, 2, 3}. Defaults to 3.
|
402 |
+
:param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
|
403 |
+
structure measures. Default to 1.
|
404 |
+
"""
|
405 |
+
assert (dim > 0) and (dim < 4), 'Invalid dimension. It must be 1, 2, or 3'
|
406 |
+
assert overlap < 1., 'Invalid overlap. It must be in the range [0., 1.)'
|
407 |
+
self.c1 = (k1 * dynamic_range) ** 2
|
408 |
+
self.c2 = (k2 * dynamic_range) ** 2
|
409 |
+
self.c3 = self.c2 / 2
|
410 |
+
self.alpha = tf.cast(alpha, tf.float32)
|
411 |
+
self.beta = tf.cast(beta, tf.float32)
|
412 |
+
self.gamma = tf.cast(gamma, tf.float32)
|
413 |
+
|
414 |
+
self.kernel_shape = [1] + [patch_size] * dim + [1]
|
415 |
+
stride = int(patch_size * (1 - overlap))
|
416 |
+
self.stride = [1] + [stride if stride else 1] * dim + [1]
|
417 |
+
self.dim = dim
|
418 |
+
self.patch_extractor = None
|
419 |
+
if dim == 2:
|
420 |
+
self.patch_extractor = tf.extract_image_patches
|
421 |
+
elif dim == 3:
|
422 |
+
self.patch_extractor = tf.extract_volume_patches
|
423 |
+
|
424 |
+
if patch_size == -1:
|
425 |
+
# Don't extract patches
|
426 |
+
self.dim = 1
|
427 |
+
|
428 |
+
self.L = None # Luminance
|
429 |
+
self.C = None # Contrast
|
430 |
+
self.S = None # Structure
|
431 |
|
432 |
+
def __int_shape(self, x):
|
433 |
+
return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
|
434 |
+
|
435 |
+
def ssim(self, y_true, y_pred):
|
436 |
+
if self.dim > 1:
|
437 |
+
# Don't use for training. The gradient doesn't backpropagate through the patch extractors
|
438 |
+
# patches: [B, out_rows, out_cols, ..., krows*kcols*...*channels] -> out_rows * out_cols * ... = nb patches
|
439 |
+
patches_true = self.patch_extractor(y_true, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_true')
|
440 |
+
patches_pred = self.patch_extractor(y_pred, ksizes=self.kernel_shape, strides=self.stride, padding='VALID', name='patches_pred')
|
441 |
+
else:
|
442 |
+
patches_true = y_true
|
443 |
+
patches_pred = y_pred
|
444 |
+
|
445 |
+
#bs, w, h, d, *c = self.__int_shape(patches_pred)
|
446 |
+
#patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
|
447 |
+
#patches_pred = tf.reshape(patches_pred, [-1, w, h, d, tf.reduce_prod(c)])
|
448 |
+
|
449 |
+
# Mean
|
450 |
+
u_true = tf.reduce_mean(patches_true, axis=-1)
|
451 |
+
u_pred = tf.reduce_mean(patches_pred, axis=-1)
|
452 |
+
|
453 |
+
# Variance
|
454 |
+
v_true = tf.math.reduce_variance(patches_true, axis=-1)
|
455 |
+
v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
|
456 |
+
|
457 |
+
# Covariance
|
458 |
+
covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
|
459 |
+
|
460 |
+
# return tf.pow(self.L, self.alpha) * tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma)
|
461 |
+
num = (2 * u_true * u_pred + self.c1) * (2 * covar + self.c2)
|
462 |
+
den = ((tf.square(u_true) + tf.square(u_pred) + self.c1) * (v_pred + v_true + self.c2))
|
463 |
+
return num / den
|
464 |
+
|
465 |
+
@function_decorator('SSIM_simple__loss')
|
466 |
+
def loss(self, y_true, y_pred):
|
467 |
return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)
|
468 |
+
|
469 |
+
@function_decorator('SSIM_simple__metric')
|
470 |
+
def metric(self, y_true, y_pred):
|
471 |
+
return tf.reduce_mean(self.ssim(y_true, y_pred))
|
472 |
+
|
473 |
+
|
474 |
+
class MultiScaleStructuralSimilarity(StructuralSimilarity):
|
475 |
+
def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0, dim=3, nscales=3, alpha=1., beta=1., gamma=1.):
|
476 |
+
"""
|
477 |
+
Multi Scale Structural (Di)Similarity Index Measure:
|
478 |
+
Ref: [1] https://www.cns.nyu.edu/pub/eero/wang03b.pdf
|
479 |
+
|
480 |
+
:param k1: Internal parameter. Defaults to 0.01
|
481 |
+
:param k2: Internal parameter. Defaults to 0.02
|
482 |
+
:param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
|
483 |
+
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
484 |
+
:param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
|
485 |
+
:param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
|
486 |
+
:param nscales: Number of scales to analyze. Defaults to 3.
|
487 |
+
:param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
|
488 |
+
structure measures. Default to 1.
|
489 |
+
"""
|
490 |
+
assert dim > 1, 'Cannot be used with 1-D data'
|
491 |
+
super(MultiScaleStructuralSimilarity, self).__init__(k1=k1, k2=k2, patch_size=patch_size,
|
492 |
+
dynamic_range=dynamic_range, overlap=overlap, dim=dim,
|
493 |
+
alpha=alpha, beta=beta, gamma=gamma)
|
494 |
+
self.num_scales = nscales
|
495 |
+
self.avg_pool = getattr(tf.nn, 'avg_pool%dd' % dim)
|
496 |
+
self.ds_stride = self.ds_kernel = [1] + [2]*dim + [1]
|
497 |
+
|
498 |
+
# In [1] these are set to the same value at the same scales and normalized across scales
|
499 |
+
self.alpha = self.beta = self.gamma = 1 / nscales
|
500 |
+
|
501 |
+
def _cond(self, cs_prod, scale_level, y_true, y_pred):
|
502 |
+
return tf.less_equal(scale_level, self.num_scales)
|
503 |
+
|
504 |
+
def _iteration(self, cs_prod, scale_level, y_true, y_pred):
|
505 |
+
super(MultiScaleStructuralSimilarity, self).ssim(y_true, y_pred)
|
506 |
+
cs_prod *= tf.reduce_mean(tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma))
|
507 |
+
y_true = self.avg_pool(y_true, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
|
508 |
+
y_pred = self.avg_pool(y_pred, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
|
509 |
+
scale_level += 1
|
510 |
+
return cs_prod, scale_level, y_true, y_pred,
|
511 |
+
|
512 |
+
def ssim(self, y_true, y_pred):
|
513 |
+
return self.ms_ssim(y_true, y_pred)
|
514 |
+
|
515 |
+
def ms_ssim(self, y_true, y_pred):
|
516 |
+
cs_prod = tf.constant(1.)
|
517 |
+
scale_level = tf.constant(1.)
|
518 |
+
cs_prod, *_ = tf.while_loop(self._cond,
|
519 |
+
self._iteration,
|
520 |
+
(cs_prod, scale_level, y_true, y_pred),
|
521 |
+
(cs_prod.get_shape(), scale_level.get_shape(),
|
522 |
+
tf.TensorShape(([1] + [None] * self.dim + [1])),
|
523 |
+
tf.TensorShape(([1] + [None] * self.dim + [1]))))
|
524 |
+
|
525 |
+
ms_ssim = tf.reduce_mean(tf.pow(self.L, self.alpha)) * cs_prod
|
526 |
+
|
527 |
+
return tf.reduce_mean(ms_ssim)
|
528 |
+
|
529 |
+
@function_decorator('MS_SSIM__loss')
|
530 |
+
def loss(self, y_true, y_pred):
|
531 |
+
return tf.reduce_mean((1. - self.ms_ssim(y_true, y_pred)) / 2.0)
|
532 |
+
|
533 |
+
|
534 |
+
class MultiScaleStructuralSimilarity_v2(StructuralSimilarity):
|
535 |
+
def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0, dim=3, nscales=3, alpha=1., beta=1., gamma=1.):
|
536 |
+
"""
|
537 |
+
Multi Scale Structural (Di)Similarity Index Measure:
|
538 |
+
Ref: [1] https://www.cns.nyu.edu/pub/eero/wang03b.pdf
|
539 |
+
|
540 |
+
:param k1: Internal parameter. Defaults to 0.01
|
541 |
+
:param k2: Internal parameter. Defaults to 0.02
|
542 |
+
:param patch_size: Size of the extracted patches. Defaults to 32. Recommendation: half the image size.
|
543 |
+
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
544 |
+
:param overlap: Patch overlap ratio. Must be in the range [0., 1.). Defaults to 0.
|
545 |
+
:param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
|
546 |
+
:param nscales: Number of scales to analyze. Defaults to 3.
|
547 |
+
:param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
|
548 |
+
structure measures. Default to 1.
|
549 |
+
"""
|
550 |
+
assert dim > 1, 'Cannot be used with 1-D data'
|
551 |
+
super(MultiScaleStructuralSimilarity_v2, self).__init__(k1=k1, k2=k2, patch_size=patch_size,
|
552 |
+
dynamic_range=dynamic_range, overlap=overlap, dim=dim,
|
553 |
+
alpha=alpha, beta=beta, gamma=gamma)
|
554 |
+
self.num_scales = nscales
|
555 |
+
self.avg_pool = getattr(tf.nn, 'avg_pool%dd' % dim)
|
556 |
+
self.ds_stride = self.ds_kernel = [1] + [2]*dim + [1]
|
557 |
+
|
558 |
+
# In [1] these are set to the same value at the same scales and normalized across scales
|
559 |
+
self.alpha = self.beta = self.gamma = 1 / nscales
|
560 |
+
|
561 |
+
def _cond(self, cs_prod, scale_level, y_true, y_pred):
|
562 |
+
return tf.less_equal(scale_level, self.num_scales)
|
563 |
+
|
564 |
+
def _iteration(self, cs_prod, scale_level, y_true, y_pred):
|
565 |
+
super(MultiScaleStructuralSimilarity_v2, self).ssim(y_true, y_pred)
|
566 |
+
cs_prod *= tf.reduce_mean(tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma))
|
567 |
+
y_true = self.avg_pool(y_true, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
|
568 |
+
y_pred = self.avg_pool(y_pred, ksize=self.ds_kernel, strides=self.ds_stride, padding='VALID')
|
569 |
+
scale_level += 1
|
570 |
+
return cs_prod, scale_level, y_true, y_pred,
|
571 |
+
|
572 |
+
def ssim(self, y_true, y_pred):
|
573 |
+
return self.ms_ssim(y_true, y_pred)
|
574 |
+
|
575 |
+
def ms_ssim(self, y_true, y_pred):
|
576 |
+
cs_prod = tf.constant(1.)
|
577 |
+
scale_level = tf.constant(1.)
|
578 |
+
cs_prod, *_ = tf.while_loop(self._cond,
|
579 |
+
self._iteration,
|
580 |
+
(cs_prod, scale_level, y_true, y_pred),
|
581 |
+
(cs_prod.get_shape(), scale_level.get_shape(),
|
582 |
+
tf.TensorShape(([1] + [None] * self.dim + [1])),
|
583 |
+
tf.TensorShape(([1] + [None] * self.dim + [1]))))
|
584 |
+
|
585 |
+
ms_ssim = tf.reduce_mean(tf.pow(self.L, self.alpha)) * cs_prod
|
586 |
+
|
587 |
+
return tf.reduce_mean(ms_ssim)
|
588 |
+
|
589 |
+
@function_decorator('MS_SSIM_v2__loss')
|
590 |
+
def loss(self, y_true, y_pred):
|
591 |
+
return tf.reduce_mean((1. - self.ms_ssim(y_true, y_pred)) / 2.0)
|
592 |
+
|
593 |
+
|
594 |
+
class StructuralSimilarityGaussian:
|
595 |
+
# This is equivalent to StructuralSimilarity(patch_size=img_size)
|
596 |
+
def __init__(self, k1=0.01, k2=0.03, dynamic_range=1., gauss_sigma=5., dim=3, alpha=1., beta=1., gamma=1.):
|
597 |
+
"""
|
598 |
+
SSIM using Gaussian filter to approximate the statistics of the images
|
599 |
+
Ref: https://www.cns.nyu.edu/pub/eero/wang03b.pdf
|
600 |
+
https://arxiv.org/pdf/1511.08861.pdf
|
601 |
+
https://github.com/NVlabs/PL4NN/blob/master/src/loss.py
|
602 |
+
|
603 |
+
:param k1: Internal parameter. Defaults to 0.01
|
604 |
+
:param k2: Internal parameter. Defaults to 0.02
|
605 |
+
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
606 |
+
:param gauss_sigma: Sigma of the Gaussian filter. Defaults to 1.5.
|
607 |
+
:param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
|
608 |
+
:param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
|
609 |
+
structure measures. Default to 1.
|
610 |
+
"""
|
611 |
+
self.c1 = (k1 * dynamic_range) ** 2
|
612 |
+
self.c2 = (k2 * dynamic_range) ** 2
|
613 |
+
self.c3 = self.c2 / 2
|
614 |
+
self.alpha = tf.cast(alpha, tf.float32)
|
615 |
+
self.beta = tf.cast(beta, tf.float32)
|
616 |
+
self.gamma = tf.cast(gamma, tf.float32)
|
617 |
+
self.dim = dim
|
618 |
+
self.convDN = getattr(tf.nn, 'conv%dd' % dim)
|
619 |
+
self.sigma = gauss_sigma
|
620 |
+
|
621 |
+
def build_gaussian_filter(self, size, sigma, num_channels=1):
|
622 |
+
range_1d = tf.range(-(size/2) + 1, size//2 + 1)
|
623 |
+
g_1d = tf.math.exp(-1.0 * tf.pow(range_1d, 2) / (2. * tf.pow(sigma, 2)))
|
624 |
+
g_1d_expanded = tf.expand_dims(g_1d, -1)
|
625 |
+
iterator = tf.constant(1)
|
626 |
+
self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
|
627 |
+
lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
|
628 |
+
[iterator, g_1d],
|
629 |
+
[iterator.get_shape(), tf.TensorShape([None]*self.dim)] # Shape invariants
|
630 |
+
)[-1]
|
631 |
+
|
632 |
+
self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
|
633 |
+
self.__GF = tf.reshape(self.__GF, (*[size]*self.dim, 1, 1)) # Add Ch_in and Ch_out for convolution
|
634 |
+
self.__GF = tf.tile(self.__GF, (*[1] * self.dim, num_channels, num_channels,))
|
635 |
+
|
636 |
+
def format_data(self, in_data):
|
637 |
+
ret_val = in_data
|
638 |
+
if self.dim == 3:
|
639 |
+
ret_val = tf.transpose(ret_val, [0, 3, 1, 2, 4])
|
640 |
+
return ret_val
|
641 |
+
|
642 |
+
def ssim(self, y_true, y_pred):
|
643 |
+
self.build_gaussian_filter(y_pred.shape[1], self.sigma)
|
644 |
+
y_true_tr = self.format_data(y_true)
|
645 |
+
y_pred_tr = self.format_data(y_pred)
|
646 |
+
|
647 |
+
u_true = self.convDN(y_true_tr, self.__GF, [1] * (self.dim + 2), 'SAME')
|
648 |
+
u_pred = self.convDN(y_pred_tr, self.__GF, [1] * (self.dim + 2), 'SAME')
|
649 |
+
|
650 |
+
v_true = self.convDN(tf.pow(y_true_tr, 2), self.__GF, [1] * (self.dim + 2), 'SAME') - tf.pow(u_true, 2)
|
651 |
+
v_pred = self.convDN(tf.pow(y_pred_tr, 2), self.__GF, [1] * (self.dim + 2), 'SAME') - tf.pow(u_pred, 2)
|
652 |
+
covar = self.convDN(tf.multiply(y_true_tr, y_pred_tr), self.__GF, [1] * (self.dim + 2), 'SAME') - u_true * u_pred
|
653 |
+
|
654 |
+
self.L = (2 * u_true * u_pred + self.c1) / (tf.square(u_true) + tf.square(u_pred) + self.c1)
|
655 |
+
self.C = (2 * tf.sqrt(v_true) * tf.sqrt(v_pred) + self.c2) / (v_true + v_pred + self.c2)
|
656 |
+
self.S = (covar + self.c3) / (tf.sqrt(v_true) * tf.sqrt(v_pred) + self.c3)
|
657 |
+
ssim = tf.pow(self.L, self.alpha) * tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma)
|
658 |
+
|
659 |
+
return tf.reduce_mean(ssim)
|
660 |
+
|
661 |
+
@function_decorator('SSIM_Gaus__loss')
|
662 |
+
def loss(self, y_true, y_pred):
|
663 |
+
return tf.reduce_mean((1. - self.ssim(y_true, y_pred))/2.)
|
664 |
+
|
665 |
+
@function_decorator('SSIM_Gaus__metric')
|
666 |
+
def metric(self, y_true, y_pred):
|
667 |
+
return tf.reduce_mean(self.ssim(y_true, y_pred))
|
668 |
+
|
669 |
+
|
670 |
+
class MultiScaleStructuralSimilarityGaussian(StructuralSimilarityGaussian):
|
671 |
+
def __init__(self, k1=0.01, k2=0.03, dynamic_range=1., gauss_sigma=5., dim=3, nscales=3, alpha=1., beta=1., gamma=1.):
|
672 |
+
"""
|
673 |
+
Multi Scale SSIM inheriting from StructuralSimilarityGaussian classed
|
674 |
+
Ref: https://www.cns.nyu.edu/pub/eero/wang03b.pdf
|
675 |
+
https://arxiv.org/pdf/1511.08861.pdf
|
676 |
+
https://github.com/NVlabs/PL4NN/blob/master/src/loss.py
|
677 |
+
|
678 |
+
:param k1: Internal parameter. Defaults to 0.01
|
679 |
+
:param k2: Internal parameter. Defaults to 0.02
|
680 |
+
:param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
|
681 |
+
:param gauss_sigma: Sigma of the Gaussian filter. Defaults to 1.5.
|
682 |
+
:param dim: Data dimensionality. Must be {2, 3}. Defaults to 3.
|
683 |
+
:param nscales: Number of scales to analyze. Defaults to 3.
|
684 |
+
:param alpha, beta, gamma: Exponential parameters to balance the contribution of the luminance, contrast and
|
685 |
+
structure measures. Default to 1.
|
686 |
+
"""
|
687 |
+
super(MultiScaleStructuralSimilarityGaussian, self).__init__(k1=k1, k2=k2, dynamic_range=dynamic_range,
|
688 |
+
gauss_sigma=gauss_sigma, dim=dim,
|
689 |
+
alpha=alpha, beta=beta, gamma=gamma)
|
690 |
+
self.__num_scales = nscales
|
691 |
+
|
692 |
+
# # If using the Gaussian approximation of the pyramid MS approach described in https://arxiv.org/pdf/1511.08861.pdf
|
693 |
+
# def build_sigma_scales(self):
|
694 |
+
# iterator = tf.constant(0)
|
695 |
+
# scales = tf.expand_dims(self.sigma, -1)
|
696 |
+
# last_sigma = scales
|
697 |
+
# self.sigma_scales = tf.while_loop(lambda iterator, last_sigma, scales: tf.less_equal(iterator, self.__num_scales),
|
698 |
+
# lambda iterator, last_sigma, scales: (iterator + 1, tf.concat([scales, last_sigma/2], 0), last_sigma/2),
|
699 |
+
# [iterator, last_sigma, scales])[-1]
|
700 |
+
#
|
701 |
+
# def build_gaussian_filters_scales(self, size):
|
702 |
+
# self.__GFS = tf.map_fn(lambda sigma: self.build_gaussian_filter(size, sigma), self.sigma, tf.float32)
|
703 |
+
|
704 |
+
def _iteration(self, cs_prod, scale_level, y_true, y_pred):
|
705 |
+
# Compute the SSIM, so CS and L have the correct value
|
706 |
+
self.ssim(y_true, y_pred)
|
707 |
+
|
708 |
+
cs_prod *= tf.reduce_mean(tf.pow(self.C, self.beta) * tf.pow(self.S, self.gamma))
|
709 |
+
scale_level += 1
|
710 |
+
|
711 |
+
# Downsample the images to half the resolution for the next iteration
|
712 |
+
y_true = tf.nn.avg_pool(y_true, [1] + [2]*self.dim + [1], [1] + [2]*self.dim + [1], 'SAME')
|
713 |
+
y_pred = tf.nn.avg_pool(y_true, [1] + [2]*self.dim + [1], [1] + [2]*self.dim + [1], 'SAME')
|
714 |
+
return cs_prod, scale_level, y_true, y_pred
|
715 |
+
|
716 |
+
def ms_ssim(self, y_true, y_pred):
|
717 |
+
scale_level = tf.constant(0.)
|
718 |
+
cs_prod = tf.constant(1.)
|
719 |
+
cs_prod, *_ = tf.while_loop(tf.less(scale_level, self.__num_scales),
|
720 |
+
self._iteration,
|
721 |
+
(cs_prod, scale_level, y_true, y_pred),
|
722 |
+
(cs_prod.get_shape(), scale_level.get_shape(),
|
723 |
+
tf.TensorShape(([1] + [None]*self.dim + [1])),
|
724 |
+
tf.TensorShape(([1] + [None]*self.dim + [1]))))
|
725 |
+
# L is taken from the last scale
|
726 |
+
return tf.reduce_mean(tf.pow(self.L, self.alfa)) * cs_prod
|
727 |
+
|
728 |
+
@function_decorator('MS_SSIM_Gaus__metric')
|
729 |
+
def loss(self, y_true, y_pred):
|
730 |
+
return tf.reduce_mean((1. - self.ms_ssim(y_true, y_pred))/2.)
|
731 |
+
|
732 |
+
|
733 |
+
class DICEScore:
|
734 |
+
def __init__(self, input_shape: list):
|
735 |
+
"""
|
736 |
+
DICE Score.
|
737 |
+
:param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
|
738 |
+
"""
|
739 |
+
self.axes = list(range(1, len(input_shape))) # The list will not include the channel axis [1, ..., num_dims)
|
740 |
+
|
741 |
+
def dice(self, y_true, y_pred):
|
742 |
+
numerator = 2 * tf.reduce_sum(y_true * y_pred, self.axes)
|
743 |
+
denominator = tf.reduce_sum(y_true + y_pred, self.axes)
|
744 |
+
return tf.reduce_mean(tf.div_no_nan(numerator, denominator))
|
745 |
+
|
746 |
+
@function_decorator('DICE__loss')
|
747 |
+
def loss(self, y_true, y_pred):
|
748 |
+
return 1 - 2 * tf.reduce_mean(self.dice(y_true, y_pred))
|
749 |
+
|
750 |
+
@function_decorator('DICE__metric')
|
751 |
+
def metric(self, y_true, y_pred):
|
752 |
+
return tf.reduce_mean(self.dice(y_true, y_pred))
|
753 |
+
|
754 |
+
|
755 |
+
class GeneralizedDICEScore:
|
756 |
+
def __init__(self, input_shape: list, num_labels: int=None):
|
757 |
+
"""
|
758 |
+
Generalized DICE Score. Implementation based on Carole H. Sudre, et al., "Generalised DIce Overlap as a Deep
|
759 |
+
Learning Los Function for Highly Unbalanced Segmentations" https://arxiv.org/abs/1707.03237
|
760 |
+
:param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
|
761 |
+
"""
|
762 |
+
if input_shape[-1] > 1:
|
763 |
+
self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
|
764 |
+
self.hot_encode = False
|
765 |
+
elif num_labels is not None:
|
766 |
+
self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), num_labels]
|
767 |
+
self.one_hot_enc_shape = [-1, *input_shape[:-1]]
|
768 |
+
self.hot_encode = True
|
769 |
+
warnings.warn('Differentiable one-hot encoding not yet implemented')
|
770 |
+
else:
|
771 |
+
raise ValueError('If input_shape is not one hot encoded, then num_labels must be provided')
|
772 |
+
|
773 |
+
def one_hot_encoding(self, in_img, name=''):
|
774 |
+
# TODO: Test if differentiable!
|
775 |
+
labels, indices = tf.unique(tf.reshape(in_img, [-1]), tf.int32, name=name+'_unique')
|
776 |
+
one_hot = tf.one_hot(indices, tf.size(labels), name=name + '_one_hot')
|
777 |
+
one_hot = tf.reshape(one_hot, self.one_hot_enc_shape + [tf.size(labels)], name=name + '_reshape')
|
778 |
+
one_hot = tf.slice(one_hot, [0]*len(self.one_hot_enc_shape) + [1], [-1]*(len(self.one_hot_enc_shape) + 1),
|
779 |
+
name=name + '_remove_bg')
|
780 |
+
return one_hot
|
781 |
+
|
782 |
+
def weigthed_dice(self, y_true, y_pred):
|
783 |
+
# y_true = [B, -1, L]
|
784 |
+
# y_pred = [B, -1, L]
|
785 |
+
if self.hot_encode:
|
786 |
+
y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
|
787 |
+
y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
|
788 |
+
y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
|
789 |
+
y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
|
790 |
+
|
791 |
+
size_y_true = tf.reduce_sum(y_true, axis=1, name='GDICE_size_y_true')
|
792 |
+
size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
|
793 |
+
w = tf.div_no_nan(1., tf.pow(size_y_true, 2), name='GDICE_weight')
|
794 |
+
numerator = w * tf.reduce_sum(y_true * y_pred, axis=1)
|
795 |
+
denominator = w * (size_y_true + size_y_pred)
|
796 |
+
return tf.div_no_nan(2 * tf.reduce_sum(numerator, axis=-1), tf.reduce_sum(denominator, axis=-1))
|
797 |
+
|
798 |
+
def macro_dice(self, y_true, y_pred):
|
799 |
+
# y_true = [B, -1, L]
|
800 |
+
# y_pred = [B, -1, L]
|
801 |
+
if self.hot_encode:
|
802 |
+
y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
|
803 |
+
y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
|
804 |
+
y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
|
805 |
+
y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
|
806 |
+
|
807 |
+
size_y_true = tf.reduce_sum(y_true, axis=1, name='GDICE_size_y_true')
|
808 |
+
size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
|
809 |
+
numerator = tf.reduce_sum(y_true * y_pred, axis=1)
|
810 |
+
denominator = (size_y_true + size_y_pred)
|
811 |
+
return tf.div_no_nan(2 * numerator, denominator)
|
812 |
+
|
813 |
+
@function_decorator('GeneralizeDICE__loss')
|
814 |
+
def loss(self, y_true, y_pred):
|
815 |
+
return 1 - tf.reduce_mean(self.weigthed_dice(y_true, y_pred))
|
816 |
+
|
817 |
+
@function_decorator('GeneralizeDICE__metric')
|
818 |
+
def metric(self, y_true, y_pred):
|
819 |
+
return tf.reduce_mean(self.weigthed_dice(y_true, y_pred))
|
820 |
+
|
821 |
+
@function_decorator('GeneralizeDICE__loss_macro')
|
822 |
+
def loss_macro(self, y_true, y_pred):
|
823 |
+
return 1 - tf.reduce_mean(self.macro_dice(y_true, y_pred))
|
824 |
+
|
825 |
+
@function_decorator('GeneralizeDICE__metric_macro')
|
826 |
+
def metric_macro(self, y_true, y_pred):
|
827 |
+
return tf.reduce_mean(self.macro_dice(y_true, y_pred))
|
828 |
+
|
829 |
+
|
830 |
+
def target_registration_error(y_true, y_pred, average=True):
|
831 |
+
'''
|
832 |
+
Target Registration Error measured as the average distance between y_true and y_pred
|
833 |
+
:param y_true: [N, D] target points
|
834 |
+
:param y_pred: [N, D] predicted points
|
835 |
+
:param average: return the average TRE or an [N,] array
|
836 |
+
:return: averate TRE or [N,] array of TRE for each point
|
837 |
+
'''
|
838 |
+
assert y_true.shape == y_pred.shape, "y_true and y_pred must have the same shape"
|
839 |
+
if average:
|
840 |
+
return tf.reduce_mean(tf.linalg.norm(y_pred - y_true, axis=1))
|
841 |
+
else:
|
842 |
+
return tf.linalg.norm(y_pred - y_true, axis=1)
|
843 |
+
|
844 |
+
# TODO: tensorflow-graphic has an implementation of Hausdorff ditance.
|
845 |
+
# However, this is not where it should and I can't find it
|
846 |
+
# def HausdorffDistance_exact(y_true, y_pred, ohe=False, name='hd_exact'):
|
847 |
+
# if ohe:
|
848 |
+
# y_true = tf.transpose(y_true, [0, 4, 1, 2, 3])
|
849 |
+
# y_pred = tf.transpose(y_pred, [0, 4, 1, 2, 3])
|
850 |
+
# y_true_coords = tf.where(y_true)
|
851 |
+
# y_pred_coords = tf.where(y_pred)
|
852 |
+
#
|
853 |
+
# return tfg_nn.loss.hausdorff_distance.evaluate(y_true_coords, y_pred_coords, name=name)
|
DeepDeformationMapRegistration/ms_ssim_tf.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SRC: https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/python/ops/image_ops_impl.py
|
2 |
+
from tensorflow.python import nn_ops
|
3 |
+
from tensorflow.python import math_ops
|
4 |
+
from tensorflow.python import array_ops
|
5 |
+
from tensorflow.python.framework import dtypes
|
6 |
+
from tensorflow.python.framework import ops
|
7 |
+
from tensorflow.python.framework import constant_op
|
8 |
+
from tensorflow.python.ops import control_flow_ops
|
9 |
+
from tensorflow.python.ops import nn
|
10 |
+
from tensorflow.python.util.tf_export import tf_export
|
11 |
+
from tensorflow.python.util import dispatch
|
12 |
+
from DeepDeformationMapRegistration.utils.misc import function_decorator
|
13 |
+
|
14 |
+
|
15 |
+
@tf_export('image.convert_image_dtype')
|
16 |
+
@dispatch.add_dispatch_support
|
17 |
+
def convert_image_dtype(image, dtype, saturate=False, name=None):
|
18 |
+
"""Convert `image` to `dtype`, scaling its values if needed.
|
19 |
+
The operation supports data types (for `image` and `dtype`) of
|
20 |
+
`uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
|
21 |
+
`float16`, `float32`, `float64`, `bfloat16`.
|
22 |
+
Images that are represented using floating point values are expected to have
|
23 |
+
values in the range [0,1). Image data stored in integer data types are
|
24 |
+
expected to have values in the range `[0,MAX]`, where `MAX` is the largest
|
25 |
+
positive representable number for the data type.
|
26 |
+
This op converts between data types, scaling the values appropriately before
|
27 |
+
casting.
|
28 |
+
Usage Example:
|
29 |
+
>>> x = [[[1, 2, 3], [4, 5, 6]],
|
30 |
+
... [[7, 8, 9], [10, 11, 12]]]
|
31 |
+
>>> x_int8 = tf.convert_to_tensor(x, dtype=tf.int8)
|
32 |
+
>>> tf.image.convert_image_dtype(x_int8, dtype=tf.float16, saturate=False)
|
33 |
+
<tf.Tensor: shape=(2, 2, 3), dtype=float16, numpy=
|
34 |
+
array([[[0.00787, 0.01575, 0.02362],
|
35 |
+
[0.0315 , 0.03937, 0.04724]],
|
36 |
+
[[0.0551 , 0.063 , 0.07086],
|
37 |
+
[0.07874, 0.0866 , 0.0945 ]]], dtype=float16)>
|
38 |
+
Converting integer types to floating point types returns normalized floating
|
39 |
+
point values in the range [0, 1); the values are normalized by the `MAX` value
|
40 |
+
of the input dtype. Consider the following two examples:
|
41 |
+
>>> a = [[[1], [2]], [[3], [4]]]
|
42 |
+
>>> a_int8 = tf.convert_to_tensor(a, dtype=tf.int8)
|
43 |
+
>>> tf.image.convert_image_dtype(a_int8, dtype=tf.float32)
|
44 |
+
<tf.Tensor: shape=(2, 2, 1), dtype=float32, numpy=
|
45 |
+
array([[[0.00787402],
|
46 |
+
[0.01574803]],
|
47 |
+
[[0.02362205],
|
48 |
+
[0.03149606]]], dtype=float32)>
|
49 |
+
>>> a_int32 = tf.convert_to_tensor(a, dtype=tf.int32)
|
50 |
+
>>> tf.image.convert_image_dtype(a_int32, dtype=tf.float32)
|
51 |
+
<tf.Tensor: shape=(2, 2, 1), dtype=float32, numpy=
|
52 |
+
array([[[4.6566129e-10],
|
53 |
+
[9.3132257e-10]],
|
54 |
+
[[1.3969839e-09],
|
55 |
+
[1.8626451e-09]]], dtype=float32)>
|
56 |
+
Despite having identical values of `a` and output dtype of `float32`, the
|
57 |
+
outputs differ due to the different input dtypes (`int8` vs. `int32`). This
|
58 |
+
is, again, because the values are normalized by the `MAX` value of the input
|
59 |
+
dtype.
|
60 |
+
Note that converting floating point values to integer type may lose precision.
|
61 |
+
In the example below, an image tensor `b` of dtype `float32` is converted to
|
62 |
+
`int8` and back to `float32`. The final output, however, is different from
|
63 |
+
the original input `b` due to precision loss.
|
64 |
+
>>> b = [[[0.12], [0.34]], [[0.56], [0.78]]]
|
65 |
+
>>> b_float32 = tf.convert_to_tensor(b, dtype=tf.float32)
|
66 |
+
>>> b_int8 = tf.image.convert_image_dtype(b_float32, dtype=tf.int8)
|
67 |
+
>>> tf.image.convert_image_dtype(b_int8, dtype=tf.float32)
|
68 |
+
<tf.Tensor: shape=(2, 2, 1), dtype=float32, numpy=
|
69 |
+
array([[[0.11811024],
|
70 |
+
[0.33858266]],
|
71 |
+
[[0.5590551 ],
|
72 |
+
[0.77952754]]], dtype=float32)>
|
73 |
+
Scaling up from an integer type (input dtype) to another integer type (output
|
74 |
+
dtype) will not map input dtype's `MAX` to output dtype's `MAX` but converting
|
75 |
+
back and forth should result in no change. For example, as shown below, the
|
76 |
+
`MAX` value of int8 (=127) is not mapped to the `MAX` value of int16 (=32,767)
|
77 |
+
but, when scaled back, we get the same, original values of `c`.
|
78 |
+
>>> c = [[[1], [2]], [[127], [127]]]
|
79 |
+
>>> c_int8 = tf.convert_to_tensor(c, dtype=tf.int8)
|
80 |
+
>>> c_int16 = tf.image.convert_image_dtype(c_int8, dtype=tf.int16)
|
81 |
+
>>> print(c_int16)
|
82 |
+
tf.Tensor(
|
83 |
+
[[[ 256]
|
84 |
+
[ 512]]
|
85 |
+
[[32512]
|
86 |
+
[32512]]], shape=(2, 2, 1), dtype=int16)
|
87 |
+
>>> c_int8_back = tf.image.convert_image_dtype(c_int16, dtype=tf.int8)
|
88 |
+
>>> print(c_int8_back)
|
89 |
+
tf.Tensor(
|
90 |
+
[[[ 1]
|
91 |
+
[ 2]]
|
92 |
+
[[127]
|
93 |
+
[127]]], shape=(2, 2, 1), dtype=int8)
|
94 |
+
Scaling down from an integer type to another integer type can be a lossy
|
95 |
+
conversion. Notice in the example below that converting `int16` to `uint8` and
|
96 |
+
back to `int16` has lost precision.
|
97 |
+
>>> d = [[[1000], [2000]], [[3000], [4000]]]
|
98 |
+
>>> d_int16 = tf.convert_to_tensor(d, dtype=tf.int16)
|
99 |
+
>>> d_uint8 = tf.image.convert_image_dtype(d_int16, dtype=tf.uint8)
|
100 |
+
>>> d_int16_back = tf.image.convert_image_dtype(d_uint8, dtype=tf.int16)
|
101 |
+
>>> print(d_int16_back)
|
102 |
+
tf.Tensor(
|
103 |
+
[[[ 896]
|
104 |
+
[1920]]
|
105 |
+
[[2944]
|
106 |
+
[3968]]], shape=(2, 2, 1), dtype=int16)
|
107 |
+
Note that converting from floating point inputs to integer types may lead to
|
108 |
+
over/underflow problems. Set saturate to `True` to avoid such problem in
|
109 |
+
problematic conversions. If enabled, saturation will clip the output into the
|
110 |
+
allowed range before performing a potentially dangerous cast (and only before
|
111 |
+
performing such a cast, i.e., when casting from a floating point to an integer
|
112 |
+
type, and when casting from a signed to an unsigned type; `saturate` has no
|
113 |
+
effect on casts between floats, or on casts that increase the type's range).
|
114 |
+
Args:
|
115 |
+
image: An image.
|
116 |
+
dtype: A `DType` to convert `image` to.
|
117 |
+
saturate: If `True`, clip the input before casting (if necessary).
|
118 |
+
name: A name for this operation (optional).
|
119 |
+
Returns:
|
120 |
+
`image`, converted to `dtype`.
|
121 |
+
Raises:
|
122 |
+
AttributeError: Raises an attribute error when dtype is neither
|
123 |
+
float nor integer
|
124 |
+
"""
|
125 |
+
image = ops.convert_to_tensor(image, name='image')
|
126 |
+
dtype = dtypes.as_dtype(dtype)
|
127 |
+
if not dtype.is_floating and not dtype.is_integer:
|
128 |
+
raise AttributeError('dtype must be either floating point or integer')
|
129 |
+
if dtype == image.dtype:
|
130 |
+
return array_ops.identity(image, name=name)
|
131 |
+
|
132 |
+
with ops.name_scope(name, 'convert_image', [image]) as name:
|
133 |
+
# Both integer: use integer multiplication in the larger range
|
134 |
+
if image.dtype.is_integer and dtype.is_integer:
|
135 |
+
scale_in = image.dtype.max
|
136 |
+
scale_out = dtype.max
|
137 |
+
if scale_in > scale_out:
|
138 |
+
# Scaling down, scale first, then cast. The scaling factor will
|
139 |
+
# cause in.max to be mapped to above out.max but below out.max+1,
|
140 |
+
# so that the output is safely in the supported range.
|
141 |
+
scale = (scale_in + 1) // (scale_out + 1)
|
142 |
+
scaled = math_ops.floordiv(image, scale)
|
143 |
+
|
144 |
+
if saturate:
|
145 |
+
return math_ops.saturate_cast(scaled, dtype, name=name)
|
146 |
+
else:
|
147 |
+
return math_ops.cast(scaled, dtype, name=name)
|
148 |
+
else:
|
149 |
+
# Scaling up, cast first, then scale. The scale will not map in.max to
|
150 |
+
# out.max, but converting back and forth should result in no change.
|
151 |
+
if saturate:
|
152 |
+
cast = math_ops.saturate_cast(image, dtype)
|
153 |
+
else:
|
154 |
+
cast = math_ops.cast(image, dtype)
|
155 |
+
scale = (scale_out + 1) // (scale_in + 1)
|
156 |
+
return math_ops.multiply(cast, scale, name=name)
|
157 |
+
elif image.dtype.is_floating and dtype.is_floating:
|
158 |
+
# Both float: Just cast, no possible overflows in the allowed ranges.
|
159 |
+
# Note: We're ignoring float overflows. If your image dynamic range
|
160 |
+
# exceeds float range, you're on your own.
|
161 |
+
return math_ops.cast(image, dtype, name=name)
|
162 |
+
else:
|
163 |
+
if image.dtype.is_integer:
|
164 |
+
# Converting to float: first cast, then scale. No saturation possible.
|
165 |
+
cast = math_ops.cast(image, dtype)
|
166 |
+
scale = 1. / image.dtype.max
|
167 |
+
return math_ops.multiply(cast, scale, name=name)
|
168 |
+
else:
|
169 |
+
# Converting from float: first scale, then cast
|
170 |
+
scale = dtype.max + 0.5 # avoid rounding problems in the cast
|
171 |
+
scaled = math_ops.multiply(image, scale)
|
172 |
+
if saturate:
|
173 |
+
return math_ops.saturate_cast(scaled, dtype, name=name)
|
174 |
+
else:
|
175 |
+
return math_ops.cast(scaled, dtype, name=name)
|
176 |
+
|
177 |
+
|
178 |
+
def _verify_compatible_image_shapes(img1, img2):
|
179 |
+
"""Checks if two image tensors are compatible for applying SSIM or PSNR.
|
180 |
+
This function checks if two sets of images have ranks at least 3, and if the
|
181 |
+
last three dimensions match.
|
182 |
+
Args:
|
183 |
+
img1: Tensor containing the first image batch.
|
184 |
+
img2: Tensor containing the second image batch.
|
185 |
+
Returns:
|
186 |
+
A tuple containing: the first tensor shape, the second tensor shape, and a
|
187 |
+
list of control_flow_ops.Assert() ops implementing the checks.
|
188 |
+
Raises:
|
189 |
+
ValueError: When static shape check fails.
|
190 |
+
"""
|
191 |
+
shape1 = img1.get_shape().with_rank_at_least(4) # at least [H, W, D, C]
|
192 |
+
shape2 = img2.get_shape().with_rank_at_least(4) # at least [H, W, D, C]
|
193 |
+
shape1[-4:].assert_is_compatible_with(shape2[-4:])
|
194 |
+
|
195 |
+
if shape1.ndims is not None and shape2.ndims is not None:
|
196 |
+
for dim1, dim2 in zip(
|
197 |
+
reversed(shape1.dims[:-4]), reversed(shape2.dims[:-4])):
|
198 |
+
if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)):
|
199 |
+
raise ValueError('Two images are not compatible: %s and %s' %
|
200 |
+
(shape1, shape2))
|
201 |
+
|
202 |
+
# Now assign shape tensors.
|
203 |
+
shape1, shape2 = array_ops.shape_n([img1, img2])
|
204 |
+
|
205 |
+
# TODO(sjhwang): Check if shape1[:-4] and shape2[:-4] are broadcastable.
|
206 |
+
checks = []
|
207 |
+
checks.append(
|
208 |
+
control_flow_ops.Assert(
|
209 |
+
math_ops.greater_equal(array_ops.size(shape1), 4), [shape1, shape2],
|
210 |
+
summarize=10))
|
211 |
+
checks.append(
|
212 |
+
control_flow_ops.Assert(
|
213 |
+
math_ops.reduce_all(math_ops.equal(shape1[-4:], shape2[-4:])),
|
214 |
+
[shape1, shape2],
|
215 |
+
summarize=10))
|
216 |
+
return shape1, shape2, checks
|
217 |
+
|
218 |
+
|
219 |
+
def _ssim_helper(x, y, reducer, max_val, compensation=1.0, k1=0.01, k2=0.03):
|
220 |
+
r"""Helper function for computing SSIM.
|
221 |
+
SSIM estimates covariances with weighted sums. The default parameters
|
222 |
+
use a biased estimate of the covariance:
|
223 |
+
Suppose `reducer` is a weighted sum, then the mean estimators are
|
224 |
+
\mu_x = \sum_i w_i x_i,
|
225 |
+
\mu_y = \sum_i w_i y_i,
|
226 |
+
where w_i's are the weighted-sum weights, and covariance estimator is
|
227 |
+
cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
|
228 |
+
with assumption \sum_i w_i = 1. This covariance estimator is biased, since
|
229 |
+
E[cov_{xy}] = (1 - \sum_i w_i ^ 2) Cov(X, Y).
|
230 |
+
For SSIM measure with unbiased covariance estimators, pass as `compensation`
|
231 |
+
argument (1 - \sum_i w_i ^ 2).
|
232 |
+
Args:
|
233 |
+
x: First set of images.
|
234 |
+
y: Second set of images.
|
235 |
+
reducer: Function that computes 'local' averages from the set of images. For
|
236 |
+
non-convolutional version, this is usually tf.reduce_mean(x, [1, 2]), and
|
237 |
+
for convolutional version, this is usually tf.nn.avg_pool2d or
|
238 |
+
tf.nn.conv3d with weighted-sum kernel.
|
239 |
+
max_val: The dynamic range (i.e., the difference between the maximum
|
240 |
+
possible allowed value and the minimum allowed value).
|
241 |
+
compensation: Compensation factor. See above.
|
242 |
+
k1: Default value 0.01
|
243 |
+
k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
|
244 |
+
it would be better if we took the values in the range of 0 < K2 < 0.4).
|
245 |
+
Returns:
|
246 |
+
A pair containing the luminance measure, and the contrast-structure measure.
|
247 |
+
"""
|
248 |
+
|
249 |
+
c1 = (k1 * max_val)**2
|
250 |
+
c2 = (k2 * max_val)**2
|
251 |
+
|
252 |
+
# SSIM luminance measure is
|
253 |
+
# (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
|
254 |
+
mean0 = reducer(x)
|
255 |
+
mean1 = reducer(y)
|
256 |
+
num0 = mean0 * mean1 * 2.0
|
257 |
+
den0 = math_ops.square(mean0) + math_ops.square(mean1)
|
258 |
+
luminance = (num0 + c1) / (den0 + c1)
|
259 |
+
|
260 |
+
# SSIM contrast-structure measure is
|
261 |
+
# (2 * cov_{xy} + c2) / (cov_{xx} + cov_{yy} + c2).
|
262 |
+
# Note that `reducer` is a weighted sum with weight w_k, \sum_i w_i = 1, then
|
263 |
+
# cov_{xy} = \sum_i w_i (x_i - \mu_x) (y_i - \mu_y)
|
264 |
+
# = \sum_i w_i x_i y_i - (\sum_i w_i x_i) (\sum_j w_j y_j).
|
265 |
+
num1 = reducer(x * y) * 2.0
|
266 |
+
den1 = reducer(math_ops.square(x) + math_ops.square(y))
|
267 |
+
c2 *= compensation
|
268 |
+
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
|
269 |
+
|
270 |
+
# SSIM score is the product of the luminance and contrast-structure measures.
|
271 |
+
return luminance, cs
|
272 |
+
|
273 |
+
|
274 |
+
def _fspecial_gauss(size, sigma):
|
275 |
+
"""Function to mimic the 'fspecial' gaussian MATLAB function."""
|
276 |
+
size = ops.convert_to_tensor(size, dtypes.int32)
|
277 |
+
sigma = ops.convert_to_tensor(sigma, dtypes.float32)
|
278 |
+
|
279 |
+
coords = math_ops.cast(math_ops.range(size), sigma.dtype)
|
280 |
+
coords -= math_ops.cast(size - 1, sigma.dtype) / 2.0
|
281 |
+
|
282 |
+
g = math_ops.square(coords)
|
283 |
+
g *= -0.5 / math_ops.square(sigma)
|
284 |
+
|
285 |
+
g = array_ops.reshape(g, shape=[1, -1]) + array_ops.reshape(g, shape=[-1, 1])
|
286 |
+
g = array_ops.reshape(g, shape=[size, size, 1]) + array_ops.reshape(g, shape=[1, size, size])
|
287 |
+
g = array_ops.reshape(g, shape=[1, -1]) # For tf.nn.softmax().
|
288 |
+
g = nn_ops.softmax(g)
|
289 |
+
return array_ops.reshape(g, shape=[size, size, size, 1, 1])
|
290 |
+
|
291 |
+
|
292 |
+
def _ssim_per_channel(img1,
|
293 |
+
img2,
|
294 |
+
max_val=1.0,
|
295 |
+
filter_size=11,
|
296 |
+
filter_sigma=1.5,
|
297 |
+
k1=0.01,
|
298 |
+
k2=0.03):
|
299 |
+
"""Computes SSIM index between img1 and img2 per color channel.
|
300 |
+
This function matches the standard SSIM implementation from:
|
301 |
+
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
|
302 |
+
quality assessment: from error visibility to structural similarity. IEEE
|
303 |
+
transactions on image processing.
|
304 |
+
Details:
|
305 |
+
- 11x11 Gaussian filter of width 1.5 is used.
|
306 |
+
- k1 = 0.01, k2 = 0.03 as in the original paper.
|
307 |
+
Args:
|
308 |
+
img1: First image batch.
|
309 |
+
img2: Second image batch.
|
310 |
+
max_val: The dynamic range of the images (i.e., the difference between the
|
311 |
+
maximum the and minimum allowed values).
|
312 |
+
filter_size: Default value 11 (size of gaussian filter).
|
313 |
+
filter_sigma: Default value 1.5 (width of gaussian filter).
|
314 |
+
k1: Default value 0.01
|
315 |
+
k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
|
316 |
+
it would be better if we took the values in the range of 0 < K2 < 0.4).
|
317 |
+
Returns:
|
318 |
+
A pair of tensors containing and channel-wise SSIM and contrast-structure
|
319 |
+
values. The shape is [..., channels].
|
320 |
+
"""
|
321 |
+
filter_size = constant_op.constant(filter_size, dtype=dtypes.int32)
|
322 |
+
filter_sigma = constant_op.constant(filter_sigma, dtype=img1.dtype)
|
323 |
+
|
324 |
+
shape1, shape2 = array_ops.shape_n([img1, img2])
|
325 |
+
checks = [
|
326 |
+
control_flow_ops.Assert(
|
327 |
+
math_ops.reduce_all(
|
328 |
+
math_ops.greater_equal(shape1[-4:-1], filter_size)),
|
329 |
+
[shape1, filter_size],
|
330 |
+
summarize=8),
|
331 |
+
control_flow_ops.Assert(
|
332 |
+
math_ops.reduce_all(
|
333 |
+
math_ops.greater_equal(shape2[-4:-1], filter_size)),
|
334 |
+
[shape2, filter_size],
|
335 |
+
summarize=8)
|
336 |
+
]
|
337 |
+
|
338 |
+
# Enforce the check to run before computation.
|
339 |
+
with ops.control_dependencies(checks):
|
340 |
+
img1 = array_ops.identity(img1)
|
341 |
+
|
342 |
+
# TODO(sjhwang): Try to cache kernels and compensation factor.
|
343 |
+
kernel = _fspecial_gauss(filter_size, filter_sigma)
|
344 |
+
kernel = array_ops.tile(kernel, multiples=[1, 1, 1, shape1[-1], 1])
|
345 |
+
|
346 |
+
# The correct compensation factor is `1.0 - tf.reduce_sum(tf.square(kernel))`,
|
347 |
+
# but to match MATLAB implementation of MS-SSIM, we use 1.0 instead.
|
348 |
+
compensation = 1.0
|
349 |
+
|
350 |
+
# TODO(sjhwang): Try FFT.
|
351 |
+
# TODO(sjhwang): Gaussian kernel is separable in space. Consider applying
|
352 |
+
# 1-by-n and n-by-1 Gaussian filters instead of an n-by-n filter.
|
353 |
+
def reducer(x):
|
354 |
+
shape = array_ops.shape(x)
|
355 |
+
x = array_ops.reshape(x, shape=array_ops.concat([[-1], shape[-4:]], 0))
|
356 |
+
y = nn.conv3d(x, kernel, strides=[1, 1, 1, 1, 1], padding='VALID')
|
357 |
+
return array_ops.reshape(y, array_ops.concat([shape[:-4], array_ops.shape(y)[1:]], 0))
|
358 |
+
|
359 |
+
luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation, k1,
|
360 |
+
k2)
|
361 |
+
|
362 |
+
# Average over the second, third and the fourth from the last: height, width, depth.
|
363 |
+
axes = constant_op.constant([-4, -3, -2], dtype=dtypes.int32)
|
364 |
+
ssim_val = math_ops.reduce_mean(luminance * cs, axes)
|
365 |
+
cs = math_ops.reduce_mean(cs, axes)
|
366 |
+
return ssim_val, cs
|
367 |
+
|
368 |
+
|
369 |
+
@tf_export('image.ssim')
|
370 |
+
@dispatch.add_dispatch_support
|
371 |
+
def ssim(img1,
|
372 |
+
img2,
|
373 |
+
max_val,
|
374 |
+
filter_size=11,
|
375 |
+
filter_sigma=1.5,
|
376 |
+
k1=0.01,
|
377 |
+
k2=0.03):
|
378 |
+
"""Computes SSIM index between img1 and img2.
|
379 |
+
This function is based on the standard SSIM implementation from:
|
380 |
+
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). Image
|
381 |
+
quality assessment: from error visibility to structural similarity. IEEE
|
382 |
+
transactions on image processing.
|
383 |
+
Note: The true SSIM is only defined on grayscale. This function does not
|
384 |
+
perform any colorspace transform. (If the input is already YUV, then it will
|
385 |
+
compute YUV SSIM average.)
|
386 |
+
Details:
|
387 |
+
- 11x11 Gaussian filter of width 1.5 is used.
|
388 |
+
- k1 = 0.01, k2 = 0.03 as in the original paper.
|
389 |
+
The image sizes must be at least 11x11 because of the filter size.
|
390 |
+
Example:
|
391 |
+
```python
|
392 |
+
# Read images (of size 255 x 255) from file.
|
393 |
+
im1 = tf.image.decode_image(tf.io.read_file('path/to/im1.png'))
|
394 |
+
im2 = tf.image.decode_image(tf.io.read_file('path/to/im2.png'))
|
395 |
+
tf.shape(im1) # `img1.png` has 3 channels; shape is `(255, 255, 3)`
|
396 |
+
tf.shape(im2) # `img2.png` has 3 channels; shape is `(255, 255, 3)`
|
397 |
+
# Add an outer batch for each image.
|
398 |
+
im1 = tf.expand_dims(im1, axis=0)
|
399 |
+
im2 = tf.expand_dims(im2, axis=0)
|
400 |
+
# Compute SSIM over tf.uint8 Tensors.
|
401 |
+
ssim1 = tf.image.ssim(im1, im2, max_val=255, filter_size=11,
|
402 |
+
filter_sigma=1.5, k1=0.01, k2=0.03)
|
403 |
+
# Compute SSIM over tf.float32 Tensors.
|
404 |
+
im1 = tf.image.convert_image_dtype(im1, tf.float32)
|
405 |
+
im2 = tf.image.convert_image_dtype(im2, tf.float32)
|
406 |
+
ssim2 = tf.image.ssim(im1, im2, max_val=1.0, filter_size=11,
|
407 |
+
filter_sigma=1.5, k1=0.01, k2=0.03)
|
408 |
+
# ssim1 and ssim2 both have type tf.float32 and are almost equal.
|
409 |
+
```
|
410 |
+
Args:
|
411 |
+
img1: First image batch. 4-D Tensor of shape `[batch, height, width,
|
412 |
+
channels]` with only Positive Pixel Values.
|
413 |
+
img2: Second image batch. 4-D Tensor of shape `[batch, height, width,
|
414 |
+
channels]` with only Positive Pixel Values.
|
415 |
+
max_val: The dynamic range of the images (i.e., the difference between the
|
416 |
+
maximum the and minimum allowed values).
|
417 |
+
filter_size: Default value 11 (size of gaussian filter).
|
418 |
+
filter_sigma: Default value 1.5 (width of gaussian filter).
|
419 |
+
k1: Default value 0.01
|
420 |
+
k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
|
421 |
+
it would be better if we took the values in the range of 0 < K2 < 0.4).
|
422 |
+
Returns:
|
423 |
+
A tensor containing an SSIM value for each image in batch. Returned SSIM
|
424 |
+
values are in range (-1, 1], when pixel values are non-negative. Returns
|
425 |
+
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
|
426 |
+
"""
|
427 |
+
with ops.name_scope(None, 'SSIM', [img1, img2]):
|
428 |
+
# Convert to tensor if needed.
|
429 |
+
img1 = ops.convert_to_tensor(img1, name='img1')
|
430 |
+
img2 = ops.convert_to_tensor(img2, name='img2')
|
431 |
+
# Shape checking.
|
432 |
+
_, _, checks = _verify_compatible_image_shapes(img1, img2)
|
433 |
+
with ops.control_dependencies(checks):
|
434 |
+
img1 = array_ops.identity(img1)
|
435 |
+
|
436 |
+
# Need to convert the images to float32. Scale max_val accordingly so that
|
437 |
+
# SSIM is computed correctly.
|
438 |
+
max_val = math_ops.cast(max_val, img1.dtype)
|
439 |
+
max_val = convert_image_dtype(max_val, dtypes.float32)
|
440 |
+
img1 = convert_image_dtype(img1, dtypes.float32)
|
441 |
+
img2 = convert_image_dtype(img2, dtypes.float32)
|
442 |
+
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
|
443 |
+
filter_sigma, k1, k2)
|
444 |
+
# Compute average over color channels.
|
445 |
+
return math_ops.reduce_mean(ssim_per_channel, [-1])
|
446 |
+
|
447 |
+
|
448 |
+
# Default values obtained by Wang et al.
|
449 |
+
_MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
|
450 |
+
|
451 |
+
|
452 |
+
@tf_export('image.ssim_multiscale')
|
453 |
+
@dispatch.add_dispatch_support
|
454 |
+
def ssim_multiscale(img1,
|
455 |
+
img2,
|
456 |
+
max_val,
|
457 |
+
power_factors=_MSSSIM_WEIGHTS,
|
458 |
+
filter_size=11,
|
459 |
+
filter_sigma=1.5,
|
460 |
+
k1=0.01,
|
461 |
+
k2=0.03):
|
462 |
+
"""Computes the MS-SSIM between img1 and img2.
|
463 |
+
This function assumes that `img1` and `img2` are image batches, i.e. the last
|
464 |
+
three dimensions are [height, width, channels].
|
465 |
+
Note: The true SSIM is only defined on grayscale. This function does not
|
466 |
+
perform any colorspace transform. (If the input is already YUV, then it will
|
467 |
+
compute YUV SSIM average.)
|
468 |
+
Original paper: Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. "Multiscale
|
469 |
+
structural similarity for image quality assessment." Signals, Systems and
|
470 |
+
Computers, 2004.
|
471 |
+
Args:
|
472 |
+
img1: First image batch with only Positive Pixel Values.
|
473 |
+
img2: Second image batch with only Positive Pixel Values. Must have the
|
474 |
+
same rank as img1.
|
475 |
+
max_val: The dynamic range of the images (i.e., the difference between the
|
476 |
+
maximum the and minimum allowed values).
|
477 |
+
power_factors: Iterable of weights for each of the scales. The number of
|
478 |
+
scales used is the length of the list. Index 0 is the unscaled
|
479 |
+
resolution's weight and each increasing scale corresponds to the image
|
480 |
+
being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
|
481 |
+
0.1333), which are the values obtained in the original paper.
|
482 |
+
filter_size: Default value 11 (size of gaussian filter).
|
483 |
+
filter_sigma: Default value 1.5 (width of gaussian filter).
|
484 |
+
k1: Default value 0.01
|
485 |
+
k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
|
486 |
+
it would be better if we took the values in the range of 0 < K2 < 0.4).
|
487 |
+
Returns:
|
488 |
+
A tensor containing an MS-SSIM value for each image in batch. The values
|
489 |
+
are in range [0, 1]. Returns a tensor with shape:
|
490 |
+
broadcast(img1.shape[:-3], img2.shape[:-3]).
|
491 |
+
"""
|
492 |
+
with ops.name_scope(None, 'MS-SSIM', [img1, img2]):
|
493 |
+
# Convert to tensor if needed.
|
494 |
+
img1 = ops.convert_to_tensor(img1, name='img1')
|
495 |
+
img2 = ops.convert_to_tensor(img2, name='img2')
|
496 |
+
# Shape checking.
|
497 |
+
shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2)
|
498 |
+
with ops.control_dependencies(checks):
|
499 |
+
img1 = array_ops.identity(img1)
|
500 |
+
|
501 |
+
# Need to convert the images to float32. Scale max_val accordingly so that
|
502 |
+
# SSIM is computed correctly.
|
503 |
+
max_val = math_ops.cast(max_val, img1.dtype)
|
504 |
+
max_val = convert_image_dtype(max_val, dtypes.float32)
|
505 |
+
img1 = convert_image_dtype(img1, dtypes.float32)
|
506 |
+
img2 = convert_image_dtype(img2, dtypes.float32)
|
507 |
+
|
508 |
+
imgs = [img1, img2]
|
509 |
+
shapes = [shape1, shape2]
|
510 |
+
|
511 |
+
# img1 and img2 are assumed to be a (multi-dimensional) batch of
|
512 |
+
# 4-dimensional images (height, width, depth, channels). `heads` contain the batch
|
513 |
+
# dimensions, and `tails` contain the image dimensions.
|
514 |
+
heads = [s[:-4] for s in shapes]
|
515 |
+
tails = [s[-4:] for s in shapes]
|
516 |
+
|
517 |
+
divisor = [1, 2, 2, 2, 1]
|
518 |
+
divisor_tensor = constant_op.constant(divisor[1:], dtype=dtypes.int32)
|
519 |
+
|
520 |
+
def do_pad(images, remainder):
|
521 |
+
padding = array_ops.expand_dims(remainder, -1)
|
522 |
+
padding = array_ops.pad(padding, [[1, 0], [1, 0]])
|
523 |
+
return [array_ops.pad(x, padding, mode='SYMMETRIC') for x in images]
|
524 |
+
|
525 |
+
mcs = []
|
526 |
+
for k in range(len(power_factors)):
|
527 |
+
with ops.name_scope(None, 'Scale%d' % k, imgs):
|
528 |
+
if k > 0:
|
529 |
+
# Avg pool takes rank 4 tensors. Flatten leading dimensions.
|
530 |
+
flat_imgs = [
|
531 |
+
array_ops.reshape(x, array_ops.concat([[-1], t], 0))
|
532 |
+
for x, t in zip(imgs, tails)
|
533 |
+
]
|
534 |
+
|
535 |
+
remainder = tails[0] % divisor_tensor
|
536 |
+
need_padding = math_ops.reduce_any(math_ops.not_equal(remainder, 0))
|
537 |
+
# pylint: disable=cell-var-from-loop
|
538 |
+
padded = control_flow_ops.cond(need_padding,
|
539 |
+
lambda: do_pad(flat_imgs, remainder),
|
540 |
+
lambda: flat_imgs)
|
541 |
+
# pylint: enable=cell-var-from-loop
|
542 |
+
|
543 |
+
downscaled = [
|
544 |
+
nn_ops.avg_pool3d(
|
545 |
+
x, ksize=divisor, strides=divisor, padding='VALID', data_format='NDHWC',)
|
546 |
+
for x in padded
|
547 |
+
]
|
548 |
+
tails = [x[1:] for x in array_ops.shape_n(downscaled)]
|
549 |
+
imgs = [
|
550 |
+
array_ops.reshape(x, array_ops.concat([h, t], 0))
|
551 |
+
for x, h, t in zip(downscaled, heads, tails)
|
552 |
+
]
|
553 |
+
|
554 |
+
# Overwrite previous ssim value since we only need the last one.
|
555 |
+
ssim_per_channel, cs = _ssim_per_channel(
|
556 |
+
*imgs,
|
557 |
+
max_val=max_val,
|
558 |
+
filter_size=filter_size,
|
559 |
+
filter_sigma=filter_sigma,
|
560 |
+
k1=k1,
|
561 |
+
k2=k2)
|
562 |
+
mcs.append(nn_ops.relu(cs))
|
563 |
+
|
564 |
+
# Remove the cs score for the last scale. In the MS-SSIM calculation,
|
565 |
+
# we use the l(p) at the highest scale. l(p) * cs(p) is ssim(p).
|
566 |
+
mcs.pop() # Remove the cs score for the last scale.
|
567 |
+
mcs_and_ssim = array_ops.stack(
|
568 |
+
mcs + [nn_ops.relu(ssim_per_channel)], axis=-1)
|
569 |
+
# Take weighted geometric mean across the scale axis.
|
570 |
+
ms_ssim = math_ops.reduce_prod(
|
571 |
+
math_ops.pow(mcs_and_ssim, power_factors), [-1])
|
572 |
+
|
573 |
+
return math_ops.reduce_mean(ms_ssim, [-1]) # Avg over color channels.
|
574 |
+
|
575 |
+
|
576 |
+
class MultiScaleStructuralSimilarity:
|
577 |
+
def __init__(self,
|
578 |
+
max_val,
|
579 |
+
power_factors=_MSSSIM_WEIGHTS,
|
580 |
+
filter_size=11,
|
581 |
+
filter_sigma=1.5,
|
582 |
+
k1=0.01,
|
583 |
+
k2=0.03):
|
584 |
+
self.max_val = max_val
|
585 |
+
self.power_factors = power_factors
|
586 |
+
self.filter_size = int(filter_size)
|
587 |
+
self.filter_sigma = filter_sigma
|
588 |
+
self.k1 = k1
|
589 |
+
self.k2 = k2
|
590 |
+
|
591 |
+
@function_decorator('MS_SSIM__loss')
|
592 |
+
def loss(self, y_true, y_pred):
|
593 |
+
return math_ops.reduce_mean((1 - ssim_multiscale(y_true, y_pred, self.max_val, self.power_factors,
|
594 |
+
self.filter_size, self.filter_sigma, self.k1, self.k2))/2)
|
595 |
+
|
596 |
+
@function_decorator('MS_SSIM__metric')
|
597 |
+
def metric(self, y_true, y_pred):
|
598 |
+
return ssim_multiscale(y_true, y_pred, self.max_val, self.power_factors, self.filter_size, self.filter_sigma,
|
599 |
+
self.k1, self.k2)
|
600 |
+
|
601 |
+
|
602 |
+
if __name__ == '__main__':
|
603 |
+
import os
|
604 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
605 |
+
import tensorflow as tf
|
606 |
+
tf.enable_eager_execution()
|
607 |
+
import nibabel as nib
|
608 |
+
import numpy as np
|
609 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
610 |
+
from skimage.metrics import structural_similarity
|
611 |
+
|
612 |
+
img1 = nib.load('test_images/ixi_image.nii.gz')
|
613 |
+
img1 = np.asarray(img1.dataobj)
|
614 |
+
img1 = img1[np.newaxis, ..., np.newaxis] # Add Batch and Channel dimensions
|
615 |
+
|
616 |
+
img2 = nib.load('test_images/ixi_image2.nii.gz')
|
617 |
+
img2 = np.asarray(img2.dataobj)
|
618 |
+
img2 = img2[np.newaxis, ..., np.newaxis]
|
619 |
+
|
620 |
+
img1 = min_max_norm(img1)
|
621 |
+
img2 = min_max_norm(img2)
|
622 |
+
|
623 |
+
ssim_tf_1_2 = ssim(img1, img2, 1., filter_size=5)
|
624 |
+
assert ssim(img1, img1, 1., filter_size=5).numpy()[0] == 1., 'TF SSIM returned an unexpected value'
|
625 |
+
ssim_sklearn = structural_similarity(img1[0, ..., 0], img2[0, ..., 0], win_size=5)
|
626 |
+
|
627 |
+
ms_ssim_tf_1_2 = ssim_multiscale(img1, img2, 1., filter_size=5)
|
628 |
+
assert ssim_multiscale(img1, img1, 1., filter_size=5).numpy()[0] == 1., 'TF MS-SSIM returned an unexpected value'
|
629 |
+
|
630 |
+
print('SSIM TF: {}\nSSIM SKLEARN: {}\nMS SSIM TF: {}\n'.format(ssim_tf_1_2, ssim_sklearn, ms_ssim_tf_1_2))
|
631 |
+
|
632 |
+
batch_img1 = np.stack([img1, img2], axis=0)
|
633 |
+
batch_img2 = np.stack([img2, img2], axis=0)
|
634 |
+
batch_ssim_tf = ssim(batch_img1, batch_img2, 1., filter_size=5)
|
635 |
+
batch_ms_ssim_tf = ssim_multiscale(batch_img1, batch_img2, 1., filter_size=5)
|
636 |
+
|
637 |
+
print('Batch SSIM TF: {}\nBatch MS SSIM TF: {}\n'.format(batch_ssim_tf, batch_ms_ssim_tf))
|
638 |
+
|
639 |
+
img1 = img1[:, :127, :127, :127, :]
|
640 |
+
img2 = img2[:, :127, :127, :127, :]
|
641 |
+
MS_SSIM = MultiScaleStructuralSimilarity(1., filter_size=5)
|
642 |
+
print('MS SSIM Loss{}'.format(MS_SSIM.loss(img1, img2)))
|
DeepDeformationMapRegistration/networks.py
CHANGED
@@ -8,12 +8,14 @@ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
|
8 |
import tensorflow as tf
|
9 |
import voxelmorph as vxm
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
|
|
11 |
|
12 |
|
13 |
class WeaklySupervised(LoadableModel):
|
14 |
|
15 |
@store_config_args
|
16 |
-
def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False,
|
|
|
17 |
"""
|
18 |
Parameters:
|
19 |
inshape: Input shape. e.g. (192, 192, 192)
|
@@ -21,34 +23,53 @@ class WeaklySupervised(LoadableModel):
|
|
21 |
hot_labels: List of labels to output as one-hot maps.
|
22 |
nb_unet_features: Unet convolutional features. See VxmDense documentation for more information.
|
23 |
int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
|
|
|
24 |
kwargs: Forwarded to the internal VxmDense model.
|
25 |
"""
|
26 |
|
27 |
-
fix_segm = tf.keras.Input((*inshape, len(all_labels)), name='fix_segmentations_input')
|
28 |
mov_segm = tf.keras.Input((*inshape, len(all_labels)), name='mov_segmentations_input')
|
29 |
|
|
|
30 |
mov_img = tf.keras.Input((*inshape, 1), name='mov_image_input')
|
31 |
|
32 |
-
|
33 |
|
34 |
vxm_model = vxm.networks.VxmDense(inshape=inshape,
|
35 |
nb_unet_features=nb_unet_features,
|
36 |
-
input_model=
|
37 |
int_steps=int_steps,
|
38 |
-
bidir=bidir,
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
self.references = LoadableModel.ReferenceContainer()
|
47 |
-
self.references.pred_segm =
|
48 |
-
self.references.pred_img =
|
49 |
self.references.pos_flow = vxm_model.references.pos_flow
|
50 |
|
51 |
-
super().__init__(inputs=inputs, outputs=outputs)
|
52 |
|
53 |
def get_registration_model(self):
|
54 |
return tf.keras.Model(self.inputs, self.references.pos_flow)
|
|
|
8 |
import tensorflow as tf
|
9 |
import voxelmorph as vxm
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
+
from tensorflow.keras.layers import UpSampling3D
|
12 |
|
13 |
|
14 |
class WeaklySupervised(LoadableModel):
|
15 |
|
16 |
@store_config_args
|
17 |
+
def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False,
|
18 |
+
int_downsize=1, outshape=None, **kwargs):
|
19 |
"""
|
20 |
Parameters:
|
21 |
inshape: Input shape. e.g. (192, 192, 192)
|
|
|
23 |
hot_labels: List of labels to output as one-hot maps.
|
24 |
nb_unet_features: Unet convolutional features. See VxmDense documentation for more information.
|
25 |
int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
|
26 |
+
int_downsize: Dowsampling of the displacement map. Integer
|
27 |
kwargs: Forwarded to the internal VxmDense model.
|
28 |
"""
|
29 |
|
|
|
30 |
mov_segm = tf.keras.Input((*inshape, len(all_labels)), name='mov_segmentations_input')
|
31 |
|
32 |
+
fix_img = tf.keras.Input((*inshape, 1), name='fix_image_input')
|
33 |
mov_img = tf.keras.Input((*inshape, 1), name='mov_image_input')
|
34 |
|
35 |
+
input_model = tf.keras.Model(inputs=[mov_img, fix_img], outputs=[mov_img, fix_img])
|
36 |
|
37 |
vxm_model = vxm.networks.VxmDense(inshape=inshape,
|
38 |
nb_unet_features=nb_unet_features,
|
39 |
+
input_model=input_model,
|
40 |
int_steps=int_steps,
|
41 |
+
bidir=bidir,
|
42 |
+
int_downsize=int_downsize,
|
43 |
+
**kwargs)
|
44 |
+
|
45 |
+
pred_segm = vxm.layers.SpatialTransformer(interp_method='linear', indexing='ij', name='interp_segm')(
|
46 |
+
[mov_segm, vxm_model.references.pos_flow])
|
47 |
+
|
48 |
+
inputs = [mov_img, fix_img, mov_segm] # mov_img, mov_segm, fix_segm
|
49 |
+
model_outputs = vxm_model.outputs
|
50 |
+
if outshape is not None:
|
51 |
+
scale_factors = [o//i for i, o in zip(inshape, outshape)]
|
52 |
+
upsampling_layer = UpSampling3D(scale_factors) # Doesn't perform trilinear, only nearest
|
53 |
+
# Image
|
54 |
+
model_outputs[0] = upsampling_layer(model_outputs[0])
|
55 |
+
# Segmentation
|
56 |
+
pred_segm = upsampling_layer(pred_segm)
|
57 |
+
# Displacement map
|
58 |
+
model_outputs[1] = upsampling_layer(scale_factors)(model_outputs[1])
|
59 |
+
model_outputs[1] = tf.multiply(model_outputs[1], tf.cast(scale_factors, model_outputs[1].dtype))
|
60 |
+
|
61 |
+
# Just renaming
|
62 |
+
pred_fix_image = tf.identity(model_outputs[0], name='pred_fix_image')
|
63 |
+
pred_dm = tf.identity(model_outputs[1], name='pred_dm')
|
64 |
+
pred_segm = tf.identity(pred_segm, name='pred_fix_segm')
|
65 |
+
outputs = [pred_fix_image, pred_segm, pred_dm]
|
66 |
|
67 |
self.references = LoadableModel.ReferenceContainer()
|
68 |
+
self.references.pred_segm = pred_segm
|
69 |
+
self.references.pred_img = vxm_model.outputs[0]
|
70 |
self.references.pos_flow = vxm_model.references.pos_flow
|
71 |
|
72 |
+
super(WeaklySupervised, self).__init__(inputs=inputs, outputs=outputs)
|
73 |
|
74 |
def get_registration_model(self):
|
75 |
return tf.keras.Model(self.inputs, self.references.pos_flow)
|
DeepDeformationMapRegistration/utils/acummulated_optimizer.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
-
from tensorflow.keras.optimizers import Optimizer
|
2 |
-
from tensorflow.keras import
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
class AccumOptimizer(Optimizer):
|
@@ -14,7 +18,7 @@ class AccumOptimizer(Optimizer):
|
|
14 |
a new keras optimizer.
|
15 |
"""
|
16 |
def __init__(self, optimizer, steps_per_update=1, **kwargs):
|
17 |
-
super(AccumOptimizer, self).__init__(**kwargs)
|
18 |
self.optimizer = optimizer
|
19 |
with K.name_scope(self.__class__.__name__):
|
20 |
self.steps_per_update = steps_per_update
|
@@ -55,3 +59,134 @@ class AccumOptimizer(Optimizer):
|
|
55 |
config = self.optimizer.get_config()
|
56 |
K.set_value(self.iterations, iterations)
|
57 |
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.python.keras.optimizers import Optimizer
|
2 |
+
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
3 |
+
from tensorflow.python import ops, math_ops, state_ops, control_flow_ops
|
4 |
+
from tensorflow.python.keras import backend as K
|
5 |
+
from tensorflow.python.keras import backend_config
|
6 |
+
import tensorflow as tf
|
7 |
|
8 |
|
9 |
class AccumOptimizer(Optimizer):
|
|
|
18 |
a new keras optimizer.
|
19 |
"""
|
20 |
def __init__(self, optimizer, steps_per_update=1, **kwargs):
|
21 |
+
super(AccumOptimizer, self).__init__(name='AccumOptimizer', **kwargs)
|
22 |
self.optimizer = optimizer
|
23 |
with K.name_scope(self.__class__.__name__):
|
24 |
self.steps_per_update = steps_per_update
|
|
|
59 |
config = self.optimizer.get_config()
|
60 |
K.set_value(self.iterations, iterations)
|
61 |
return config
|
62 |
+
|
63 |
+
|
64 |
+
__all__ = ['AdamAccumulated']
|
65 |
+
|
66 |
+
|
67 |
+
# SRC: https://github.com/CyberZHG/keras-gradient-accumulation/blob/master/keras_gradient_accumulation/optimizer_v2.py
|
68 |
+
class AdamAccumulated(OptimizerV2):
|
69 |
+
"""Optimizer that implements the Adam algorithm with gradient accumulation."""
|
70 |
+
|
71 |
+
def __init__(self,
|
72 |
+
accumulation_steps,
|
73 |
+
learning_rate=0.001,
|
74 |
+
beta_1=0.9,
|
75 |
+
beta_2=0.999,
|
76 |
+
epsilon=1e-7,
|
77 |
+
amsgrad=False,
|
78 |
+
name='Adam',
|
79 |
+
**kwargs):
|
80 |
+
r"""Construct a new Adam optimizer.
|
81 |
+
Args:
|
82 |
+
accumulation_steps: An integer. Update gradient in every accumulation steps.
|
83 |
+
learning_rate: A Tensor or a floating point value. The learning rate.
|
84 |
+
beta_1: A float value or a constant float tensor. The exponential decay
|
85 |
+
rate for the 1st moment estimates.
|
86 |
+
beta_2: A float value or a constant float tensor. The exponential decay
|
87 |
+
rate for the 2nd moment estimates.
|
88 |
+
epsilon: A small constant for numerical stability. This epsilon is
|
89 |
+
"epsilon hat" in the Kingma and Ba paper (in the formula just before
|
90 |
+
Section 2.1), not the epsilon in Algorithm 1 of the paper.
|
91 |
+
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
|
92 |
+
the paper "On the Convergence of Adam and beyond".
|
93 |
+
name: Optional name for the operations created when applying gradients.
|
94 |
+
Defaults to "Adam". @compatibility(eager) When eager execution is
|
95 |
+
enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be
|
96 |
+
a callable that takes no arguments and returns the actual value to use.
|
97 |
+
This can be useful for changing these values across different
|
98 |
+
invocations of optimizer functions. @end_compatibility
|
99 |
+
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
|
100 |
+
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
|
101 |
+
gradients by value, `decay` is included for backward compatibility to
|
102 |
+
allow time inverse decay of learning rate. `lr` is included for backward
|
103 |
+
compatibility, recommended to use `learning_rate` instead.
|
104 |
+
"""
|
105 |
+
|
106 |
+
super(AdamAccumulated, self).__init__(name, **kwargs)
|
107 |
+
self._set_hyper('accumulation_steps', accumulation_steps)
|
108 |
+
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
|
109 |
+
self._set_hyper('decay', self._initial_decay)
|
110 |
+
self._set_hyper('beta_1', beta_1)
|
111 |
+
self._set_hyper('beta_2', beta_2)
|
112 |
+
self.epsilon = epsilon or backend_config.epsilon()
|
113 |
+
self.amsgrad = amsgrad
|
114 |
+
|
115 |
+
def _create_slots(self, var_list):
|
116 |
+
for var in var_list:
|
117 |
+
self.add_slot(var, 'g')
|
118 |
+
for var in var_list:
|
119 |
+
self.add_slot(var, 'm')
|
120 |
+
for var in var_list:
|
121 |
+
self.add_slot(var, 'v')
|
122 |
+
if self.amsgrad:
|
123 |
+
for var in var_list:
|
124 |
+
self.add_slot(var, 'vhat')
|
125 |
+
|
126 |
+
def set_weights(self, weights):
|
127 |
+
params = self.weights
|
128 |
+
num_vars = int((len(params) - 1) / 2)
|
129 |
+
if len(weights) == 3 * num_vars + 1:
|
130 |
+
weights = weights[:len(params)]
|
131 |
+
super(AdamAccumulated, self).set_weights(weights)
|
132 |
+
|
133 |
+
def _resource_apply_dense(self, grad, var):
|
134 |
+
var_dtype = var.dtype.base_dtype
|
135 |
+
lr_t = self._decayed_lr(var_dtype)
|
136 |
+
beta_1_t = self._get_hyper('beta_1', var_dtype)
|
137 |
+
beta_2_t = self._get_hyper('beta_2', var_dtype)
|
138 |
+
accumulation_steps = self._get_hyper('accumulation_steps', 'int64')
|
139 |
+
update_cond = tf.equal((self.iterations + 1) % accumulation_steps, 0)
|
140 |
+
sub_step = self.iterations % accumulation_steps + 1
|
141 |
+
local_step = math_ops.cast(self.iterations // accumulation_steps + 1, var_dtype)
|
142 |
+
beta_1_power = math_ops.pow(beta_1_t, local_step)
|
143 |
+
beta_2_power = math_ops.pow(beta_2_t, local_step)
|
144 |
+
epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)
|
145 |
+
lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
|
146 |
+
lr = tf.where(update_cond, lr, 0.0)
|
147 |
+
|
148 |
+
g = self.get_slot(var, 'g')
|
149 |
+
g_a = grad / math_ops.cast(accumulation_steps, var_dtype)
|
150 |
+
g_t = tf.where(tf.equal(sub_step, 1),
|
151 |
+
g_a,
|
152 |
+
g + (g_a - g) / math_ops.cast(sub_step, var_dtype))
|
153 |
+
g_t = state_ops.assign(g, g_t, use_locking=self._use_locking)
|
154 |
+
|
155 |
+
m = self.get_slot(var, 'm')
|
156 |
+
m_t = tf.where(update_cond, m * beta_1_t + g_t * (1 - beta_1_t), m)
|
157 |
+
m_t = state_ops.assign(m, m_t, use_locking=self._use_locking)
|
158 |
+
|
159 |
+
v = self.get_slot(var, 'v')
|
160 |
+
v_t = tf.where(update_cond, v * beta_2_t + (g_t * g_t) * (1 - beta_2_t), v)
|
161 |
+
v_t = state_ops.assign(v, v_t, use_locking=self._use_locking)
|
162 |
+
|
163 |
+
if not self.amsgrad:
|
164 |
+
v_sqrt = math_ops.sqrt(v_t)
|
165 |
+
var_update = state_ops.assign_sub(
|
166 |
+
var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking)
|
167 |
+
return control_flow_ops.group(*[var_update, m_t, v_t])
|
168 |
+
else:
|
169 |
+
v_hat = self.get_slot(var, 'vhat')
|
170 |
+
v_hat_t = tf.where(update_cond, math_ops.maximum(v_hat, v_t), v_hat)
|
171 |
+
with ops.control_dependencies([v_hat_t]):
|
172 |
+
v_hat_t = state_ops.assign(
|
173 |
+
v_hat, v_hat_t, use_locking=self._use_locking)
|
174 |
+
v_hat_sqrt = math_ops.sqrt(v_hat_t)
|
175 |
+
var_update = state_ops.assign_sub(
|
176 |
+
var,
|
177 |
+
lr * m_t / (v_hat_sqrt + epsilon_t),
|
178 |
+
use_locking=self._use_locking)
|
179 |
+
return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
|
180 |
+
|
181 |
+
def get_config(self):
|
182 |
+
config = super(AdamAccumulated, self).get_config()
|
183 |
+
config.update({
|
184 |
+
'accumulation_steps': self._serialize_hyperparameter('accumulation_steps'),
|
185 |
+
'learning_rate': self._serialize_hyperparameter('learning_rate'),
|
186 |
+
'decay': self._serialize_hyperparameter('decay'),
|
187 |
+
'beta_1': self._serialize_hyperparameter('beta_1'),
|
188 |
+
'beta_2': self._serialize_hyperparameter('beta_2'),
|
189 |
+
'epsilon': self.epsilon,
|
190 |
+
'amsgrad': self.amsgrad,
|
191 |
+
})
|
192 |
+
return config
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
@@ -52,6 +52,8 @@ H5_MOV_TUMORS_MASK = 'input/{}'.format(MOVING_TUMORS_MASK)
|
|
52 |
H5_FIX_TUMORS_MASK = 'input/{}'.format(FIXED_TUMORS_MASK)
|
53 |
H5_FIX_SEGMENTATIONS = 'input/{}'.format(FIXED_SEGMENTATIONS)
|
54 |
H5_MOV_SEGMENTATIONS = 'input/{}'.format(MOVING_SEGMENTATIONS)
|
|
|
|
|
55 |
|
56 |
H5_GT_DISP = 'output/{}'.format(DISP_MAP_GT)
|
57 |
H5_GT_IMG = 'output/{}'.format(PRED_IMG_GT)
|
@@ -64,6 +66,7 @@ MAX_ANGLE = 45.0 # degrees
|
|
64 |
MAX_FLIPS = 2 # Axes to flip over
|
65 |
NUM_ROTATIONS = 5
|
66 |
MAX_WORKERS = 10
|
|
|
67 |
|
68 |
# Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
|
69 |
DG_LBL_FIX_IMG = H5_FIX_IMG
|
@@ -75,6 +78,8 @@ DG_LBL_MOV_VESSELS = H5_MOV_VESSELS_MASK
|
|
75 |
DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
|
76 |
DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
|
77 |
DG_LBL_ZERO_GRADS = 'zero_gradients'
|
|
|
|
|
78 |
|
79 |
# Training constants
|
80 |
MODEL = 'unet'
|
@@ -91,6 +96,8 @@ DATA_FORMAT = 'channels_last' # or 'channels_fist'
|
|
91 |
DATA_DIR = './data'
|
92 |
MODEL_CHECKPOINT = './model_checkpoint'
|
93 |
BATCH_SIZE = 8
|
|
|
|
|
94 |
EPOCHS = 100
|
95 |
SAVE_EPOCH = EPOCHS // 10 # Epoch when to save the model
|
96 |
VERBOSE_EPOCH = EPOCHS // 10
|
@@ -99,7 +106,6 @@ VALIDATION_ERR_LIMIT_COUNTER = 10 # Number of successive times the validation e
|
|
99 |
VALIDATION_ERR_LIMIT_COUNTER_BACKUP = 10
|
100 |
THRESHOLD = 0.5 # Threshold to select the centerline in the interpolated images
|
101 |
RESTORE_TRAINING = True # look for previously saved models to resume training
|
102 |
-
EARLY_STOP_PATIENCE = 10
|
103 |
LOG_FIELD_NAMES = ['time', 'epoch', 'step',
|
104 |
'training_loss_mean', 'training_loss_std',
|
105 |
'training_loss1_mean', 'training_loss1_std',
|
@@ -362,10 +368,13 @@ COORDS_GRID = CoordinatesGrid()
|
|
362 |
class VisualizationParameters:
|
363 |
def __init__(self):
|
364 |
self.__scale = None # See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.quiver.html
|
365 |
-
self.__spacing =
|
366 |
|
367 |
def set_spacing(self, img_shape: tf.TensorShape):
|
368 |
-
|
|
|
|
|
|
|
369 |
|
370 |
@property
|
371 |
def spacing(self):
|
@@ -495,3 +504,19 @@ MANUAL_W = [1.] * len(PRIOR_W)
|
|
495 |
|
496 |
REG_PRIOR_W = [1e-3]
|
497 |
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
H5_FIX_TUMORS_MASK = 'input/{}'.format(FIXED_TUMORS_MASK)
|
53 |
H5_FIX_SEGMENTATIONS = 'input/{}'.format(FIXED_SEGMENTATIONS)
|
54 |
H5_MOV_SEGMENTATIONS = 'input/{}'.format(MOVING_SEGMENTATIONS)
|
55 |
+
H5_FIX_CENTROID = 'input/fix_centroid'
|
56 |
+
H5_MOV_CENTROID = 'input/mov_centroid'
|
57 |
|
58 |
H5_GT_DISP = 'output/{}'.format(DISP_MAP_GT)
|
59 |
H5_GT_IMG = 'output/{}'.format(PRED_IMG_GT)
|
|
|
66 |
MAX_FLIPS = 2 # Axes to flip over
|
67 |
NUM_ROTATIONS = 5
|
68 |
MAX_WORKERS = 10
|
69 |
+
DEG_TO_RAD = np.pi/180.
|
70 |
|
71 |
# Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
|
72 |
DG_LBL_FIX_IMG = H5_FIX_IMG
|
|
|
78 |
DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
|
79 |
DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
|
80 |
DG_LBL_ZERO_GRADS = 'zero_gradients'
|
81 |
+
DG_LBL_FIX_CENTROID = H5_FIX_CENTROID
|
82 |
+
DG_LBL_MOV_CENTROID = H5_MOV_CENTROID
|
83 |
|
84 |
# Training constants
|
85 |
MODEL = 'unet'
|
|
|
96 |
DATA_DIR = './data'
|
97 |
MODEL_CHECKPOINT = './model_checkpoint'
|
98 |
BATCH_SIZE = 8
|
99 |
+
ACCUM_GRADIENT_STEP = 1
|
100 |
+
EARLY_STOP_PATIENCE = 30 # Weights are updated every ACCUM_GRADIENT_STEPth step
|
101 |
EPOCHS = 100
|
102 |
SAVE_EPOCH = EPOCHS // 10 # Epoch when to save the model
|
103 |
VERBOSE_EPOCH = EPOCHS // 10
|
|
|
106 |
VALIDATION_ERR_LIMIT_COUNTER_BACKUP = 10
|
107 |
THRESHOLD = 0.5 # Threshold to select the centerline in the interpolated images
|
108 |
RESTORE_TRAINING = True # look for previously saved models to resume training
|
|
|
109 |
LOG_FIELD_NAMES = ['time', 'epoch', 'step',
|
110 |
'training_loss_mean', 'training_loss_std',
|
111 |
'training_loss1_mean', 'training_loss1_std',
|
|
|
368 |
class VisualizationParameters:
|
369 |
def __init__(self):
|
370 |
self.__scale = None # See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.quiver.html
|
371 |
+
self.__spacing = 15
|
372 |
|
373 |
def set_spacing(self, img_shape: tf.TensorShape):
|
374 |
+
if isinstance(img_shape, tf.TensorShape):
|
375 |
+
self.__spacing = int(5 * np.log(img_shape[W]))
|
376 |
+
else:
|
377 |
+
self.__spacing = img_shape
|
378 |
|
379 |
@property
|
380 |
def spacing(self):
|
|
|
504 |
|
505 |
REG_PRIOR_W = [1e-3]
|
506 |
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
507 |
+
|
508 |
+
# Constants for augmentation layer
|
509 |
+
# .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
510 |
+
# The augmentation values will be scaled using the average+std
|
511 |
+
IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
|
512 |
+
MAX_AUG_DISP_ISOT = 30
|
513 |
+
MAX_AUG_DEF_ISOT = 6
|
514 |
+
MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
|
515 |
+
MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
|
516 |
+
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
517 |
+
np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[1]) * 180 / np.pi,
|
518 |
+
np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi]) # Scaled angles
|
519 |
+
GAMMA_AUGMENTATION = False
|
520 |
+
BRIGHTNESS_AUGMENTATION = False
|
521 |
+
NUM_CONTROL_PTS_AUG = 10
|
522 |
+
NUM_AUGMENTATIONS = 1
|
DeepDeformationMapRegistration/utils/misc.py
CHANGED
@@ -1,11 +1,19 @@
|
|
1 |
import os
|
2 |
import errno
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
5 |
try:
|
6 |
os.makedirs(dir)
|
7 |
except OSError as err:
|
8 |
-
if err.errno == errno.EEXIST:
|
9 |
print("Directory " + dir + " already exists")
|
10 |
else:
|
11 |
raise ValueError("Can't create dir " + dir)
|
@@ -23,3 +31,120 @@ def function_decorator(new_name):
|
|
23 |
func.__name__ = new_name
|
24 |
return func
|
25 |
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import errno
|
3 |
+
import shutil
|
4 |
+
import numpy as np
|
5 |
+
from scipy.interpolate import griddata, Rbf, LinearNDInterpolator, NearestNDInterpolator
|
6 |
+
from skimage.measure import regionprops
|
7 |
+
from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
|
8 |
+
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
+
from tensorflow import squeeze
|
10 |
|
11 |
+
|
12 |
+
def try_mkdir(dir, verbose=True):
|
13 |
try:
|
14 |
os.makedirs(dir)
|
15 |
except OSError as err:
|
16 |
+
if err.errno == errno.EEXIST and verbose:
|
17 |
print("Directory " + dir + " already exists")
|
18 |
else:
|
19 |
raise ValueError("Can't create dir " + dir)
|
|
|
31 |
func.__name__ = new_name
|
32 |
return func
|
33 |
return decorator
|
34 |
+
|
35 |
+
|
36 |
+
class DatasetCopy:
|
37 |
+
def __init__(self, dataset_location, copy_location=None, verbose=True):
|
38 |
+
self.__copy_loc = os.path.join(os.getcwd(), 'temp_dataset') if copy_location is None else copy_location
|
39 |
+
self.__dst_loc = dataset_location
|
40 |
+
self.__verbose = verbose
|
41 |
+
|
42 |
+
def copy_dataset(self):
|
43 |
+
shutil.copytree(self.__dst_loc, self.__copy_loc)
|
44 |
+
if self.__verbose:
|
45 |
+
print('{} copied to {}'.format(self.__dst_loc, self.__copy_loc))
|
46 |
+
return self.__copy_loc
|
47 |
+
|
48 |
+
def delete_temp(self):
|
49 |
+
shutil.rmtree(self.__copy_loc)
|
50 |
+
if self.__verbose:
|
51 |
+
print('Deleted: ', self.__copy_loc)
|
52 |
+
|
53 |
+
|
54 |
+
class DisplacementMapInterpolator:
|
55 |
+
def __init__(self,
|
56 |
+
image_shape=[64, 64, 64],
|
57 |
+
method='rbf'):
|
58 |
+
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
59 |
+
self.method = method
|
60 |
+
self.image_shape = image_shape
|
61 |
+
|
62 |
+
self.grid = self.__regular_grid()
|
63 |
+
|
64 |
+
def __regular_grid(self):
|
65 |
+
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
66 |
+
yy = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
67 |
+
zz = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
68 |
+
|
69 |
+
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
70 |
+
|
71 |
+
return np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
|
72 |
+
|
73 |
+
def __call__(self, disp_map, interp_points, backwards=False):
|
74 |
+
disp_map = disp_map.reshape([-1, 3])
|
75 |
+
grid_pts = self.grid.copy()
|
76 |
+
if backwards:
|
77 |
+
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|
78 |
+
disp_map *= -1
|
79 |
+
|
80 |
+
if self.method == 'rbf':
|
81 |
+
interpolator = Rbf(grid_pts[:, 0], grid_pts[:, 1], grid_pts[:, 2], disp_map[:, :],
|
82 |
+
method='thin_plate', mode='N-D')
|
83 |
+
disp = interpolator(interp_points)
|
84 |
+
elif self.method == 'griddata':
|
85 |
+
linear_interp = LinearNDInterpolator(grid_pts, disp_map)
|
86 |
+
disp = linear_interp(interp_points).copy()
|
87 |
+
del linear_interp
|
88 |
+
|
89 |
+
if np.any(np.isnan(disp)):
|
90 |
+
# It might happen (though it shouldn't) that the interpolation point is outside the convex hull of grid points.
|
91 |
+
# in this situation, linear interpolation fails and will put NaN. Nearest can give a value, so we are going to
|
92 |
+
# substitute those unexpected NaNs with the nearest value. Unexpected == not in interp_points
|
93 |
+
nan_disp_idx = set(np.unique(np.argwhere(np.isnan(disp))[:, 0]))
|
94 |
+
nan_interp_pts_idx = set(np.unique(np.argwhere(np.isnan(interp_points))[:, 0]))
|
95 |
+
idx = nan_disp_idx - nan_interp_pts_idx if len(nan_disp_idx) > len(nan_interp_pts_idx) else nan_interp_pts_idx - nan_disp_idx
|
96 |
+
idx = list(idx)
|
97 |
+
if len(idx):
|
98 |
+
# We have unexpected NaNs
|
99 |
+
near_interp = NearestNDInterpolator(grid_pts, disp_map)
|
100 |
+
near_disp = near_interp(interp_points[idx, ...]).copy()
|
101 |
+
del near_interp
|
102 |
+
for n, i in enumerate(idx):
|
103 |
+
disp[i, ...] = near_disp[n, ...]
|
104 |
+
elif self.method == 'tf':
|
105 |
+
# Order: 1 -> linear, 2 -> thin plate, 3 -> cubic
|
106 |
+
disp = squeeze(interpolate_spline(grid_pts[np.newaxis, ...][::4, :], # Batch axis
|
107 |
+
disp_map[np.newaxis, ...][::4, :],
|
108 |
+
interp_points[np.newaxis, ...], order=2), axis=0)
|
109 |
+
else:
|
110 |
+
tps_interp = ThinPlateSplines(grid_pts[::8, :], self.grid.copy().astype(np.float32)[::8, :])
|
111 |
+
disp = tps_interp.interpolate(interp_points).eval()
|
112 |
+
del tps_interp
|
113 |
+
|
114 |
+
return disp
|
115 |
+
|
116 |
+
|
117 |
+
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(0, 28), missing_centroid=[np.nan]*3, brain_study=True):
|
118 |
+
segmentations = np.squeeze(segmentations)
|
119 |
+
if ohe:
|
120 |
+
segmentations = np.sum(segmentations, axis=-1).astype(np.uint8)
|
121 |
+
missing_lbls = set(expected_lbls) - set(np.unique(segmentations))
|
122 |
+
if brain_study:
|
123 |
+
segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1
|
124 |
+
else:
|
125 |
+
missing_lbls = set(expected_lbls) - set(np.unique(segmentations))
|
126 |
+
|
127 |
+
seg_props = regionprops(segmentations)
|
128 |
+
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
|
129 |
+
|
130 |
+
for lbl in missing_lbls:
|
131 |
+
idx = expected_lbls.index(lbl)
|
132 |
+
centroids = np.insert(centroids, idx, missing_centroid, axis=0)
|
133 |
+
return centroids.copy(), missing_lbls
|
134 |
+
|
135 |
+
|
136 |
+
def segmentation_ohe_to_cardinal(segmentation):
|
137 |
+
cpy = segmentation.copy()
|
138 |
+
for lbl in range(segmentation.shape[-1]):
|
139 |
+
cpy[..., lbl] *= (lbl + 1)
|
140 |
+
# Add the Background
|
141 |
+
cpy = np.concatenate([np.zeros(segmentation.shape[:-1])[..., np.newaxis], cpy], axis=-1)
|
142 |
+
return np.argmax(cpy, axis=-1)[..., np.newaxis]
|
143 |
+
|
144 |
+
|
145 |
+
def segmentation_cardinal_to_ohe(segmentation):
|
146 |
+
# Keep in mind that we don't handle the overlap between the segmentations!
|
147 |
+
cpy = np.tile(np.zeros_like(segmentation), (1, 1, 1, len(np.unique(segmentation)[1:])))
|
148 |
+
for ch, lbl in enumerate(np.unique(segmentation)[1:]):
|
149 |
+
cpy[segmentation == lbl, ch] = 1
|
150 |
+
return cpy
|
DeepDeformationMapRegistration/utils/nifti_utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import nibabel as nb
|
3 |
+
import numpy as np
|
4 |
+
import zipfile
|
5 |
+
|
6 |
+
|
7 |
+
TEMP_UNZIP_PATH = '/mnt/EncryptedData1/Users/javier/ext_datasets/LITS17/temp'
|
8 |
+
NII_EXTENSION = '.nii'
|
9 |
+
|
10 |
+
|
11 |
+
def save_nifti(data, save_path, header=None, verbose=True):
|
12 |
+
if header is None:
|
13 |
+
data_nifti = nb.Nifti1Image(data, affine=np.eye(4))
|
14 |
+
else:
|
15 |
+
data_nifti = nb.Nifti1Image(data, affine=None, header=header)
|
16 |
+
|
17 |
+
data_nifti.header.get_xyzt_units()
|
18 |
+
try:
|
19 |
+
nb.save(data_nifti, save_path) # Save as NiBabel file
|
20 |
+
if verbose:
|
21 |
+
print('Saved {}'.format(save_path))
|
22 |
+
except ValueError:
|
23 |
+
print('Could not save {}'.format(save_path))
|
24 |
+
|
25 |
+
|
26 |
+
def unzip_nii_file(file_path):
|
27 |
+
file_dir, file_name = os.path.split(file_path)
|
28 |
+
file_name = file_name.split('.zip')[0]
|
29 |
+
|
30 |
+
dest_path = os.path.join(TEMP_UNZIP_PATH, file_name)
|
31 |
+
zipfile.ZipFile(file_path).extractall(TEMP_UNZIP_PATH)
|
32 |
+
|
33 |
+
if not os.path.exists(dest_path):
|
34 |
+
print('ERR: File {} not unzip-ed!'.format(file_path))
|
35 |
+
dest_path = None
|
36 |
+
return dest_path
|
37 |
+
|
38 |
+
|
39 |
+
def delete_temp(file_path, verbose=False):
|
40 |
+
assert NII_EXTENSION in file_path
|
41 |
+
os.remove(file_path)
|
42 |
+
if verbose:
|
43 |
+
print('Deleted file: ', file_path)
|
DeepDeformationMapRegistration/utils/operators.py
CHANGED
@@ -11,14 +11,23 @@ def min_max_norm(img: np.ndarray, out_max_val=1.):
|
|
11 |
return out_img * out_max_val
|
12 |
|
13 |
|
14 |
-
def soft_threshold(x, threshold, name=
|
15 |
# https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
|
16 |
-
|
|
|
|
|
|
|
17 |
x = tf.convert_to_tensor(x, name='x')
|
18 |
threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
|
19 |
return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def binary_activation(x):
|
23 |
# https://stackoverflow.com/questions/37743574/hard-limiting-threshold-activation-function-in-tensorflow
|
24 |
cond = tf.less(x, tf.zeros(tf.shape(x)))
|
@@ -26,3 +35,31 @@ def binary_activation(x):
|
|
26 |
|
27 |
return out
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
return out_img * out_max_val
|
12 |
|
13 |
|
14 |
+
def soft_threshold(x, threshold, name='soft_threshold'):
|
15 |
# https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
|
16 |
+
# Foucart S., Rauhut H. (2013) Basic Algorithms. In: A Mathematical Introduction to Compressive Sensing.
|
17 |
+
# Applied and Numerical Harmonic Analysis. Birkhäuser, New York, NY. https://doi.org/10.1007/978-0-8176-4948-7_3
|
18 |
+
# Chapter 3, page 72
|
19 |
+
with tf.name_scope(name):
|
20 |
x = tf.convert_to_tensor(x, name='x')
|
21 |
threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
|
22 |
return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
|
23 |
|
24 |
|
25 |
+
def hard_threshold(x, threshold, name='hard_threshold'):
|
26 |
+
with tf.name_scope(name):
|
27 |
+
threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
|
28 |
+
return tf.sign(tf.maximum(tf.abs(x) - threshold, 0.))
|
29 |
+
|
30 |
+
|
31 |
def binary_activation(x):
|
32 |
# https://stackoverflow.com/questions/37743574/hard-limiting-threshold-activation-function-in-tensorflow
|
33 |
cond = tf.less(x, tf.zeros(tf.shape(x)))
|
|
|
35 |
|
36 |
return out
|
37 |
|
38 |
+
|
39 |
+
def gaussian_kernel(kernel_size, sigma, in_ch, out_ch, dim, dtype=tf.float32):
|
40 |
+
# SRC: https://stackoverflow.com/questions/59286171/gaussian-blur-image-in-dataset-pipeline-in-tensorflow
|
41 |
+
x = tf.range(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=dtype)
|
42 |
+
g = tf.math.exp(-(tf.pow(x, 2) / (2 * tf.pow(tf.cast(sigma, dtype), 2))))
|
43 |
+
|
44 |
+
g_kernel = tf.identity(g)
|
45 |
+
g_kernel = tf.tensordot(g_kernel, g, 0)
|
46 |
+
g_kernel = tf.tensordot(g_kernel, g, 0)
|
47 |
+
|
48 |
+
# i = tf.constant(0)
|
49 |
+
# cond = lambda i, g_kern: tf.less(i, dim - 1)
|
50 |
+
# mult_kern = lambda i, g_kern: [tf.add(i, 1), tf.tensordot(g_kern, g, 0)]
|
51 |
+
# _, g_kernel = tf.while_loop(cond, mult_kern,
|
52 |
+
# loop_vars=[i, g_kernel],
|
53 |
+
# shape_invariants=[i.get_shape(), tf.TensorShape([kernel_size, None, None])])
|
54 |
+
|
55 |
+
g_kernel = g_kernel / tf.reduce_sum(g_kernel)
|
56 |
+
g_kernel = tf.expand_dims(tf.expand_dims(g_kernel, axis=-1), axis=-1)
|
57 |
+
return tf.tile(g_kernel, (*(1,)*dim, in_ch, out_ch))
|
58 |
+
|
59 |
+
|
60 |
+
def sample_unique(population, samples, tout=tf.int32):
|
61 |
+
# src: https://github.com/tensorflow/tensorflow/issues/9260#issuecomment-437875125
|
62 |
+
z = -tf.log(-tf.log(tf.random_uniform((tf.shape(population)[0],), 0, 1)))
|
63 |
+
_, indices = tf.nn.top_k(z, samples)
|
64 |
+
ret_val = tf.gather(population, indices)
|
65 |
+
return tf.cast(ret_val, tout)
|
DeepDeformationMapRegistration/utils/thin_plate_splines.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
|
5 |
+
class ThinPlateSplines:
|
6 |
+
def __init__(self, ctrl_pts: tf.Tensor, target_pts: tf.Tensor, reg=0.0):
|
7 |
+
"""
|
8 |
+
|
9 |
+
:param ctrl_pts: [N, d] tensor of control d-dimensional points
|
10 |
+
:param target_pts: [N, d] tensor of target d-dimensional points
|
11 |
+
:param reg: regularization coefficient
|
12 |
+
"""
|
13 |
+
self.__ctrl_pts = ctrl_pts
|
14 |
+
self.__target_pts = target_pts
|
15 |
+
self.__reg = reg
|
16 |
+
self.__num_ctrl_pts = ctrl_pts.shape[0]
|
17 |
+
self.__dim = ctrl_pts.shape[1]
|
18 |
+
|
19 |
+
self.__compute_coeffs()
|
20 |
+
# self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...] # Affine parameters of the TPS
|
21 |
+
self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] # Non-affine parameters of he TPS
|
22 |
+
|
23 |
+
def __compute_coeffs(self):
|
24 |
+
target_pts_aug = tf.concat([self.__target_pts,
|
25 |
+
tf.zeros([self.__dim + 1, self.__dim], dtype=self.__target_pts.dtype)],
|
26 |
+
axis=0)
|
27 |
+
|
28 |
+
# T = self.__make_T()
|
29 |
+
T_i = tf.cast(tf.linalg.inv(self.__make_T()), target_pts_aug.dtype)
|
30 |
+
self.__coeffs = tf.cast(tf.matmul(T_i, target_pts_aug), tf.float32)
|
31 |
+
|
32 |
+
def __make_T(self):
|
33 |
+
# cp: [K x 2] control points
|
34 |
+
# T: [(num_pts+dim+1) x (num_pts+dim+1)]
|
35 |
+
num_pts = self.__ctrl_pts.shape[0]
|
36 |
+
|
37 |
+
P = tf.concat([tf.ones([self.__num_ctrl_pts, 1], dtype=tf.float32), self.__ctrl_pts], axis=1)
|
38 |
+
zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
|
39 |
+
self.__K = self.__U_dist(self.__ctrl_pts)
|
40 |
+
alfa = tf.reduce_mean(self.__K)
|
41 |
+
|
42 |
+
self.__K = self.__K + tf.ones_like(self.__K) * tf.pow(alfa, 2) * self.__reg
|
43 |
+
|
44 |
+
# top = tf.concat([self.__K, P], axis=1)
|
45 |
+
# bottom = tf.concat([tf.transpose(P), zeros], axis=1)
|
46 |
+
|
47 |
+
return tf.concat([tf.concat([self.__K, P], axis=1), tf.concat([tf.transpose(P), zeros], axis=1)], axis=0)
|
48 |
+
|
49 |
+
def __U_dist(self, ctrl_pts, int_pts=None):
|
50 |
+
if int_pts is None:
|
51 |
+
dist = self.__pairwise_distance_equal(ctrl_pts) # Already squared!
|
52 |
+
else:
|
53 |
+
dist = self.__pairwise_distance_different(ctrl_pts, int_pts) # Already squared!
|
54 |
+
|
55 |
+
|
56 |
+
# U(x, y) = p_w_dist(x, y)^2 * log(p_w_dist(x, y)) (dist() > =0); 0 otw
|
57 |
+
if ctrl_pts.shape[-1] == 2:
|
58 |
+
u_dist = dist * tf.math.log(dist + 1e-6)
|
59 |
+
else:
|
60 |
+
# Src: https://github.com/vaipatel/morphops/blob/master/morphops/tps.py
|
61 |
+
# In particular, if k = 2, then U(r) = r^2 * log(r^2), else U(r) = r
|
62 |
+
u_dist = tf.sqrt(dist)
|
63 |
+
# tf.matrix_set_diag(u_dist, tf.constant(0, dtype=dist_sq.dtype))
|
64 |
+
# reg_term = self.__reg * tf.pow(alfa, 2) * tf.eye(self.__num_ctrl_pts)
|
65 |
+
|
66 |
+
return u_dist # + reg_term
|
67 |
+
|
68 |
+
def __pairwise_distance_sq(self, pts_a, pts_b):
|
69 |
+
with tf.variable_scope('pairwise_distance'):
|
70 |
+
if np.all(pts_a == pts_b):
|
71 |
+
# This implementation works better when doing the pairwise distance os a single set of points
|
72 |
+
pts_a_ = tf.reshape(pts_a, [-1, 1, 3])
|
73 |
+
pts_b_ = tf.reshape(pts_b, [1, -1, 3])
|
74 |
+
dist = tf.reduce_sum(tf.square(pts_a_ - pts_b_), 2) # squared pairwise distance
|
75 |
+
else:
|
76 |
+
# PwD^2= A_norm^2 - 2*A*B' + B_norm^2
|
77 |
+
pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
|
78 |
+
pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)
|
79 |
+
|
80 |
+
pts_a_ = tf.expand_dims(pts_a_, 1)
|
81 |
+
pts_b_ = tf.expand_dims(pts_b_, 0)
|
82 |
+
|
83 |
+
pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)
|
84 |
+
|
85 |
+
dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
|
86 |
+
|
87 |
+
return tf.cast(dist, tf.float32)
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def __pairwise_distance_equal(pts):
|
91 |
+
# This implementation works better when doing the pairwise distance os a single set of points
|
92 |
+
dist = tf.reduce_sum(tf.square(tf.reshape(pts, [-1, 1, 3]) - tf.reshape(pts, [1, -1, 3])), 2) # squared pairwise distance
|
93 |
+
return tf.cast(dist, tf.float32)
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def __pairwise_distance_different(pts_a, pts_b):
|
97 |
+
pts_a_ = tf.reduce_sum(tf.square(pts_a), 1)
|
98 |
+
pts_b_ = tf.reduce_sum(tf.square(pts_b), 1)
|
99 |
+
|
100 |
+
pts_a_ = tf.expand_dims(pts_a_, 1)
|
101 |
+
pts_b_ = tf.expand_dims(pts_b_, 0)
|
102 |
+
|
103 |
+
pts_a_pts_b_ = tf.matmul(pts_a, pts_b, adjoint_b=True)
|
104 |
+
|
105 |
+
dist = pts_a_ - 2 * pts_a_pts_b_ + pts_b_
|
106 |
+
return tf.cast(dist, tf.float32)
|
107 |
+
|
108 |
+
def __lift_pts(self, int_pts: tf.Tensor, num_pts):
|
109 |
+
# int_pts: [N x 2], input points
|
110 |
+
# cp: [K x 2], control points
|
111 |
+
# pLift: [N x (3+K)], lifted input points
|
112 |
+
|
113 |
+
# u_dist = self.__U_dist(int_pts, self.__ctrl_pts)
|
114 |
+
|
115 |
+
int_pts_lift = tf.concat([self.__U_dist(int_pts, self.__ctrl_pts),
|
116 |
+
tf.ones([num_pts, 1], dtype=tf.float32),
|
117 |
+
int_pts], axis=1)
|
118 |
+
return int_pts_lift
|
119 |
+
|
120 |
+
@property
|
121 |
+
def bending_energy(self):
|
122 |
+
aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
|
123 |
+
return tf.matmul(aux, self.__non_aff_paramms)
|
124 |
+
|
125 |
+
def interpolate(self, int_points): #, num_pts):
|
126 |
+
"""
|
127 |
+
|
128 |
+
:param int_points: [K, d] flattened d-points of a mesh
|
129 |
+
:return:
|
130 |
+
"""
|
131 |
+
num_pts = tf.shape(int_points)[0]
|
132 |
+
int_points_lift = self.__lift_pts(int_points, num_pts)
|
133 |
+
|
134 |
+
return tf.matmul(int_points_lift, self.__coeffs)
|
135 |
+
|
136 |
+
def __call__(self, int_points, num_pts, **kwargs):
|
137 |
+
return self.interpolate(int_points) # , num_pts)
|
138 |
+
|
139 |
+
|
140 |
+
def thin_plate_splines_batch(ctrl_pts: tf.Tensor, target_pts: tf.Tensor, int_pts: tf.Tensor, reg=0.0):
|
141 |
+
_batches = ctrl_pts.shape[0]
|
142 |
+
|
143 |
+
if tf.get_default_session() is not None:
|
144 |
+
print('DEBUG TIME')
|
145 |
+
|
146 |
+
def tps_sample(in_data):
|
147 |
+
cp, tp, ip = in_data
|
148 |
+
# _num_pts = ip.shape[0]
|
149 |
+
tps = ThinPlateSplines(cp, tp, reg)
|
150 |
+
interp = tps.interpolate(ip) # , _num_pts)
|
151 |
+
return interp
|
152 |
+
|
153 |
+
return tf.map_fn(tps_sample, elems=(ctrl_pts, target_pts, int_pts), dtype=tf.float32)
|
154 |
+
|
DeepDeformationMapRegistration/utils/visualization.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
|
2 |
-
#
|
3 |
import matplotlib.pyplot as plt
|
4 |
from mpl_toolkits.mplot3d import Axes3D
|
5 |
import matplotlib.colors as mcolors
|
@@ -8,7 +8,7 @@ from matplotlib import cm
|
|
8 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
9 |
import tensorflow as tf
|
10 |
import numpy as np
|
11 |
-
import
|
12 |
from skimage.exposure import rescale_intensity
|
13 |
import scipy.misc as scpmisc
|
14 |
import os
|
@@ -175,11 +175,11 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
|
|
175 |
fig.clear()
|
176 |
plt.figure(fig.number)
|
177 |
else:
|
178 |
-
fig = plt.figure(dpi=
|
179 |
|
180 |
ax_fix = fig.add_subplot(231)
|
181 |
im_fix = ax_fix.imshow(list_imgs[0][:, :, 0])
|
182 |
-
ax_fix.set_title('Fix image', fontsize=
|
183 |
ax_fix.tick_params(axis='both',
|
184 |
which='both',
|
185 |
bottom=False,
|
@@ -188,7 +188,7 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
|
|
188 |
labelbottom=False)
|
189 |
ax_mov = fig.add_subplot(232)
|
190 |
im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
|
191 |
-
ax_mov.set_title('Moving image', fontsize=
|
192 |
ax_mov.tick_params(axis='both',
|
193 |
which='both',
|
194 |
bottom=False,
|
@@ -198,7 +198,7 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
|
|
198 |
|
199 |
ax_pred_im = fig.add_subplot(233)
|
200 |
im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
|
201 |
-
ax_pred_im.set_title('Prediction', fontsize=
|
202 |
ax_pred_im.tick_params(axis='both',
|
203 |
which='both',
|
204 |
bottom=False,
|
@@ -228,8 +228,8 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
|
|
228 |
else:
|
229 |
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3])
|
230 |
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
231 |
-
ax_pred_disp.quiver(cx, cy, dx, dy, scale=
|
232 |
-
ax_pred_disp.set_title('Pred disp map', fontsize=
|
233 |
ax_pred_disp.tick_params(axis='both',
|
234 |
which='both',
|
235 |
bottom=False,
|
@@ -259,8 +259,8 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
|
|
259 |
else:
|
260 |
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4])
|
261 |
im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal')
|
262 |
-
ax_gt_disp.quiver(cx, cy, dx, dy, scale=
|
263 |
-
ax_gt_disp.set_title('GT disp map', fontsize=
|
264 |
ax_gt_disp.tick_params(axis='both',
|
265 |
which='both',
|
266 |
bottom=False,
|
@@ -276,7 +276,7 @@ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', f
|
|
276 |
|
277 |
if filename is not None:
|
278 |
plt.savefig(filename, format='png') # Call before show
|
279 |
-
if not
|
280 |
plt.show()
|
281 |
else:
|
282 |
plt.close()
|
@@ -288,7 +288,7 @@ def save_centreline_img(img, title, filename, fig=None):
|
|
288 |
fig.clear()
|
289 |
plt.figure(fig.number)
|
290 |
else:
|
291 |
-
fig = plt.figure(dpi=
|
292 |
|
293 |
dim = len(img.shape[:-1])
|
294 |
|
@@ -321,19 +321,19 @@ def save_centreline_img(img, title, filename, fig=None):
|
|
321 |
plt.close()
|
322 |
|
323 |
|
324 |
-
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
|
325 |
if fig is not None:
|
326 |
fig.clear()
|
327 |
plt.figure(fig.number)
|
328 |
else:
|
329 |
-
fig = plt.figure(dpi=
|
330 |
-
|
331 |
dim = disp_map.shape[-1]
|
332 |
|
333 |
if dim == 2:
|
334 |
ax_x = fig.add_subplot(131)
|
335 |
ax_x.set_title('H displacement')
|
336 |
-
im_x = ax_x.imshow(disp_map[...,
|
337 |
ax_x.tick_params(axis='both',
|
338 |
which='both',
|
339 |
bottom=False,
|
@@ -344,7 +344,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
|
|
344 |
|
345 |
ax_y = fig.add_subplot(132)
|
346 |
ax_y.set_title('W displacement')
|
347 |
-
im_y = ax_y.imshow(disp_map[...,
|
348 |
ax_y.tick_params(axis='both',
|
349 |
which='both',
|
350 |
bottom=False,
|
@@ -373,8 +373,8 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
|
|
373 |
else:
|
374 |
c, d, s = _prepare_quiver_map(disp_map, dim=dim)
|
375 |
im = ax.imshow(s, interpolation='none', aspect='equal')
|
376 |
-
ax.quiver(c[
|
377 |
-
scale=
|
378 |
cb = _set_colorbar(fig, ax, im, False)
|
379 |
ax.set_title('Displacement map')
|
380 |
ax.tick_params(axis='both',
|
@@ -387,9 +387,8 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
|
|
387 |
else:
|
388 |
ax = fig.add_subplot(111, projection='3d')
|
389 |
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
|
390 |
-
ax.quiver(c[
|
391 |
-
|
392 |
-
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
393 |
fig.suptitle('Displacement map')
|
394 |
ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html
|
395 |
which='both',
|
@@ -397,10 +396,14 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
|
|
397 |
left=False,
|
398 |
labelleft=False,
|
399 |
labelbottom=False)
|
|
|
400 |
fig.suptitle(title)
|
401 |
|
402 |
plt.savefig(filename, format='png')
|
|
|
|
|
403 |
plt.close()
|
|
|
404 |
|
405 |
|
406 |
def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None,
|
@@ -409,16 +412,16 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
409 |
fig.clear()
|
410 |
plt.figure(fig.number)
|
411 |
else:
|
412 |
-
fig = plt.figure(dpi=
|
413 |
|
414 |
dim = len(list_imgs[0].shape[:-1])
|
415 |
|
416 |
if dim == 2:
|
417 |
# TRAINING
|
418 |
ax_input = fig.add_subplot(241)
|
419 |
-
ax_input.set_ylabel(title_first_row, fontsize=
|
420 |
im_fix = ax_input.imshow(list_imgs[0][:, :, 0])
|
421 |
-
ax_input.set_title('Fix image', fontsize=
|
422 |
ax_input.tick_params(axis='both',
|
423 |
which='both',
|
424 |
bottom=False,
|
@@ -427,7 +430,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
427 |
labelbottom=False)
|
428 |
ax_mov = fig.add_subplot(242)
|
429 |
im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
|
430 |
-
ax_mov.set_title('Moving image', fontsize=
|
431 |
ax_mov.tick_params(axis='both',
|
432 |
which='both',
|
433 |
bottom=False,
|
@@ -437,7 +440,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
437 |
|
438 |
ax_pred_im = fig.add_subplot(244)
|
439 |
im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
|
440 |
-
ax_pred_im.set_title('Predicted fix image', fontsize=
|
441 |
ax_pred_im.tick_params(axis='both',
|
442 |
which='both',
|
443 |
bottom=False,
|
@@ -467,8 +470,8 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
467 |
else:
|
468 |
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim)
|
469 |
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
470 |
-
ax_pred_disp.quiver(cx, cy, dx, dy, scale=
|
471 |
-
ax_pred_disp.set_title('Pred disp map', fontsize=
|
472 |
ax_pred_disp.tick_params(axis='both',
|
473 |
which='both',
|
474 |
bottom=False,
|
@@ -478,9 +481,9 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
478 |
|
479 |
# VALIDATION
|
480 |
axinput_val = fig.add_subplot(245)
|
481 |
-
axinput_val.set_ylabel(title_second_row, fontsize=
|
482 |
im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0])
|
483 |
-
axinput_val.set_title('Fix image', fontsize=
|
484 |
axinput_val.tick_params(axis='both',
|
485 |
which='both',
|
486 |
bottom=False,
|
@@ -489,7 +492,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
489 |
labelbottom=False)
|
490 |
ax_mov_val = fig.add_subplot(246)
|
491 |
im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0])
|
492 |
-
ax_mov_val.set_title('Moving image', fontsize=
|
493 |
ax_mov_val.tick_params(axis='both',
|
494 |
which='both',
|
495 |
bottom=False,
|
@@ -499,7 +502,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
499 |
|
500 |
ax_pred_im_val = fig.add_subplot(248)
|
501 |
im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0])
|
502 |
-
ax_pred_im_val.set_title('Predicted fix image', fontsize=
|
503 |
ax_pred_im_val.tick_params(axis='both',
|
504 |
which='both',
|
505 |
bottom=False,
|
@@ -529,8 +532,8 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
529 |
else:
|
530 |
c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
|
531 |
im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
|
532 |
-
ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=
|
533 |
-
ax_pred_disp_val.set_title('Pred disp map', fontsize=
|
534 |
ax_pred_disp_val.tick_params(axis='both',
|
535 |
which='both',
|
536 |
bottom=False,
|
@@ -552,10 +555,10 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
552 |
# 3D
|
553 |
# TRAINING
|
554 |
ax_input = fig.add_subplot(231, projection='3d')
|
555 |
-
ax_input.set_ylabel(title_first_row, fontsize=
|
556 |
im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
557 |
im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
|
558 |
-
ax_input.set_title('Fix image', fontsize=
|
559 |
ax_input.tick_params(axis='both',
|
560 |
which='both',
|
561 |
bottom=False,
|
@@ -566,7 +569,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
566 |
ax_pred_im = fig.add_subplot(232, projection='3d')
|
567 |
im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction')
|
568 |
im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
569 |
-
ax_pred_im.set_title('Predicted fix image', fontsize=
|
570 |
ax_pred_im.tick_params(axis='both',
|
571 |
which='both',
|
572 |
bottom=False,
|
@@ -578,9 +581,9 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
578 |
|
579 |
c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim)
|
580 |
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
581 |
-
ax_pred_disp.quiver(c[
|
582 |
-
d[
|
583 |
-
ax_pred_disp.set_title('Pred disp map', fontsize=
|
584 |
ax_pred_disp.tick_params(axis='both',
|
585 |
which='both',
|
586 |
bottom=False,
|
@@ -590,10 +593,10 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
590 |
|
591 |
# VALIDATION
|
592 |
axinput_val = fig.add_subplot(234, projection='3d')
|
593 |
-
axinput_val.set_ylabel(title_second_row, fontsize=
|
594 |
im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
|
595 |
im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)')
|
596 |
-
axinput_val.set_title('Fix image', fontsize=
|
597 |
axinput_val.tick_params(axis='both',
|
598 |
which='both',
|
599 |
bottom=False,
|
@@ -604,7 +607,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
604 |
ax_pred_im_val = fig.add_subplot(235, projection='3d')
|
605 |
im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)')
|
606 |
im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
|
607 |
-
ax_pred_im_val.set_title('Predicted fix image', fontsize=
|
608 |
ax_pred_im_val.tick_params(axis='both',
|
609 |
which='both',
|
610 |
bottom=False,
|
@@ -615,10 +618,10 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
615 |
ax_pred_disp_val = fig.add_subplot(236, projection='3d')
|
616 |
c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
|
617 |
im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
|
618 |
-
ax_pred_disp_val.quiver(c[
|
619 |
-
d[
|
620 |
-
scale=
|
621 |
-
ax_pred_disp_val.set_title('Pred disp map', fontsize=
|
622 |
ax_pred_disp_val.tick_params(axis='both',
|
623 |
which='both',
|
624 |
bottom=False,
|
@@ -628,7 +631,7 @@ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, fi
|
|
628 |
|
629 |
if filename is not None:
|
630 |
plt.savefig(filename, format='png') # Call before show
|
631 |
-
if not
|
632 |
plt.show()
|
633 |
else:
|
634 |
plt.close()
|
@@ -644,21 +647,21 @@ def _set_colorbar(fig, ax, im, drawedges=True):
|
|
644 |
return im_cb
|
645 |
|
646 |
|
647 |
-
def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=
|
648 |
if isinstance(disp_map, tf.Tensor):
|
649 |
if tf.executing_eagerly():
|
650 |
disp_map = disp_map.numpy()
|
651 |
else:
|
652 |
disp_map = disp_map.eval()
|
653 |
-
dx = disp_map[...,
|
654 |
-
dy = disp_map[...,
|
655 |
if dim > 2:
|
656 |
-
dz = disp_map[...,
|
657 |
|
658 |
-
img_size_x = disp_map.shape[
|
659 |
-
img_size_y = disp_map.shape[
|
660 |
if dim > 2:
|
661 |
-
img_size_z = disp_map.shape[
|
662 |
|
663 |
if dim > 2:
|
664 |
s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz))
|
@@ -728,7 +731,7 @@ def plot_input_data(fix_img, mov_img, img_size=(64, 64), title=None, filename=No
|
|
728 |
|
729 |
if filename is not None:
|
730 |
plt.savefig(filename, format='png') # Call before show
|
731 |
-
if not
|
732 |
plt.show()
|
733 |
else:
|
734 |
plt.close()
|
@@ -807,29 +810,42 @@ def plot_dataset_3d(img_sets):
|
|
807 |
return fig
|
808 |
|
809 |
|
810 |
-
def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None):
|
811 |
num_rows = fix_img_batch.shape[0]
|
812 |
-
|
|
|
813 |
if fig is not None:
|
814 |
fig.clear()
|
815 |
plt.figure(fig.number)
|
816 |
-
ax = fig.add_subplot(nrows=num_rows, ncols=
|
817 |
else:
|
818 |
-
fig, ax = plt.subplots(nrows=num_rows, ncols=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
|
820 |
for row in range(num_rows):
|
821 |
-
fix_img = fix_img_batch[row, :, :, 0]
|
822 |
-
mov_img = mov_img_batch[row, :, :, 0]
|
823 |
-
disp_map = disp_map_batch[row, :, :, :]
|
824 |
-
pred_img = pred_img_batch[row, :, :, 0]
|
825 |
-
ax[row, 0].imshow(fix_img)
|
826 |
ax[row, 0].tick_params(axis='both',
|
827 |
which='both',
|
828 |
bottom=False,
|
829 |
left=False,
|
830 |
labelleft=False,
|
831 |
labelbottom=False)
|
832 |
-
ax[row, 1].imshow(mov_img)
|
833 |
ax[row, 1].tick_params(axis='both',
|
834 |
which='both',
|
835 |
bottom=False,
|
@@ -837,11 +853,12 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
837 |
labelleft=False,
|
838 |
labelbottom=False)
|
839 |
|
840 |
-
|
|
|
|
|
841 |
disp_map_color = _prepare_colormap(disp_map)
|
842 |
-
ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal')
|
843 |
-
ax[row, 2].quiver(cx
|
844 |
-
ax[row, 2].figure.set_size_inches(img_size)
|
845 |
ax[row, 2].tick_params(axis='both',
|
846 |
which='both',
|
847 |
bottom=False,
|
@@ -849,6 +866,8 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
849 |
labelleft=False,
|
850 |
labelbottom=False)
|
851 |
|
|
|
|
|
852 |
ax[row, 3].tick_params(axis='both',
|
853 |
which='both',
|
854 |
bottom=False,
|
@@ -856,18 +875,26 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
856 |
labelleft=False,
|
857 |
labelbottom=False)
|
858 |
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
|
|
|
|
864 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
if filename is not None:
|
866 |
plt.savefig(filename, format='png') # Call before show
|
867 |
-
if
|
868 |
plt.show()
|
869 |
-
|
870 |
-
plt.close()
|
871 |
return fig
|
872 |
|
873 |
|
@@ -876,7 +903,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
|
|
876 |
fig.clear()
|
877 |
plt.figure(fig.number)
|
878 |
else:
|
879 |
-
fig = plt.figure(dpi=
|
880 |
|
881 |
ax0 = fig.add_subplot(221)
|
882 |
im0 = ax0.imshow(fix_img[..., 0])
|
@@ -900,7 +927,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
|
|
900 |
ax2 = fig.add_subplot(223)
|
901 |
im2 = ax2.imshow(s, interpolation='none', aspect='equal')
|
902 |
|
903 |
-
ax2.quiver(cx, cy, dx, dy, scale=
|
904 |
# ax2.figure.set_size_inches(img_size)
|
905 |
ax2.tick_params(axis='both',
|
906 |
which='both',
|
@@ -912,7 +939,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
|
|
912 |
ax3 = fig.add_subplot(224)
|
913 |
dif = fix_img[..., 0] - mov_img[..., 0]
|
914 |
im3 = ax3.imshow(dif)
|
915 |
-
ax3.quiver(cx, cy, dx, dy, scale=
|
916 |
ax3.tick_params(axis='both',
|
917 |
which='both',
|
918 |
bottom=False,
|
@@ -921,10 +948,10 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
|
|
921 |
labelbottom=False)
|
922 |
|
923 |
plt.axis('off')
|
924 |
-
ax0.set_title('Fixed img ($I_f$)', fontsize=
|
925 |
-
ax1.set_title('Moving img ($I_m$)', fontsize=
|
926 |
-
ax2.set_title('Displacement map', fontsize=
|
927 |
-
ax3.set_title('Fix - Mov', fontsize=
|
928 |
|
929 |
im0_cb = _set_colorbar(fig, ax0, im0, False)
|
930 |
im1_cb = _set_colorbar(fig, ax1, im1, False)
|
@@ -933,7 +960,7 @@ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=N
|
|
933 |
|
934 |
if filename is not None:
|
935 |
plt.savefig(filename, format='png') # Call before show
|
936 |
-
if not
|
937 |
plt.show()
|
938 |
else:
|
939 |
plt.close()
|
@@ -950,7 +977,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
950 |
fig = plt.figure()
|
951 |
|
952 |
ax_grid = fig.add_subplot(231)
|
953 |
-
ax_grid.set_title('Grids', fontsize=
|
954 |
ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20)
|
955 |
ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1)
|
956 |
ax_grid.tick_params(axis='both',
|
@@ -966,7 +993,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
966 |
ax_grid.set_aspect('equal')
|
967 |
|
968 |
ax_disp = fig.add_subplot(232)
|
969 |
-
ax_disp.set_title('Displacement map', fontsize=
|
970 |
cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
|
971 |
ax_disp.imshow(s, interpolation='none', aspect='equal')
|
972 |
ax_disp.tick_params(axis='both',
|
@@ -977,7 +1004,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
977 |
labelbottom=False)
|
978 |
|
979 |
ax_mask = fig.add_subplot(233)
|
980 |
-
ax_mask.set_title('Mask', fontsize=
|
981 |
ax_mask.imshow(mask)
|
982 |
ax_mask.tick_params(axis='both',
|
983 |
which='both',
|
@@ -987,7 +1014,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
987 |
labelbottom=False)
|
988 |
|
989 |
ax_fix = fig.add_subplot(234)
|
990 |
-
ax_fix.set_title('Fix image', fontsize=
|
991 |
ax_fix.imshow(fix_img[..., 0])
|
992 |
ax_fix.tick_params(axis='both',
|
993 |
which='both',
|
@@ -997,7 +1024,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
997 |
labelbottom=False)
|
998 |
|
999 |
ax_mov = fig.add_subplot(235)
|
1000 |
-
ax_mov.set_title('Moving image', fontsize=
|
1001 |
ax_mov.imshow(mov_img[..., 0])
|
1002 |
ax_mov.tick_params(axis='both',
|
1003 |
which='both',
|
@@ -1007,7 +1034,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
1007 |
labelbottom=False)
|
1008 |
|
1009 |
ax_dif = fig.add_subplot(236)
|
1010 |
-
ax_dif.set_title('Fix - Moving image', fontsize=
|
1011 |
ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
|
1012 |
ax_dif.tick_params(axis='both',
|
1013 |
which='both',
|
@@ -1023,7 +1050,7 @@ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coo
|
|
1023 |
|
1024 |
if filename is not None:
|
1025 |
plt.savefig(filename, format='png') # Call before show
|
1026 |
-
if not
|
1027 |
plt.show()
|
1028 |
|
1029 |
return fig
|
@@ -1037,10 +1064,10 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
|
|
1037 |
fig = plt.figure()
|
1038 |
|
1039 |
ax_d_m_f = fig.add_subplot(131)
|
1040 |
-
ax_d_m_f.set_title('Disp M->F', fontsize=
|
1041 |
cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f)
|
1042 |
ax_d_m_f.imshow(s, interpolation='none', aspect='equal')
|
1043 |
-
ax_d_m_f.quiver(cx, cy, dx, dy, scale=
|
1044 |
ax_d_m_f.tick_params(axis='both',
|
1045 |
which='both',
|
1046 |
bottom=False,
|
@@ -1049,9 +1076,9 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
|
|
1049 |
labelbottom=False)
|
1050 |
|
1051 |
ax_d_f_m = fig.add_subplot(132)
|
1052 |
-
ax_d_f_m.set_title('Disp F->M', fontsize=
|
1053 |
cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m)
|
1054 |
-
ax_d_f_m.quiver(cx, cy, dx, dy, scale=
|
1055 |
ax_d_f_m.imshow(s, interpolation='none', aspect='equal')
|
1056 |
ax_d_f_m.tick_params(axis='both',
|
1057 |
which='both',
|
@@ -1061,7 +1088,7 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
|
|
1061 |
labelbottom=False)
|
1062 |
|
1063 |
ax_dif = fig.add_subplot(133)
|
1064 |
-
ax_dif.set_title('Fix - Moving image', fontsize=
|
1065 |
ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
|
1066 |
ax_dif.tick_params(axis='both',
|
1067 |
which='both',
|
@@ -1078,7 +1105,7 @@ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=N
|
|
1078 |
|
1079 |
if filename is not None:
|
1080 |
plt.savefig(filename, format='png') # Call before show
|
1081 |
-
if not
|
1082 |
plt.show()
|
1083 |
else:
|
1084 |
plt.close()
|
@@ -1104,7 +1131,7 @@ def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='
|
|
1104 |
fig.tight_layout(pad=5.0)
|
1105 |
ax = fig.add_subplot(2, num_cols, 1, projection='3d')
|
1106 |
ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
1107 |
-
ax.set_title('Fix image', fontsize=
|
1108 |
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1109 |
|
1110 |
for i in range(2, num_preds+2):
|
@@ -1116,23 +1143,23 @@ def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='
|
|
1116 |
|
1117 |
ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d')
|
1118 |
ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
|
1119 |
-
ax.set_title('Fix image', fontsize=
|
1120 |
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1121 |
|
1122 |
for i in range(num_preds+2, 2 * num_preds + 2):
|
1123 |
ax = fig.add_subplot(2, num_cols, i + 1, projection='3d')
|
1124 |
c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3)
|
1125 |
-
ax.quiver(c[
|
1126 |
-
d[
|
1127 |
norm=True)
|
1128 |
ax.set_title('Disp. #{}'.format(i - 5))
|
1129 |
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1130 |
|
1131 |
-
fig.suptitle(fig_title, fontsize=
|
1132 |
|
1133 |
if save_file:
|
1134 |
plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show
|
1135 |
-
if not
|
1136 |
plt.show()
|
1137 |
else:
|
1138 |
plt.close()
|
@@ -1149,3 +1176,66 @@ def _square_3d_plot(X, Y, Z, ax):
|
|
1149 |
ax.set_ylim(mid_y - max_range, mid_y + max_range)
|
1150 |
ax.set_zlim(mid_z - max_range, mid_z + max_range)
|
1151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
#matplotlib.use('TkAgg')
|
3 |
import matplotlib.pyplot as plt
|
4 |
from mpl_toolkits.mplot3d import Axes3D
|
5 |
import matplotlib.colors as mcolors
|
|
|
8 |
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
9 |
import tensorflow as tf
|
10 |
import numpy as np
|
11 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
12 |
from skimage.exposure import rescale_intensity
|
13 |
import scipy.misc as scpmisc
|
14 |
import os
|
|
|
175 |
fig.clear()
|
176 |
plt.figure(fig.number)
|
177 |
else:
|
178 |
+
fig = plt.figure(dpi=C.DPI)
|
179 |
|
180 |
ax_fix = fig.add_subplot(231)
|
181 |
im_fix = ax_fix.imshow(list_imgs[0][:, :, 0])
|
182 |
+
ax_fix.set_title('Fix image', fontsize=C.FONT_SIZE)
|
183 |
ax_fix.tick_params(axis='both',
|
184 |
which='both',
|
185 |
bottom=False,
|
|
|
188 |
labelbottom=False)
|
189 |
ax_mov = fig.add_subplot(232)
|
190 |
im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
|
191 |
+
ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE)
|
192 |
ax_mov.tick_params(axis='both',
|
193 |
which='both',
|
194 |
bottom=False,
|
|
|
198 |
|
199 |
ax_pred_im = fig.add_subplot(233)
|
200 |
im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
|
201 |
+
ax_pred_im.set_title('Prediction', fontsize=C.FONT_SIZE)
|
202 |
ax_pred_im.tick_params(axis='both',
|
203 |
which='both',
|
204 |
bottom=False,
|
|
|
228 |
else:
|
229 |
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3])
|
230 |
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
231 |
+
ax_pred_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
232 |
+
ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE)
|
233 |
ax_pred_disp.tick_params(axis='both',
|
234 |
which='both',
|
235 |
bottom=False,
|
|
|
259 |
else:
|
260 |
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4])
|
261 |
im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal')
|
262 |
+
ax_gt_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
263 |
+
ax_gt_disp.set_title('GT disp map', fontsize=C.FONT_SIZE)
|
264 |
ax_gt_disp.tick_params(axis='both',
|
265 |
which='both',
|
266 |
bottom=False,
|
|
|
276 |
|
277 |
if filename is not None:
|
278 |
plt.savefig(filename, format='png') # Call before show
|
279 |
+
if not C.REMOTE:
|
280 |
plt.show()
|
281 |
else:
|
282 |
plt.close()
|
|
|
288 |
fig.clear()
|
289 |
plt.figure(fig.number)
|
290 |
else:
|
291 |
+
fig = plt.figure(dpi=C.DPI)
|
292 |
|
293 |
dim = len(img.shape[:-1])
|
294 |
|
|
|
321 |
plt.close()
|
322 |
|
323 |
|
324 |
+
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False):
|
325 |
if fig is not None:
|
326 |
fig.clear()
|
327 |
plt.figure(fig.number)
|
328 |
else:
|
329 |
+
fig = plt.figure(dpi=C.DPI)
|
330 |
+
dim_h, dim_w, dim_d = disp_map.shape[1:-1]
|
331 |
dim = disp_map.shape[-1]
|
332 |
|
333 |
if dim == 2:
|
334 |
ax_x = fig.add_subplot(131)
|
335 |
ax_x.set_title('H displacement')
|
336 |
+
im_x = ax_x.imshow(disp_map[..., C.H_DISP])
|
337 |
ax_x.tick_params(axis='both',
|
338 |
which='both',
|
339 |
bottom=False,
|
|
|
344 |
|
345 |
ax_y = fig.add_subplot(132)
|
346 |
ax_y.set_title('W displacement')
|
347 |
+
im_y = ax_y.imshow(disp_map[..., C.W_DISP])
|
348 |
ax_y.tick_params(axis='both',
|
349 |
which='both',
|
350 |
bottom=False,
|
|
|
373 |
else:
|
374 |
c, d, s = _prepare_quiver_map(disp_map, dim=dim)
|
375 |
im = ax.imshow(s, interpolation='none', aspect='equal')
|
376 |
+
ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
|
377 |
+
scale=C.QUIVER_PARAMS.arrow_scale)
|
378 |
cb = _set_colorbar(fig, ax, im, False)
|
379 |
ax.set_title('Displacement map')
|
380 |
ax.tick_params(axis='both',
|
|
|
387 |
else:
|
388 |
ax = fig.add_subplot(111, projection='3d')
|
389 |
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
|
390 |
+
ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
|
391 |
+
_square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
|
|
|
392 |
fig.suptitle('Displacement map')
|
393 |
ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html
|
394 |
which='both',
|
|
|
396 |
left=False,
|
397 |
labelleft=False,
|
398 |
labelbottom=False)
|
399 |
+
add_axes_arrows_3d(ax, xyz_label=['R', 'A', 'S'])
|
400 |
fig.suptitle(title)
|
401 |
|
402 |
plt.savefig(filename, format='png')
|
403 |
+
if show:
|
404 |
+
plt.show()
|
405 |
plt.close()
|
406 |
+
return fig
|
407 |
|
408 |
|
409 |
def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None,
|
|
|
412 |
fig.clear()
|
413 |
plt.figure(fig.number)
|
414 |
else:
|
415 |
+
fig = plt.figure(dpi=C.DPI)
|
416 |
|
417 |
dim = len(list_imgs[0].shape[:-1])
|
418 |
|
419 |
if dim == 2:
|
420 |
# TRAINING
|
421 |
ax_input = fig.add_subplot(241)
|
422 |
+
ax_input.set_ylabel(title_first_row, fontsize=C.FONT_SIZE)
|
423 |
im_fix = ax_input.imshow(list_imgs[0][:, :, 0])
|
424 |
+
ax_input.set_title('Fix image', fontsize=C.FONT_SIZE)
|
425 |
ax_input.tick_params(axis='both',
|
426 |
which='both',
|
427 |
bottom=False,
|
|
|
430 |
labelbottom=False)
|
431 |
ax_mov = fig.add_subplot(242)
|
432 |
im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
|
433 |
+
ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE)
|
434 |
ax_mov.tick_params(axis='both',
|
435 |
which='both',
|
436 |
bottom=False,
|
|
|
440 |
|
441 |
ax_pred_im = fig.add_subplot(244)
|
442 |
im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
|
443 |
+
ax_pred_im.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
|
444 |
ax_pred_im.tick_params(axis='both',
|
445 |
which='both',
|
446 |
bottom=False,
|
|
|
470 |
else:
|
471 |
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim)
|
472 |
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
473 |
+
ax_pred_disp.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
474 |
+
ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE)
|
475 |
ax_pred_disp.tick_params(axis='both',
|
476 |
which='both',
|
477 |
bottom=False,
|
|
|
481 |
|
482 |
# VALIDATION
|
483 |
axinput_val = fig.add_subplot(245)
|
484 |
+
axinput_val.set_ylabel(title_second_row, fontsize=C.FONT_SIZE)
|
485 |
im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0])
|
486 |
+
axinput_val.set_title('Fix image', fontsize=C.FONT_SIZE)
|
487 |
axinput_val.tick_params(axis='both',
|
488 |
which='both',
|
489 |
bottom=False,
|
|
|
492 |
labelbottom=False)
|
493 |
ax_mov_val = fig.add_subplot(246)
|
494 |
im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0])
|
495 |
+
ax_mov_val.set_title('Moving image', fontsize=C.FONT_SIZE)
|
496 |
ax_mov_val.tick_params(axis='both',
|
497 |
which='both',
|
498 |
bottom=False,
|
|
|
502 |
|
503 |
ax_pred_im_val = fig.add_subplot(248)
|
504 |
im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0])
|
505 |
+
ax_pred_im_val.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
|
506 |
ax_pred_im_val.tick_params(axis='both',
|
507 |
which='both',
|
508 |
bottom=False,
|
|
|
532 |
else:
|
533 |
c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
|
534 |
im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
|
535 |
+
ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=C.QUIVER_PARAMS.arrow_scale)
|
536 |
+
ax_pred_disp_val.set_title('Pred disp map', fontsize=C.FONT_SIZE)
|
537 |
ax_pred_disp_val.tick_params(axis='both',
|
538 |
which='both',
|
539 |
bottom=False,
|
|
|
555 |
# 3D
|
556 |
# TRAINING
|
557 |
ax_input = fig.add_subplot(231, projection='3d')
|
558 |
+
ax_input.set_ylabel(title_first_row, fontsize=C.FONT_SIZE)
|
559 |
im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
560 |
im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
|
561 |
+
ax_input.set_title('Fix image', fontsize=C.FONT_SIZE)
|
562 |
ax_input.tick_params(axis='both',
|
563 |
which='both',
|
564 |
bottom=False,
|
|
|
569 |
ax_pred_im = fig.add_subplot(232, projection='3d')
|
570 |
im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction')
|
571 |
im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
572 |
+
ax_pred_im.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
|
573 |
ax_pred_im.tick_params(axis='both',
|
574 |
which='both',
|
575 |
bottom=False,
|
|
|
581 |
|
582 |
c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim)
|
583 |
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
584 |
+
ax_pred_disp.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP],
|
585 |
+
d[C.H_DISP], d[C.W_DISP], d[C.D_DISP], scale=C.QUIVER_PARAMS.arrow_scale)
|
586 |
+
ax_pred_disp.set_title('Pred disp map', fontsize=C.FONT_SIZE)
|
587 |
ax_pred_disp.tick_params(axis='both',
|
588 |
which='both',
|
589 |
bottom=False,
|
|
|
593 |
|
594 |
# VALIDATION
|
595 |
axinput_val = fig.add_subplot(234, projection='3d')
|
596 |
+
axinput_val.set_ylabel(title_second_row, fontsize=C.FONT_SIZE)
|
597 |
im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
|
598 |
im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)')
|
599 |
+
axinput_val.set_title('Fix image', fontsize=C.FONT_SIZE)
|
600 |
axinput_val.tick_params(axis='both',
|
601 |
which='both',
|
602 |
bottom=False,
|
|
|
607 |
ax_pred_im_val = fig.add_subplot(235, projection='3d')
|
608 |
im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)')
|
609 |
im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
|
610 |
+
ax_pred_im_val.set_title('Predicted fix image', fontsize=C.FONT_SIZE)
|
611 |
ax_pred_im_val.tick_params(axis='both',
|
612 |
which='both',
|
613 |
bottom=False,
|
|
|
618 |
ax_pred_disp_val = fig.add_subplot(236, projection='3d')
|
619 |
c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
|
620 |
im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
|
621 |
+
ax_pred_disp_val.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP],
|
622 |
+
d[C.H_DISP], d[C.W_DISP], d[C.D_DISP],
|
623 |
+
scale=C.QUIVER_PARAMS.arrow_scale)
|
624 |
+
ax_pred_disp_val.set_title('Pred disp map', fontsize=C.FONT_SIZE)
|
625 |
ax_pred_disp_val.tick_params(axis='both',
|
626 |
which='both',
|
627 |
bottom=False,
|
|
|
631 |
|
632 |
if filename is not None:
|
633 |
plt.savefig(filename, format='png') # Call before show
|
634 |
+
if not C.REMOTE:
|
635 |
plt.show()
|
636 |
else:
|
637 |
plt.close()
|
|
|
647 |
return im_cb
|
648 |
|
649 |
|
650 |
+
def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=C.QUIVER_PARAMS.spacing):
|
651 |
if isinstance(disp_map, tf.Tensor):
|
652 |
if tf.executing_eagerly():
|
653 |
disp_map = disp_map.numpy()
|
654 |
else:
|
655 |
disp_map = disp_map.eval()
|
656 |
+
dx = disp_map[..., C.H_DISP]
|
657 |
+
dy = disp_map[..., C.W_DISP]
|
658 |
if dim > 2:
|
659 |
+
dz = disp_map[..., C.D_DISP]
|
660 |
|
661 |
+
img_size_x = disp_map.shape[C.H_DISP]
|
662 |
+
img_size_y = disp_map.shape[C.W_DISP]
|
663 |
if dim > 2:
|
664 |
+
img_size_z = disp_map.shape[C.D_DISP]
|
665 |
|
666 |
if dim > 2:
|
667 |
s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz))
|
|
|
731 |
|
732 |
if filename is not None:
|
733 |
plt.savefig(filename, format='png') # Call before show
|
734 |
+
if not C.REMOTE:
|
735 |
plt.show()
|
736 |
else:
|
737 |
plt.close()
|
|
|
810 |
return fig
|
811 |
|
812 |
|
813 |
+
def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None, show=False):
|
814 |
num_rows = fix_img_batch.shape[0]
|
815 |
+
img_dim = len(fix_img_batch.shape) - 2
|
816 |
+
img_size = fix_img_batch.shape[1:-1]
|
817 |
if fig is not None:
|
818 |
fig.clear()
|
819 |
plt.figure(fig.number)
|
820 |
+
ax = fig.add_subplot(nrows=num_rows, ncols=5, figsize=(10, 3*num_rows), dpi=C.DPI)
|
821 |
else:
|
822 |
+
fig, ax = plt.subplots(nrows=num_rows, ncols=5, figsize=(10, 3*num_rows), dpi=C.DPI)
|
823 |
+
if num_rows == 1:
|
824 |
+
ax = ax[np.newaxis, ...]
|
825 |
+
|
826 |
+
if img_dim == 3: # Extract slices from the images
|
827 |
+
selected_slice = img_size[0] // 2
|
828 |
+
fix_img_batch = fix_img_batch[:, selected_slice, ...]
|
829 |
+
mov_img_batch = mov_img_batch[:, selected_slice, ...]
|
830 |
+
pred_img_batch = pred_img_batch[:, selected_slice, ...]
|
831 |
+
disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
|
832 |
+
img_size = fix_img_batch.shape[1:-1]
|
833 |
+
elif img_dim != 2:
|
834 |
+
raise ValueError('Images have a bad shape: {}'.format(fix_img_batch.shape))
|
835 |
|
836 |
for row in range(num_rows):
|
837 |
+
fix_img = fix_img_batch[row, :, :, 0].transpose()
|
838 |
+
mov_img = mov_img_batch[row, :, :, 0].transpose()
|
839 |
+
disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
|
840 |
+
pred_img = pred_img_batch[row, :, :, 0].transpose()
|
841 |
+
ax[row, 0].imshow(fix_img, origin='lower')
|
842 |
ax[row, 0].tick_params(axis='both',
|
843 |
which='both',
|
844 |
bottom=False,
|
845 |
left=False,
|
846 |
labelleft=False,
|
847 |
labelbottom=False)
|
848 |
+
ax[row, 1].imshow(mov_img, origin='lower')
|
849 |
ax[row, 1].tick_params(axis='both',
|
850 |
which='both',
|
851 |
bottom=False,
|
|
|
853 |
labelleft=False,
|
854 |
labelbottom=False)
|
855 |
|
856 |
+
c, d, s = _prepare_quiver_map(disp_map, spc=5)
|
857 |
+
cx, cy = c
|
858 |
+
dx, dy = d
|
859 |
disp_map_color = _prepare_colormap(disp_map)
|
860 |
+
ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal', origin='lower')
|
861 |
+
ax[row, 2].quiver(cx, cy, dx, dy, units='dots', scale=1)
|
|
|
862 |
ax[row, 2].tick_params(axis='both',
|
863 |
which='both',
|
864 |
bottom=False,
|
|
|
866 |
labelleft=False,
|
867 |
labelbottom=False)
|
868 |
|
869 |
+
ax[row, 3].imshow(mov_img, origin='lower')
|
870 |
+
ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
|
871 |
ax[row, 3].tick_params(axis='both',
|
872 |
which='both',
|
873 |
bottom=False,
|
|
|
875 |
labelleft=False,
|
876 |
labelbottom=False)
|
877 |
|
878 |
+
ax[row, 4].imshow(pred_img, origin='lower')
|
879 |
+
ax[row, 4].tick_params(axis='both',
|
880 |
+
which='both',
|
881 |
+
bottom=False,
|
882 |
+
left=False,
|
883 |
+
labelleft=False,
|
884 |
+
labelbottom=False)
|
885 |
|
886 |
+
plt.axis('off')
|
887 |
+
ax[0, 0].set_title('Fixed img ($I_f$)', fontsize=C.FONT_SIZE)
|
888 |
+
ax[0, 1].set_title('Moving img ($I_m$)', fontsize=C.FONT_SIZE)
|
889 |
+
ax[0, 2].set_title('Backwards\ndisp .map ($\delta$)', fontsize=C.FONT_SIZE)
|
890 |
+
ax[0, 3].set_title('Disp. map over $I_m$', fontsize=C.FONT_SIZE)
|
891 |
+
ax[0, 4].set_title('Predicted $I_m$', fontsize=C.FONT_SIZE)
|
892 |
+
plt.tight_layout()
|
893 |
if filename is not None:
|
894 |
plt.savefig(filename, format='png') # Call before show
|
895 |
+
if show:
|
896 |
plt.show()
|
897 |
+
plt.close()
|
|
|
898 |
return fig
|
899 |
|
900 |
|
|
|
903 |
fig.clear()
|
904 |
plt.figure(fig.number)
|
905 |
else:
|
906 |
+
fig = plt.figure(dpi=C.DPI)
|
907 |
|
908 |
ax0 = fig.add_subplot(221)
|
909 |
im0 = ax0.imshow(fix_img[..., 0])
|
|
|
927 |
ax2 = fig.add_subplot(223)
|
928 |
im2 = ax2.imshow(s, interpolation='none', aspect='equal')
|
929 |
|
930 |
+
ax2.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
931 |
# ax2.figure.set_size_inches(img_size)
|
932 |
ax2.tick_params(axis='both',
|
933 |
which='both',
|
|
|
939 |
ax3 = fig.add_subplot(224)
|
940 |
dif = fix_img[..., 0] - mov_img[..., 0]
|
941 |
im3 = ax3.imshow(dif)
|
942 |
+
ax3.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
943 |
ax3.tick_params(axis='both',
|
944 |
which='both',
|
945 |
bottom=False,
|
|
|
948 |
labelbottom=False)
|
949 |
|
950 |
plt.axis('off')
|
951 |
+
ax0.set_title('Fixed img ($I_f$)', fontsize=C.FONT_SIZE)
|
952 |
+
ax1.set_title('Moving img ($I_m$)', fontsize=C.FONT_SIZE)
|
953 |
+
ax2.set_title('Displacement map', fontsize=C.FONT_SIZE)
|
954 |
+
ax3.set_title('Fix - Mov', fontsize=C.FONT_SIZE)
|
955 |
|
956 |
im0_cb = _set_colorbar(fig, ax0, im0, False)
|
957 |
im1_cb = _set_colorbar(fig, ax1, im1, False)
|
|
|
960 |
|
961 |
if filename is not None:
|
962 |
plt.savefig(filename, format='png') # Call before show
|
963 |
+
if not C.REMOTE:
|
964 |
plt.show()
|
965 |
else:
|
966 |
plt.close()
|
|
|
977 |
fig = plt.figure()
|
978 |
|
979 |
ax_grid = fig.add_subplot(231)
|
980 |
+
ax_grid.set_title('Grids', fontsize=C.FONT_SIZE)
|
981 |
ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20)
|
982 |
ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1)
|
983 |
ax_grid.tick_params(axis='both',
|
|
|
993 |
ax_grid.set_aspect('equal')
|
994 |
|
995 |
ax_disp = fig.add_subplot(232)
|
996 |
+
ax_disp.set_title('Displacement map', fontsize=C.FONT_SIZE)
|
997 |
cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
|
998 |
ax_disp.imshow(s, interpolation='none', aspect='equal')
|
999 |
ax_disp.tick_params(axis='both',
|
|
|
1004 |
labelbottom=False)
|
1005 |
|
1006 |
ax_mask = fig.add_subplot(233)
|
1007 |
+
ax_mask.set_title('Mask', fontsize=C.FONT_SIZE)
|
1008 |
ax_mask.imshow(mask)
|
1009 |
ax_mask.tick_params(axis='both',
|
1010 |
which='both',
|
|
|
1014 |
labelbottom=False)
|
1015 |
|
1016 |
ax_fix = fig.add_subplot(234)
|
1017 |
+
ax_fix.set_title('Fix image', fontsize=C.FONT_SIZE)
|
1018 |
ax_fix.imshow(fix_img[..., 0])
|
1019 |
ax_fix.tick_params(axis='both',
|
1020 |
which='both',
|
|
|
1024 |
labelbottom=False)
|
1025 |
|
1026 |
ax_mov = fig.add_subplot(235)
|
1027 |
+
ax_mov.set_title('Moving image', fontsize=C.FONT_SIZE)
|
1028 |
ax_mov.imshow(mov_img[..., 0])
|
1029 |
ax_mov.tick_params(axis='both',
|
1030 |
which='both',
|
|
|
1034 |
labelbottom=False)
|
1035 |
|
1036 |
ax_dif = fig.add_subplot(236)
|
1037 |
+
ax_dif.set_title('Fix - Moving image', fontsize=C.FONT_SIZE)
|
1038 |
ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
|
1039 |
ax_dif.tick_params(axis='both',
|
1040 |
which='both',
|
|
|
1050 |
|
1051 |
if filename is not None:
|
1052 |
plt.savefig(filename, format='png') # Call before show
|
1053 |
+
if not C.REMOTE:
|
1054 |
plt.show()
|
1055 |
|
1056 |
return fig
|
|
|
1064 |
fig = plt.figure()
|
1065 |
|
1066 |
ax_d_m_f = fig.add_subplot(131)
|
1067 |
+
ax_d_m_f.set_title('Disp M->F', fontsize=C.FONT_SIZE)
|
1068 |
cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f)
|
1069 |
ax_d_m_f.imshow(s, interpolation='none', aspect='equal')
|
1070 |
+
ax_d_m_f.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
1071 |
ax_d_m_f.tick_params(axis='both',
|
1072 |
which='both',
|
1073 |
bottom=False,
|
|
|
1076 |
labelbottom=False)
|
1077 |
|
1078 |
ax_d_f_m = fig.add_subplot(132)
|
1079 |
+
ax_d_f_m.set_title('Disp F->M', fontsize=C.FONT_SIZE)
|
1080 |
cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m)
|
1081 |
+
ax_d_f_m.quiver(cx, cy, dx, dy, scale=C.QUIVER_PARAMS.arrow_scale)
|
1082 |
ax_d_f_m.imshow(s, interpolation='none', aspect='equal')
|
1083 |
ax_d_f_m.tick_params(axis='both',
|
1084 |
which='both',
|
|
|
1088 |
labelbottom=False)
|
1089 |
|
1090 |
ax_dif = fig.add_subplot(133)
|
1091 |
+
ax_dif.set_title('Fix - Moving image', fontsize=C.FONT_SIZE)
|
1092 |
ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
|
1093 |
ax_dif.tick_params(axis='both',
|
1094 |
which='both',
|
|
|
1105 |
|
1106 |
if filename is not None:
|
1107 |
plt.savefig(filename, format='png') # Call before show
|
1108 |
+
if not C.REMOTE:
|
1109 |
plt.show()
|
1110 |
else:
|
1111 |
plt.close()
|
|
|
1131 |
fig.tight_layout(pad=5.0)
|
1132 |
ax = fig.add_subplot(2, num_cols, 1, projection='3d')
|
1133 |
ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
1134 |
+
ax.set_title('Fix image', fontsize=C.FONT_SIZE)
|
1135 |
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1136 |
|
1137 |
for i in range(2, num_preds+2):
|
|
|
1143 |
|
1144 |
ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d')
|
1145 |
ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
|
1146 |
+
ax.set_title('Fix image', fontsize=C.FONT_SIZE)
|
1147 |
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1148 |
|
1149 |
for i in range(num_preds+2, 2 * num_preds + 2):
|
1150 |
ax = fig.add_subplot(2, num_cols, i + 1, projection='3d')
|
1151 |
c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3)
|
1152 |
+
ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP],
|
1153 |
+
d[C.H_DISP], d[C.W_DISP], d[C.D_DISP],
|
1154 |
norm=True)
|
1155 |
ax.set_title('Disp. #{}'.format(i - 5))
|
1156 |
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1157 |
|
1158 |
+
fig.suptitle(fig_title, fontsize=C.FONT_SIZE)
|
1159 |
|
1160 |
if save_file:
|
1161 |
plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show
|
1162 |
+
if not C.REMOTE:
|
1163 |
plt.show()
|
1164 |
else:
|
1165 |
plt.close()
|
|
|
1176 |
ax.set_ylim(mid_y - max_range, mid_y + max_range)
|
1177 |
ax.set_zlim(mid_z - max_range, mid_z + max_range)
|
1178 |
|
1179 |
+
|
1180 |
+
def remove_tick_labels(ax, project_3d=False):
|
1181 |
+
ax.set_xticklabels([])
|
1182 |
+
ax.set_yticklabels([])
|
1183 |
+
if project_3d:
|
1184 |
+
ax.set_zticklabels([])
|
1185 |
+
return ax
|
1186 |
+
|
1187 |
+
|
1188 |
+
def add_axes_arrows_3d(ax, arrow_length=10, xyz_colours=['r', 'g', 'b'], xyz_label=['X', 'Y', 'Z'], dist_arrow_text=3):
|
1189 |
+
x_limits = ax.get_xlim3d()
|
1190 |
+
y_limits = ax.get_ylim3d()
|
1191 |
+
z_limits = ax.get_zlim3d()
|
1192 |
+
|
1193 |
+
x_len = x_limits[1] - x_limits[0]
|
1194 |
+
y_len = y_limits[1] - y_limits[0]
|
1195 |
+
z_len = z_limits[1] - z_limits[0]
|
1196 |
+
|
1197 |
+
ax.quiver(x_limits[0], y_limits[0], z_limits[0], x_len, 0, 0, color=xyz_colours[0], arrow_length_ratio=0) # (init_loc, end_loc, params)
|
1198 |
+
ax.quiver(x_limits[0], y_limits[0], z_limits[0], 0, y_len, 0, color=xyz_colours[1], arrow_length_ratio=0)
|
1199 |
+
ax.quiver(x_limits[0], y_limits[0], z_limits[0], 0, 0, z_len, color=xyz_colours[2], arrow_length_ratio=0)
|
1200 |
+
|
1201 |
+
# X axis
|
1202 |
+
ax.quiver(x_limits[1], y_limits[0], z_limits[0], arrow_length, 0, 0, color=xyz_colours[0])
|
1203 |
+
ax.text(x_limits[1] + arrow_length + dist_arrow_text, y_limits[0], z_limits[0], xyz_label[0], fontsize=20, ha='right', va='top')
|
1204 |
+
|
1205 |
+
# Y axis
|
1206 |
+
ax.quiver(x_limits[0], y_limits[1], z_limits[0], 0, arrow_length, 0, color=xyz_colours[1])
|
1207 |
+
ax.text(x_limits[0], y_limits[1] + arrow_length + dist_arrow_text, z_limits[0], xyz_label[0], fontsize=20, ha='left', va='top')
|
1208 |
+
|
1209 |
+
# Z axis
|
1210 |
+
ax.quiver(x_limits[0], y_limits[0], z_limits[1], 0, 0, arrow_length, color=xyz_colours[2])
|
1211 |
+
ax.text(x_limits[0], y_limits[0], z_limits[1] + arrow_length + dist_arrow_text, xyz_label[0], fontsize=20, ha='center', va='bottom')
|
1212 |
+
|
1213 |
+
return ax
|
1214 |
+
|
1215 |
+
|
1216 |
+
def add_axes_arrows_2d(ax, arrow_length=10, xy_colour=['r', 'g'], xy_label=['X', 'Y']):
|
1217 |
+
x_limits = list(ax.get_xlim())
|
1218 |
+
y_limits = list(ax.get_ylim())
|
1219 |
+
origin = (x_limits[0], y_limits[1])
|
1220 |
+
|
1221 |
+
ax.annotate('', xy=(origin[0] + arrow_length, origin[1]), xytext=origin,
|
1222 |
+
arrowprops=dict(headlength=8., headwidth=10., width=5., color=xy_colour[0]))
|
1223 |
+
ax.annotate('', xy=(origin[0], origin[1] + arrow_length), xytext=origin,
|
1224 |
+
arrowprops=dict(headlength=8., headwidth=10., width=5., color=xy_colour[0]))
|
1225 |
+
|
1226 |
+
ax.text(origin[0] + arrow_length, origin[1], xy_label[0], fontsize=25, ha='left', va='bottom')
|
1227 |
+
ax.text(origin[0] - 1, origin[1] + arrow_length, xy_label[1], fontsize=25, ha='right', va='top')
|
1228 |
+
|
1229 |
+
return ax
|
1230 |
+
|
1231 |
+
|
1232 |
+
def set_axes_size(w,h, ax=None):
|
1233 |
+
""" w, h: width, height in inches """
|
1234 |
+
if not ax: ax=plt.gca()
|
1235 |
+
l = ax.figure.subplotpars.left
|
1236 |
+
r = ax.figure.subplotpars.right
|
1237 |
+
t = ax.figure.subplotpars.top
|
1238 |
+
b = ax.figure.subplotpars.bottom
|
1239 |
+
figw = float(w)/(r-l)
|
1240 |
+
figh = float(h)/(t-b)
|
1241 |
+
ax.figure.set_size_inches(figw, figh)
|