Merge pull request #32 from jpdefrutos/HF_spatialtransformer
Browse files
DeepDeformationMapRegistration/layers/SpatialTransformer.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow.keras.layers as kl
|
| 2 |
+
import tensorflow.keras.backend as K
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import neurite as ne
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SpatialTransformer(kl.Layer):
|
| 8 |
+
"""
|
| 9 |
+
Adapted SpatialTransformer layer taken from VoxelMorph v0.1
|
| 10 |
+
https://github.com/voxelmorph/voxelmorph/blob/dev/voxelmorph/tf/layers.py
|
| 11 |
+
Removed unused options to ease portability
|
| 12 |
+
|
| 13 |
+
N-D Spatial Transformer Tensorflow / Keras Layer
|
| 14 |
+
|
| 15 |
+
The Layer can handle ONLY dense transforms.
|
| 16 |
+
Transforms are meant to give a 'shift' from the current position.
|
| 17 |
+
Therefore, a dense transform gives displacements (not absolute locations) at each voxel.
|
| 18 |
+
|
| 19 |
+
If you find this function useful, please cite:
|
| 20 |
+
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
|
| 21 |
+
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
|
| 22 |
+
MICCAI 2018.
|
| 23 |
+
|
| 24 |
+
Originally, this code was based on voxelmorph code, which
|
| 25 |
+
was in turn transformed to be dense with the help of (affine) STN code
|
| 26 |
+
via https://github.com/kevinzakka/spatial-transformer-network
|
| 27 |
+
|
| 28 |
+
Since then, we've re-written the code to be generalized to any
|
| 29 |
+
dimensions, and along the way wrote grid and interpolation functions
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self,
|
| 33 |
+
interp_method='linear',
|
| 34 |
+
indexing='ij',
|
| 35 |
+
single_transform=False,
|
| 36 |
+
fill_value=None,
|
| 37 |
+
add_identity=True,
|
| 38 |
+
shift_center=True,
|
| 39 |
+
**kwargs):
|
| 40 |
+
"""
|
| 41 |
+
Parameters:
|
| 42 |
+
interp_method: 'linear' or 'nearest'
|
| 43 |
+
single_transform: whether a single transform supplied for the whole batch
|
| 44 |
+
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
|
| 45 |
+
'xy' indexing will have the first two entries of the flow
|
| 46 |
+
(along last axis) flipped compared to 'ij' indexing
|
| 47 |
+
fill_value (default: None): value to use for points outside the domain.
|
| 48 |
+
If None, the nearest neighbors will be used.
|
| 49 |
+
add_identity (default: True): whether the identity matrix is added
|
| 50 |
+
to affine transforms.
|
| 51 |
+
shift_center (default: True): whether the grid is shifted to the center
|
| 52 |
+
of the image when converting affine transforms to warp fields.
|
| 53 |
+
"""
|
| 54 |
+
self.interp_method = interp_method
|
| 55 |
+
self.fill_value = fill_value
|
| 56 |
+
self.add_identity = add_identity
|
| 57 |
+
self.shift_center = shift_center
|
| 58 |
+
self.ndims = None
|
| 59 |
+
self.inshape = None
|
| 60 |
+
self.single_transform = single_transform
|
| 61 |
+
|
| 62 |
+
assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
|
| 63 |
+
self.indexing = indexing
|
| 64 |
+
|
| 65 |
+
super(self.__class__, self).__init__(**kwargs)
|
| 66 |
+
|
| 67 |
+
def get_config(self):
|
| 68 |
+
config = super().get_config().copy()
|
| 69 |
+
config.update({
|
| 70 |
+
'interp_method': self.interp_method,
|
| 71 |
+
'indexing': self.indexing,
|
| 72 |
+
'single_transform': self.single_transform,
|
| 73 |
+
'fill_value': self.fill_value,
|
| 74 |
+
'add_identity': self.add_identity,
|
| 75 |
+
'shift_center': self.shift_center,
|
| 76 |
+
})
|
| 77 |
+
return config
|
| 78 |
+
|
| 79 |
+
def build(self, input_shape):
|
| 80 |
+
"""
|
| 81 |
+
input_shape should be a list for two inputs:
|
| 82 |
+
input1: image.
|
| 83 |
+
input2: transform Tensor
|
| 84 |
+
if affine:
|
| 85 |
+
should be a N x N+1 matrix
|
| 86 |
+
*or* a N*N+1 tensor (which will be reshape to N x (N+1) and an identity row added)
|
| 87 |
+
if not affine:
|
| 88 |
+
should be a *vol_shape x N
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
if len(input_shape) > 2:
|
| 92 |
+
raise Exception('Spatial Transformer must be called on a list of length 2.'
|
| 93 |
+
'First argument is the image, second is the transform.')
|
| 94 |
+
|
| 95 |
+
# set up number of dimensions
|
| 96 |
+
self.ndims = len(input_shape[0]) - 2
|
| 97 |
+
self.inshape = input_shape
|
| 98 |
+
vol_shape = input_shape[0][1:-1]
|
| 99 |
+
trf_shape = input_shape[1][1:]
|
| 100 |
+
|
| 101 |
+
# the transform is an affine iff:
|
| 102 |
+
# it's a 1D Tensor [dense transforms need to be at least ndims + 1]
|
| 103 |
+
# it's a 2D Tensor and shape == [N+1, N+1] or [N, N+1]
|
| 104 |
+
# [dense with N=1, which is the only one that could have a transform shape of 2, would be of size Mx1]
|
| 105 |
+
is_matrix = len(trf_shape) == 2 and trf_shape[0] in (self.ndims, self.ndims + 1) and trf_shape[
|
| 106 |
+
1] == self.ndims + 1
|
| 107 |
+
assert not (len(trf_shape) == 1 or is_matrix), "Invalid transformation. Expected a dense displacement map"
|
| 108 |
+
|
| 109 |
+
# check sizes
|
| 110 |
+
if trf_shape[-1] != self.ndims:
|
| 111 |
+
raise Exception('Offset flow field size expected: %d, found: %d'
|
| 112 |
+
% (self.ndims, trf_shape[-1]))
|
| 113 |
+
|
| 114 |
+
# confirm built
|
| 115 |
+
self.built = True
|
| 116 |
+
|
| 117 |
+
def call(self, inputs):
|
| 118 |
+
"""
|
| 119 |
+
Parameters
|
| 120 |
+
inputs: list with two entries
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
# check shapes
|
| 124 |
+
assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
|
| 125 |
+
vol = inputs[0]
|
| 126 |
+
trf = inputs[1]
|
| 127 |
+
|
| 128 |
+
# necessary for multi_gpu models...
|
| 129 |
+
vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
|
| 130 |
+
trf = K.reshape(trf, [-1, *self.inshape[1][1:]])
|
| 131 |
+
|
| 132 |
+
# prepare location shift
|
| 133 |
+
if self.indexing == 'xy': # shift the first two dimensions
|
| 134 |
+
trf_split = tf.split(trf, trf.shape[-1], axis=-1)
|
| 135 |
+
trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]]
|
| 136 |
+
trf = tf.concat(trf_lst, -1)
|
| 137 |
+
|
| 138 |
+
# map transform across batch
|
| 139 |
+
if self.single_transform:
|
| 140 |
+
fn = lambda x: self._single_transform([x, trf[0, :]])
|
| 141 |
+
return tf.map_fn(fn, vol, dtype=tf.float32)
|
| 142 |
+
else:
|
| 143 |
+
return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)
|
| 144 |
+
|
| 145 |
+
def _single_transform(self, inputs):
|
| 146 |
+
return self._transform(inputs[0], inputs[1], interp_method=self.interp_method, fill_value=self.fill_value)
|
| 147 |
+
|
| 148 |
+
def _transform(self, vol, loc_shift, interp_method='linear', indexing='ij', fill_value=None):
|
| 149 |
+
"""
|
| 150 |
+
transform (interpolation N-D volumes (features) given shifts at each location in tensorflow
|
| 151 |
+
|
| 152 |
+
Essentially interpolates volume vol at locations determined by loc_shift.
|
| 153 |
+
This is a spatial transform in the sense that at location [x] we now have the data from,
|
| 154 |
+
[x + shift] so we've moved data.
|
| 155 |
+
|
| 156 |
+
Parameters:
|
| 157 |
+
vol: volume with size vol_shape or [*vol_shape, nb_features]
|
| 158 |
+
loc_shift: shift volume [*new_vol_shape, N]
|
| 159 |
+
interp_method (default:'linear'): 'linear', 'nearest'
|
| 160 |
+
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
|
| 161 |
+
In general, prefer to leave this 'ij'
|
| 162 |
+
fill_value (default: None): value to use for points outside the domain.
|
| 163 |
+
If None, the nearest neighbors will be used.
|
| 164 |
+
|
| 165 |
+
Return:
|
| 166 |
+
new interpolated volumes in the same size as loc_shift[0]
|
| 167 |
+
|
| 168 |
+
Keyworks:
|
| 169 |
+
interpolation, sampler, resampler, linear, bilinear
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
# parse shapes
|
| 173 |
+
|
| 174 |
+
if isinstance(loc_shift.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
|
| 175 |
+
volshape = loc_shift.shape[:-1].as_list()
|
| 176 |
+
else:
|
| 177 |
+
volshape = loc_shift.shape[:-1]
|
| 178 |
+
nb_dims = len(volshape)
|
| 179 |
+
|
| 180 |
+
# location should be mesh and delta
|
| 181 |
+
mesh = ne.utils.volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh
|
| 182 |
+
loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]
|
| 183 |
+
|
| 184 |
+
# test single
|
| 185 |
+
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
|
| 186 |
+
|
DeepDeformationMapRegistration/main.py
CHANGED
|
@@ -14,8 +14,7 @@ from scipy.ndimage import gaussian_filter, zoom
|
|
| 14 |
from skimage.measure import regionprops
|
| 15 |
import SimpleITK as sitk
|
| 16 |
|
| 17 |
-
from
|
| 18 |
-
|
| 19 |
import DeepDeformationMapRegistration.utils.constants as C
|
| 20 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
| 21 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
@@ -298,8 +297,7 @@ def main():
|
|
| 298 |
|
| 299 |
LOGGER.info('Applying displacement map...')
|
| 300 |
time_pred_img_start = time.time()
|
| 301 |
-
|
| 302 |
-
pred_image = np.zeros_like(moving_image[np.newaxis, ...]) # @TODO: Replace this with Keras' Model with SpatialTransformer Layer
|
| 303 |
time_pred_img_end = time.time()
|
| 304 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
| 305 |
pred_image = pred_image[0, ...]
|
|
|
|
| 14 |
from skimage.measure import regionprops
|
| 15 |
import SimpleITK as sitk
|
| 16 |
|
| 17 |
+
from DeepDeformationMapRegistration.layers.SpatialTransformer import SpatialTransformer
|
|
|
|
| 18 |
import DeepDeformationMapRegistration.utils.constants as C
|
| 19 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
| 20 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
|
|
| 297 |
|
| 298 |
LOGGER.info('Applying displacement map...')
|
| 299 |
time_pred_img_start = time.time()
|
| 300 |
+
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
|
|
|
| 301 |
time_pred_img_end = time.time()
|
| 302 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
| 303 |
pred_image = pred_image[0, ...]
|