jpdefrutos commited on
Commit
61f0e36
·
1 Parent(s): 6f67742

Imported and adapted SpatialTransformer from VoxelMorph v0.1

Browse files
DeepDeformationMapRegistration/layers/SpatialTransformer.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.keras.layers as kl
2
+ import tensorflow.keras.backend as K
3
+ import tensorflow as tf
4
+ import neurite as ne
5
+
6
+
7
+ class SpatialTransformer(kl.Layer):
8
+ """
9
+ Adapted SpatialTransformer layer taken from VoxelMorph v0.1
10
+ https://github.com/voxelmorph/voxelmorph/blob/dev/voxelmorph/tf/layers.py
11
+ Removed unused options to ease portability
12
+
13
+ N-D Spatial Transformer Tensorflow / Keras Layer
14
+
15
+ The Layer can handle ONLY dense transforms.
16
+ Transforms are meant to give a 'shift' from the current position.
17
+ Therefore, a dense transform gives displacements (not absolute locations) at each voxel.
18
+
19
+ If you find this function useful, please cite:
20
+ Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
21
+ Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
22
+ MICCAI 2018.
23
+
24
+ Originally, this code was based on voxelmorph code, which
25
+ was in turn transformed to be dense with the help of (affine) STN code
26
+ via https://github.com/kevinzakka/spatial-transformer-network
27
+
28
+ Since then, we've re-written the code to be generalized to any
29
+ dimensions, and along the way wrote grid and interpolation functions
30
+ """
31
+
32
+ def __init__(self,
33
+ interp_method='linear',
34
+ indexing='ij',
35
+ single_transform=False,
36
+ fill_value=None,
37
+ add_identity=True,
38
+ shift_center=True,
39
+ **kwargs):
40
+ """
41
+ Parameters:
42
+ interp_method: 'linear' or 'nearest'
43
+ single_transform: whether a single transform supplied for the whole batch
44
+ indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
45
+ 'xy' indexing will have the first two entries of the flow
46
+ (along last axis) flipped compared to 'ij' indexing
47
+ fill_value (default: None): value to use for points outside the domain.
48
+ If None, the nearest neighbors will be used.
49
+ add_identity (default: True): whether the identity matrix is added
50
+ to affine transforms.
51
+ shift_center (default: True): whether the grid is shifted to the center
52
+ of the image when converting affine transforms to warp fields.
53
+ """
54
+ self.interp_method = interp_method
55
+ self.fill_value = fill_value
56
+ self.add_identity = add_identity
57
+ self.shift_center = shift_center
58
+ self.ndims = None
59
+ self.inshape = None
60
+ self.single_transform = single_transform
61
+
62
+ assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
63
+ self.indexing = indexing
64
+
65
+ super(self.__class__, self).__init__(**kwargs)
66
+
67
+ def get_config(self):
68
+ config = super().get_config().copy()
69
+ config.update({
70
+ 'interp_method': self.interp_method,
71
+ 'indexing': self.indexing,
72
+ 'single_transform': self.single_transform,
73
+ 'fill_value': self.fill_value,
74
+ 'add_identity': self.add_identity,
75
+ 'shift_center': self.shift_center,
76
+ })
77
+ return config
78
+
79
+ def build(self, input_shape):
80
+ """
81
+ input_shape should be a list for two inputs:
82
+ input1: image.
83
+ input2: transform Tensor
84
+ if affine:
85
+ should be a N x N+1 matrix
86
+ *or* a N*N+1 tensor (which will be reshape to N x (N+1) and an identity row added)
87
+ if not affine:
88
+ should be a *vol_shape x N
89
+ """
90
+
91
+ if len(input_shape) > 2:
92
+ raise Exception('Spatial Transformer must be called on a list of length 2.'
93
+ 'First argument is the image, second is the transform.')
94
+
95
+ # set up number of dimensions
96
+ self.ndims = len(input_shape[0]) - 2
97
+ self.inshape = input_shape
98
+ vol_shape = input_shape[0][1:-1]
99
+ trf_shape = input_shape[1][1:]
100
+
101
+ # the transform is an affine iff:
102
+ # it's a 1D Tensor [dense transforms need to be at least ndims + 1]
103
+ # it's a 2D Tensor and shape == [N+1, N+1] or [N, N+1]
104
+ # [dense with N=1, which is the only one that could have a transform shape of 2, would be of size Mx1]
105
+ is_matrix = len(trf_shape) == 2 and trf_shape[0] in (self.ndims, self.ndims + 1) and trf_shape[
106
+ 1] == self.ndims + 1
107
+ assert not (len(trf_shape) == 1 or is_matrix), "Invalid transformation. Expected a dense displacement map"
108
+
109
+ # check sizes
110
+ if trf_shape[-1] != self.ndims:
111
+ raise Exception('Offset flow field size expected: %d, found: %d'
112
+ % (self.ndims, trf_shape[-1]))
113
+
114
+ # confirm built
115
+ self.built = True
116
+
117
+ def call(self, inputs):
118
+ """
119
+ Parameters
120
+ inputs: list with two entries
121
+ """
122
+
123
+ # check shapes
124
+ assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
125
+ vol = inputs[0]
126
+ trf = inputs[1]
127
+
128
+ # necessary for multi_gpu models...
129
+ vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
130
+ trf = K.reshape(trf, [-1, *self.inshape[1][1:]])
131
+
132
+ # prepare location shift
133
+ if self.indexing == 'xy': # shift the first two dimensions
134
+ trf_split = tf.split(trf, trf.shape[-1], axis=-1)
135
+ trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]]
136
+ trf = tf.concat(trf_lst, -1)
137
+
138
+ # map transform across batch
139
+ if self.single_transform:
140
+ fn = lambda x: self._single_transform([x, trf[0, :]])
141
+ return tf.map_fn(fn, vol, dtype=tf.float32)
142
+ else:
143
+ return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)
144
+
145
+ def _single_transform(self, inputs):
146
+ return self._transform(inputs[0], inputs[1], interp_method=self.interp_method, fill_value=self.fill_value)
147
+
148
+ def _transform(self, vol, loc_shift, interp_method='linear', indexing='ij', fill_value=None):
149
+ """
150
+ transform (interpolation N-D volumes (features) given shifts at each location in tensorflow
151
+
152
+ Essentially interpolates volume vol at locations determined by loc_shift.
153
+ This is a spatial transform in the sense that at location [x] we now have the data from,
154
+ [x + shift] so we've moved data.
155
+
156
+ Parameters:
157
+ vol: volume with size vol_shape or [*vol_shape, nb_features]
158
+ loc_shift: shift volume [*new_vol_shape, N]
159
+ interp_method (default:'linear'): 'linear', 'nearest'
160
+ indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
161
+ In general, prefer to leave this 'ij'
162
+ fill_value (default: None): value to use for points outside the domain.
163
+ If None, the nearest neighbors will be used.
164
+
165
+ Return:
166
+ new interpolated volumes in the same size as loc_shift[0]
167
+
168
+ Keyworks:
169
+ interpolation, sampler, resampler, linear, bilinear
170
+ """
171
+
172
+ # parse shapes
173
+
174
+ if isinstance(loc_shift.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
175
+ volshape = loc_shift.shape[:-1].as_list()
176
+ else:
177
+ volshape = loc_shift.shape[:-1]
178
+ nb_dims = len(volshape)
179
+
180
+ # location should be mesh and delta
181
+ mesh = ne.utils.volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh
182
+ loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]
183
+
184
+ # test single
185
+ return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
186
+
DeepDeformationMapRegistration/main.py CHANGED
@@ -14,8 +14,7 @@ from scipy.ndimage import gaussian_filter, zoom
14
  from skimage.measure import regionprops
15
  import SimpleITK as sitk
16
 
17
- from voxelmorph.tf.layers import SpatialTransformer
18
-
19
  import DeepDeformationMapRegistration.utils.constants as C
20
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
21
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
@@ -298,8 +297,7 @@ def main():
298
 
299
  LOGGER.info('Applying displacement map...')
300
  time_pred_img_start = time.time()
301
- #pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
302
- pred_image = np.zeros_like(moving_image[np.newaxis, ...]) # @TODO: Replace this with Keras' Model with SpatialTransformer Layer
303
  time_pred_img_end = time.time()
304
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
305
  pred_image = pred_image[0, ...]
 
14
  from skimage.measure import regionprops
15
  import SimpleITK as sitk
16
 
17
+ from DeepDeformationMapRegistration.layers.SpatialTransformer import SpatialTransformer
 
18
  import DeepDeformationMapRegistration.utils.constants as C
19
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
20
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
 
297
 
298
  LOGGER.info('Applying displacement map...')
299
  time_pred_img_start = time.time()
300
+ pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
 
301
  time_pred_img_end = time.time()
302
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
303
  pred_image = pred_image[0, ...]