File size: 8,125 Bytes
61f0e36 a27d55f 5b0dbe4 61f0e36 ec717e8 61f0e36 5b0dbe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import tensorflow.keras.layers as kl
import tensorflow.keras.backend as K
import tensorflow as tf
import neurite as ne
from ddmr.utils.constants import IMG_SHAPE, DISP_MAP_SHAPE
@tf.keras.utils.register_keras_serializable()
class SpatialTransformer(kl.Layer):
"""
Adapted SpatialTransformer layer taken from VoxelMorph v0.1
https://github.com/voxelmorph/voxelmorph/blob/dev/voxelmorph/tf/layers.py
Removed unused options to ease portability
N-D Spatial Transformer Tensorflow / Keras Layer
The Layer can handle ONLY dense transforms.
Transforms are meant to give a 'shift' from the current position.
Therefore, a dense transform gives displacements (not absolute locations) at each voxel.
If you find this function useful, please cite:
Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
MICCAI 2018.
Originally, this code was based on voxelmorph code, which
was in turn transformed to be dense with the help of (affine) STN code
via https://github.com/kevinzakka/spatial-transformer-network
Since then, we've re-written the code to be generalized to any
dimensions, and along the way wrote grid and interpolation functions
"""
def __init__(self,
interp_method='linear',
indexing='ij',
single_transform=False,
fill_value=None,
add_identity=True,
shift_center=True,
**kwargs):
"""
Parameters:
interp_method: 'linear' or 'nearest'
single_transform: whether a single transform supplied for the whole batch
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
'xy' indexing will have the first two entries of the flow
(along last axis) flipped compared to 'ij' indexing
fill_value (default: None): value to use for points outside the domain.
If None, the nearest neighbors will be used.
add_identity (default: True): whether the identity matrix is added
to affine transforms.
shift_center (default: True): whether the grid is shifted to the center
of the image when converting affine transforms to warp fields.
"""
self.interp_method = interp_method
self.fill_value = fill_value
self.add_identity = add_identity
self.shift_center = shift_center
self.ndims = None
self.inshape = None
self.single_transform = single_transform
assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
self.indexing = indexing
super(self.__class__, self).__init__(**kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({
'interp_method': self.interp_method,
'indexing': self.indexing,
'single_transform': self.single_transform,
'fill_value': self.fill_value,
'add_identity': self.add_identity,
'shift_center': self.shift_center,
})
return config
def build(self, input_shape):
"""
input_shape should be a list for two inputs:
input1: image.
input2: transform Tensor
if affine:
should be a N x N+1 matrix
*or* a N*N+1 tensor (which will be reshape to N x (N+1) and an identity row added)
if not affine:
should be a *vol_shape x N
"""
if len(input_shape) > 2:
raise Exception('Spatial Transformer must be called on a list of length 2.'
'First argument is the image, second is the transform.')
# set up number of dimensions
self.ndims = len(input_shape[0]) - 2
self.inshape = input_shape
vol_shape = input_shape[0][1:-1]
trf_shape = input_shape[1][1:]
# the transform is an affine iff:
# it's a 1D Tensor [dense transforms need to be at least ndims + 1]
# it's a 2D Tensor and shape == [N+1, N+1] or [N, N+1]
# [dense with N=1, which is the only one that could have a transform shape of 2, would be of size Mx1]
is_matrix = len(trf_shape) == 2 and trf_shape[0] in (self.ndims, self.ndims + 1) and trf_shape[
1] == self.ndims + 1
assert not (len(trf_shape) == 1 or is_matrix), "Invalid transformation. Expected a dense displacement map"
# check sizes
if trf_shape[-1] != self.ndims:
raise Exception('Offset flow field size expected: %d, found: %d'
% (self.ndims, trf_shape[-1]))
# confirm built
self.built = True
def call(self, inputs):
"""
Parameters
inputs: list with two entries
"""
# check shapes
assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
vol = inputs[0]
trf = inputs[1]
# necessary for multi_gpu models...
vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
trf = K.reshape(trf, [-1, *self.inshape[1][1:]])
# prepare location shift
if self.indexing == 'xy': # shift the first two dimensions
trf_split = tf.split(trf, trf.shape[-1], axis=-1)
trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]]
trf = tf.concat(trf_lst, -1)
# map transform across batch
if self.single_transform:
fn = lambda x: self._single_transform([x, trf[0, :]])
return tf.map_fn(fn, vol, dtype=tf.float32)
else:
return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)
def _single_transform(self, inputs):
return self._transform(inputs[0], inputs[1], interp_method=self.interp_method, fill_value=self.fill_value)
def _transform(self, vol, loc_shift, interp_method='linear', indexing='ij', fill_value=None):
"""
transform (interpolation N-D volumes (features) given shifts at each location in tensorflow
Essentially interpolates volume vol at locations determined by loc_shift.
This is a spatial transform in the sense that at location [x] we now have the data from,
[x + shift] so we've moved data.
Parameters:
vol: volume with size vol_shape or [*vol_shape, nb_features]
loc_shift: shift volume [*new_vol_shape, N]
interp_method (default:'linear'): 'linear', 'nearest'
indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
In general, prefer to leave this 'ij'
fill_value (default: None): value to use for points outside the domain.
If None, the nearest neighbors will be used.
Return:
new interpolated volumes in the same size as loc_shift[0]
Keyworks:
interpolation, sampler, resampler, linear, bilinear
"""
# parse shapes
if isinstance(loc_shift.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
volshape = loc_shift.shape[:-1].as_list()
else:
volshape = loc_shift.shape[:-1]
nb_dims = len(volshape)
# location should be mesh and delta
mesh = ne.utils.volshape_to_meshgrid(volshape, indexing=indexing) # volume mesh
loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]
# test single
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
if __name__ == "__main__":
output_file = './spatialtransformer.h5'
in_dm = tf.keras.Input(DISP_MAP_SHAPE)
in_image = tf.keras.Input(IMG_SHAPE)
pred = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([in_image, in_dm])
model = tf.keras.Model(inputs=[in_image, in_dm], outputs=pred)
model.save(output_file)
print(f"SpatialTransformer layer saved in: {output_file}")
|