File size: 8,125 Bytes
61f0e36
 
 
 
 
a27d55f
5b0dbe4
61f0e36
ec717e8
61f0e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b0dbe4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import tensorflow.keras.layers as kl
import tensorflow.keras.backend as K
import tensorflow as tf
import neurite as ne

from ddmr.utils.constants import IMG_SHAPE, DISP_MAP_SHAPE


@tf.keras.utils.register_keras_serializable()
class SpatialTransformer(kl.Layer):
    """
    Adapted SpatialTransformer layer taken from VoxelMorph v0.1
    https://github.com/voxelmorph/voxelmorph/blob/dev/voxelmorph/tf/layers.py
    Removed unused options to ease portability

    N-D Spatial Transformer Tensorflow / Keras Layer

    The Layer can handle ONLY dense transforms.
    Transforms are meant to give a 'shift' from the current position.
    Therefore, a dense transform gives displacements (not absolute locations) at each voxel.

    If you find this function useful, please cite:
      Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
      Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
      MICCAI 2018.

    Originally, this code was based on voxelmorph code, which
    was in turn transformed to be dense with the help of (affine) STN code
    via https://github.com/kevinzakka/spatial-transformer-network

    Since then, we've re-written the code to be generalized to any
    dimensions, and along the way wrote grid and interpolation functions
    """

    def __init__(self,
                 interp_method='linear',
                 indexing='ij',
                 single_transform=False,
                 fill_value=None,
                 add_identity=True,
                 shift_center=True,
                 **kwargs):
        """
        Parameters:
            interp_method: 'linear' or 'nearest'
            single_transform: whether a single transform supplied for the whole batch
            indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian)
                'xy' indexing will have the first two entries of the flow
                (along last axis) flipped compared to 'ij' indexing
            fill_value (default: None): value to use for points outside the domain.
                If None, the nearest neighbors will be used.
            add_identity (default: True): whether the identity matrix is added
                to affine transforms.
            shift_center (default: True): whether the grid is shifted to the center
                of the image when converting affine transforms to warp fields.
        """
        self.interp_method = interp_method
        self.fill_value = fill_value
        self.add_identity = add_identity
        self.shift_center = shift_center
        self.ndims = None
        self.inshape = None
        self.single_transform = single_transform

        assert indexing in ['ij', 'xy'], "indexing has to be 'ij' (matrix) or 'xy' (cartesian)"
        self.indexing = indexing

        super(self.__class__, self).__init__(**kwargs)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'interp_method': self.interp_method,
            'indexing': self.indexing,
            'single_transform': self.single_transform,
            'fill_value': self.fill_value,
            'add_identity': self.add_identity,
            'shift_center': self.shift_center,
        })
        return config

    def build(self, input_shape):
        """
        input_shape should be a list for two inputs:
        input1: image.
        input2: transform Tensor
            if affine:
                should be a N x N+1 matrix
                *or* a N*N+1 tensor (which will be reshape to N x (N+1) and an identity row added)
            if not affine:
                should be a *vol_shape x N
        """

        if len(input_shape) > 2:
            raise Exception('Spatial Transformer must be called on a list of length 2.'
                            'First argument is the image, second is the transform.')

        # set up number of dimensions
        self.ndims = len(input_shape[0]) - 2
        self.inshape = input_shape
        vol_shape = input_shape[0][1:-1]
        trf_shape = input_shape[1][1:]

        # the transform is an affine iff:
        # it's a 1D Tensor [dense transforms need to be at least ndims + 1]
        # it's a 2D Tensor and shape == [N+1, N+1] or [N, N+1]
        #   [dense with N=1, which is the only one that could have a transform shape of 2, would be of size Mx1]
        is_matrix = len(trf_shape) == 2 and trf_shape[0] in (self.ndims, self.ndims + 1) and trf_shape[
            1] == self.ndims + 1
        assert not (len(trf_shape) == 1 or is_matrix), "Invalid transformation. Expected a dense displacement map"

        # check sizes
        if trf_shape[-1] != self.ndims:
            raise Exception('Offset flow field size expected: %d, found: %d'
                            % (self.ndims, trf_shape[-1]))

        # confirm built
        self.built = True

    def call(self, inputs):
        """
        Parameters
            inputs: list with two entries
        """

        # check shapes
        assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
        vol = inputs[0]
        trf = inputs[1]

        # necessary for multi_gpu models...
        vol = K.reshape(vol, [-1, *self.inshape[0][1:]])
        trf = K.reshape(trf, [-1, *self.inshape[1][1:]])

        # prepare location shift
        if self.indexing == 'xy':  # shift the first two dimensions
            trf_split = tf.split(trf, trf.shape[-1], axis=-1)
            trf_lst = [trf_split[1], trf_split[0], *trf_split[2:]]
            trf = tf.concat(trf_lst, -1)

        # map transform across batch
        if self.single_transform:
            fn = lambda x: self._single_transform([x, trf[0, :]])
            return tf.map_fn(fn, vol, dtype=tf.float32)
        else:
            return tf.map_fn(self._single_transform, [vol, trf], dtype=tf.float32)

    def _single_transform(self, inputs):
        return self._transform(inputs[0], inputs[1], interp_method=self.interp_method, fill_value=self.fill_value)

    def _transform(self, vol, loc_shift, interp_method='linear', indexing='ij', fill_value=None):
        """
        transform (interpolation N-D volumes (features) given shifts at each location in tensorflow

        Essentially interpolates volume vol at locations determined by loc_shift.
        This is a spatial transform in the sense that at location [x] we now have the data from,
        [x + shift] so we've moved data.

        Parameters:
            vol: volume with size vol_shape or [*vol_shape, nb_features]
            loc_shift: shift volume [*new_vol_shape, N]
            interp_method (default:'linear'): 'linear', 'nearest'
            indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
                In general, prefer to leave this 'ij'
            fill_value (default: None): value to use for points outside the domain.
                If None, the nearest neighbors will be used.

        Return:
            new interpolated volumes in the same size as loc_shift[0]

        Keyworks:
            interpolation, sampler, resampler, linear, bilinear
        """

        # parse shapes

        if isinstance(loc_shift.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
            volshape = loc_shift.shape[:-1].as_list()
        else:
            volshape = loc_shift.shape[:-1]
        nb_dims = len(volshape)

        # location should be mesh and delta
        mesh = ne.utils.volshape_to_meshgrid(volshape, indexing=indexing)  # volume mesh
        loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]

        # test single
        return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)


if __name__ == "__main__":
    output_file = './spatialtransformer.h5'

    in_dm = tf.keras.Input(DISP_MAP_SHAPE)
    in_image = tf.keras.Input(IMG_SHAPE)
    pred = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([in_image, in_dm])

    model = tf.keras.Model(inputs=[in_image, in_dm], outputs=pred)

    model.save(output_file)
    print(f"SpatialTransformer layer saved in: {output_file}")