Commit
·
61f0e36
1
Parent(s):
6f67742
Imported and adapted SpatialTransformer from VoxelMorph v0.1
Browse files
DeepDeformationMapRegistration/layers/SpatialTransformer.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow.keras.layers as kl
|
2 |
+
import tensorflow.keras.backend as K
|
3 |
+
import tensorflow as tf
|
4 |
+
import neurite as ne
|
5 |
+
|
6 |
+
|
7 |
+
class SpatialTransformer(kl.Layer):
|
8 |
+
"""
|
9 |
+
Adapted SpatialTransformer layer taken from VoxelMorph v0.1
|
10 |
+
https://github.com/voxelmorph/voxelmorph/blob/dev/voxelmorph/tf/layers.py
|
11 |
+
Removed unused options to ease portability
|
12 |
+
|
13 |
+
N-D Spatial Transformer Tensorflow / Keras Layer
|
14 |
+
|
15 |
+
The Layer can handle ONLY dense transforms.
|
16 |
+
Transforms are meant to give a 'shift' from the current position.
|
17 |
+
Therefore, a dense transform gives displacements (not absolute locations) at each voxel.
|
18 |
+
|
19 |
+
If you find this function useful, please cite:
|
20 |
+
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
|
21 |
+
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
|
22 |
+
MICCAI 2018.
|
23 |
+
|
24 |
+
Originally, this code was based on voxelmorph code, which
|
25 |
+
was in turn transformed to be dense with the help of (affine) STN code
|
26 |
+
via https://github.com/kevinzakka/spatial-transformer-network
|
27 |
+
|
28 |
+
Since then, we've re-written the code to be generalized to any
|
29 |
+
dimensions, and along the way wrote grid and interpolation functions
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
interp_method='linear',
|
34 |
+
indexing='ij',
|
35 |
+
single_transform=False,
|
36 |
+
fill_value=None,
|
37 |
+
add_identity=True,
|
38 |
+
shift_center=True,
|
39 |
+
**kwargs):
|
40 |
+
"""
|
41 |
+
Parameters:
|
42 |
+
interp_method: 'linear' or 'nearest'
|
43 |
+
single_transform: whether a single transform supplied for the whole batch
|
44 |
+
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
|
45 |
+
'xy' indexing will have the first two entries of the flow
|
46 |
+
(along last axis) flipped compared to 'ij' indexing
|
47 |
+
fill_value (default: None): value to use for points outside the domain.
|
48 |
+
If None, the nearest neighbors will be used.
|
49 |
+
add_identity (default: True): whether the identity matrix is added
|
50 |
+
to affine transforms.
|
51 |
+
shift_center (default: True): whether the grid is shifted to the center
|
52 |
+
of the image when converting affine transforms to warp fields.
|
53 |
+
"""
|
54 |
+
self.interp_method = interp_method
|
55 |
+
self.fill_value = fill_value
|
56 |
+
self.add_identity = add_identity
|
57 |
+
self.shift_center = shift_center
|
58 |
+
self.ndims = None
|
59 |
+
self.inshape = None
|
60 |
+
self.single_transform = single_transform
|
61 |
+
|
62 |
+
assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
|
63 |
+
self.indexing = indexing
|
64 |
+
|
65 |
+
super(self.__class__, self).__init__(**kwargs)
|
66 |
+
|
67 |
+
def get_config(self):
|
68 |
+
config = super().get_config().copy()
|
69 |
+
config.update({
|
70 |
+
'interp_method': self.interp_method,
|
71 |
+
'indexing': self.indexing,
|
72 |
+
'single_transform': self.single_transform,
|
73 |
+
'fill_value': self.fill_value,
|
74 |
+
'add_identity': self.add_identity,
|
75 |
+
'shift_center': self.shift_center,
|
76 |
+
})
|
77 |
+
return config
|
78 |
+
|
79 |
+
def build(self, input_shape):
|
80 |
+
"""
|
81 |
+
input_shape should be a list for two inputs:
|
82 |
+
input1: image.
|
83 |
+
input2: transform Tensor
|
84 |
+
if affine:
|
85 |
+
should be a N x N+1 matrix
|
86 |
+
*or* a N*N+1 tensor (which will be reshape to N x (N+1) and an identity row added)
|
87 |
+
if not affine:
|
88 |
+
should be a *vol_shape x N
|
89 |
+
"""
|
90 |
+
|
91 |
+
if len(input_shape) > 2:
|
92 |
+
raise Exception('Spatial Transformer must be called on a list of length 2.'
|
93 |
+
'First argument is the image, second is the transform.')
|
94 |
+
|
95 |
+
# set up number of dimensions
|
96 |
+
self.ndims = len(input_shape[0]) - 2
|
97 |
+
self.inshape = input_shape
|
98 |
+
vol_shape = input_shape[0][1:-1]
|
99 |
+
trf_shape = input_shape[1][1:]
|
100 |
+
|
101 |
+
# the transform is an affine iff:
|
102 |
+
# it's a 1D Tensor [dense transforms need to be at least ndims + 1]
|
103 |
+
# it's a 2D Tensor and shape == [N+1, N+1] or [N, N+1]
|
104 |
+
# [dense with N=1, which is the only one that could have a transform shape of 2, would be of size Mx1]
|
105 |
+
is_matrix = len(trf_shape) == 2 and trf_shape[0] in (self.ndims, self.ndims + 1) and trf_shape[
|
106 |
+
1] == self.ndims + 1
|
107 |
+
assert not (len(trf_shape) == 1 or is_matrix), "Invalid transformation. Expected a dense displacement map"
|
108 |
+
|
109 |
+
# check sizes
|
110 |
+
if trf_shape[-1] != self.ndims:
|
111 |
+
raise Exception('Offset flow field size expected: %d, found: %d'
|
112 |
+
% (self.ndims, trf_shape[-1]))
|
113 |
+
|
114 |
+
# confirm built
|
115 |
+
self.built = True
|
116 |
+
|
117 |
+
def call(self, inputs):
|
118 |
+
"""
|
119 |
+
Parameters
|
120 |
+
inputs: list with two entries
|
121 |
+
"""
|
122 |
+
|
123 |
+
# check shapes
|
124 |
+
assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
|
125 |
+
vol = inputs[0]
|
126 |
+
trf = inputs[1]
|
127 |
+
|
128 |
+
# necessary for multi_gpu models...
|
129 |
+
vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
|
130 |
+
trf = K.reshape(trf, [-1, *self.inshape[1][1:]])
|
131 |
+
|
132 |
+
# prepare location shift
|
133 |
+
if self.indexing == 'xy': # shift the first two dimensions
|
134 |
+
trf_split = tf.split(trf, trf.shape[-1], axis=-1)
|
135 |
+
trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]]
|
136 |
+
trf = tf.concat(trf_lst, -1)
|
137 |
+
|
138 |
+
# map transform across batch
|
139 |
+
if self.single_transform:
|
140 |
+
fn = lambda x: self._single_transform([x, trf[0, :]])
|
141 |
+
return tf.map_fn(fn, vol, dtype=tf.float32)
|
142 |
+
else:
|
143 |
+
return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)
|
144 |
+
|
145 |
+
def _single_transform(self, inputs):
|
146 |
+
return self._transform(inputs[0], inputs[1], interp_method=self.interp_method, fill_value=self.fill_value)
|
147 |
+
|
148 |
+
def _transform(self, vol, loc_shift, interp_method='linear', indexing='ij', fill_value=None):
|
149 |
+
"""
|
150 |
+
transform (interpolation N-D volumes (features) given shifts at each location in tensorflow
|
151 |
+
|
152 |
+
Essentially interpolates volume vol at locations determined by loc_shift.
|
153 |
+
This is a spatial transform in the sense that at location [x] we now have the data from,
|
154 |
+
[x + shift] so we've moved data.
|
155 |
+
|
156 |
+
Parameters:
|
157 |
+
vol: volume with size vol_shape or [*vol_shape, nb_features]
|
158 |
+
loc_shift: shift volume [*new_vol_shape, N]
|
159 |
+
interp_method (default:'linear'): 'linear', 'nearest'
|
160 |
+
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
|
161 |
+
In general, prefer to leave this 'ij'
|
162 |
+
fill_value (default: None): value to use for points outside the domain.
|
163 |
+
If None, the nearest neighbors will be used.
|
164 |
+
|
165 |
+
Return:
|
166 |
+
new interpolated volumes in the same size as loc_shift[0]
|
167 |
+
|
168 |
+
Keyworks:
|
169 |
+
interpolation, sampler, resampler, linear, bilinear
|
170 |
+
"""
|
171 |
+
|
172 |
+
# parse shapes
|
173 |
+
|
174 |
+
if isinstance(loc_shift.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
|
175 |
+
volshape = loc_shift.shape[:-1].as_list()
|
176 |
+
else:
|
177 |
+
volshape = loc_shift.shape[:-1]
|
178 |
+
nb_dims = len(volshape)
|
179 |
+
|
180 |
+
# location should be mesh and delta
|
181 |
+
mesh = ne.utils.volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh
|
182 |
+
loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]
|
183 |
+
|
184 |
+
# test single
|
185 |
+
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
|
186 |
+
|
DeepDeformationMapRegistration/main.py
CHANGED
@@ -14,8 +14,7 @@ from scipy.ndimage import gaussian_filter, zoom
|
|
14 |
from skimage.measure import regionprops
|
15 |
import SimpleITK as sitk
|
16 |
|
17 |
-
from
|
18 |
-
|
19 |
import DeepDeformationMapRegistration.utils.constants as C
|
20 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
21 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
@@ -298,8 +297,7 @@ def main():
|
|
298 |
|
299 |
LOGGER.info('Applying displacement map...')
|
300 |
time_pred_img_start = time.time()
|
301 |
-
|
302 |
-
pred_image = np.zeros_like(moving_image[np.newaxis, ...]) # @TODO: Replace this with Keras' Model with SpatialTransformer Layer
|
303 |
time_pred_img_end = time.time()
|
304 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
305 |
pred_image = pred_image[0, ...]
|
|
|
14 |
from skimage.measure import regionprops
|
15 |
import SimpleITK as sitk
|
16 |
|
17 |
+
from DeepDeformationMapRegistration.layers.SpatialTransformer import SpatialTransformer
|
|
|
18 |
import DeepDeformationMapRegistration.utils.constants as C
|
19 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
20 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
|
297 |
|
298 |
LOGGER.info('Applying displacement map...')
|
299 |
time_pred_img_start = time.time()
|
300 |
+
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
|
|
301 |
time_pred_img_end = time.time()
|
302 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
303 |
pred_image = pred_image[0, ...]
|