File size: 15,471 Bytes
74c6a32
 
 
 
 
 
 
 
 
 
 
 
a27d55f
 
 
74c6a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78ae283
74c6a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5764e7
 
 
74c6a32
 
 
 
e5764e7
 
74c6a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5764e7
74c6a32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import os, sys

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)  # PYTHON > 3.3 does not allow relative referencing

PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'

import tensorflow.keras.layers as kl
import tensorflow as tf
from tensorflow.python.framework.errors import InvalidArgumentError

from ddmr.utils.operators import soft_threshold, gaussian_kernel, sample_unique
import ddmr.utils.constants as C
from ddmr.utils.thin_plate_splines import ThinPlateSplines
from voxelmorph.tf.layers import SpatialTransformer


class AugmentationLayer(kl.Layer):
    def __init__(self,
                 max_deformation,
                 max_displacement,
                 max_rotation,
                 num_control_points,
                 in_img_shape,
                 out_img_shape,
                 num_augmentations=1,
                 gamma_augmentation=True,
                 brightness_augmentation=True,
                 only_image=False,
                 only_resize=True,
                 return_displacement_map=False,
                 **kwargs):
        super(AugmentationLayer, self).__init__(**kwargs)

        self.max_deformation = max_deformation
        self.max_displacement = max_displacement
        self.max_rotation = max_rotation
        self.num_control_points = num_control_points
        self.num_augmentations = num_augmentations
        self.in_img_shape = in_img_shape
        self.out_img_shape = out_img_shape
        self.only_image = only_image
        self.return_disp_map = return_displacement_map

        self.do_gamma_augm = gamma_augmentation
        self.do_brightness_augm = brightness_augmentation

        grid = C.CoordinatesGrid()
        grid.set_coords_grid(in_img_shape, [C.TPS_NUM_CTRL_PTS_PER_AXIS] * 3)
        self.control_grid = tf.identity(grid.grid_flat(), name='control_grid')
        self.target_grid = tf.identity(grid.grid_flat(), name='target_grid')

        grid.set_coords_grid(in_img_shape, in_img_shape)
        self.fine_grid = tf.identity(grid.grid_flat(), 'fine_grid')

        if out_img_shape is not None:
            self.downsample_factor = [i // o for o, i in zip(out_img_shape, in_img_shape)]
            self.img_gauss_filter = gaussian_kernel(3, 0.001, 1, 1, 3)
            # self.resize_transf = tf.diag([*self.downsample_factor, 1])[:-1, :]
            # self.resize_transf = tf.expand_dims(tf.reshape(self.resize_transf, [-1]), 0, name='resize_transformation')  # ST expects a (12,) vector

        self.augment = not only_resize

    def compute_output_shape(self, input_shape):
        input_shape = tf.TensorShape(input_shape).as_list()
        img_shape = (input_shape[0], *self.out_img_shape, 1)
        seg_shape = (input_shape[0], *self.out_img_shape, input_shape[-1] - 1)
        disp_shape = (input_shape[0], *self.out_img_shape, 3)
        # Expect the input to have the image and segmentations in the same tensor
        if self.return_disp_map:
            return (img_shape, img_shape, seg_shape, seg_shape, disp_shape)
        else:
            return (img_shape, img_shape, seg_shape, seg_shape)

    #@tf.custom_gradient
    def call(self, in_data, training=None):
        # def custom_grad(in_grad):
        #     return tf.ones_like(in_grad)
        if training is not None:
            self.augment = training
        return self.build_batch(in_data)# , custom_grad

    def build_batch(self, fix_data: tf.Tensor):
        if len(fix_data.get_shape().as_list()) < 5:
            fix_data = tf.expand_dims(fix_data, axis=0)  # Add Batch dimension
        # fix_data = tf.tile(fix_data, (self.num_augmentations, *(1,)*4))
        fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map = tf.map_fn(lambda x: self.augment_sample(x),
                                                                                         fix_data,
                                                                                         dtype=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
        # map_fn unstacks elems on axis 0
        if self.return_disp_map:
            return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map
        else:
            return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch

    def augment_sample(self, fix_data: tf.Tensor):
        if self.only_image or not self.augment:
            fix_img = fix_data
            fix_segm = tf.zeros_like(fix_data, dtype=tf.float32)
        else:
            fix_img = fix_data[..., 0]
            fix_img = tf.expand_dims(fix_img, -1)
            fix_segm = fix_data[..., 1:]        # We expect several segmentation masks

        if self.augment:
            # If we are training, do the full-fledged augmentation
            fix_img = self.min_max_normalization(fix_img)

            mov_img, mov_segm, disp_map = self.deform_image(tf.squeeze(fix_img), fix_segm)
            mov_img = tf.expand_dims(mov_img, -1)  # Add the removed channel axis

        # Resample to output_shape
            if self.out_img_shape is not None:
                fix_img = self.downsize_image(fix_img)
                mov_img = self.downsize_image(mov_img)

                fix_segm = self.downsize_segmentation(fix_segm)
                mov_segm = self.downsize_segmentation(mov_segm)

                disp_map = self.downsize_displacement_map(disp_map)

            if self.do_gamma_augm:
                fix_img = self.gamma_augmentation(fix_img)
                mov_img = self.gamma_augmentation(mov_img)

            if self.do_brightness_augm:
                fix_img = self.brightness_augmentation(fix_img)
                mov_img = self.brightness_augmentation(mov_img)

        else:
            # During inference, just resize the input images
            mov_img = tf.zeros_like(fix_img)
            mov_segm = tf.zeros_like(fix_segm)

            disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3])  # TODO: change, don't use tile!!

            if self.out_img_shape is not None:
                fix_img = self.downsize_image(fix_img)
                mov_img = self.downsize_image(mov_img)

                fix_segm = self.downsize_segmentation(fix_segm)
                mov_segm = self.downsize_segmentation(mov_segm)

                disp_map = self.downsize_displacement_map(disp_map)

        fix_img = self.min_max_normalization(fix_img)
        mov_img = self.min_max_normalization(mov_img)
        return fix_img, mov_img, fix_segm, mov_segm, disp_map

    def downsize_image(self, img):
        img = tf.expand_dims(img, axis=0)
        # The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
        img = tf.nn.conv3d(img, self.img_gauss_filter, strides=[1, ] * 5, padding='SAME', data_format='NDHWC')
        img = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(img)

        return tf.squeeze(img, axis=0)

    def downsize_segmentation(self, segm):
        segm = tf.expand_dims(segm, axis=0)
        segm = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(segm)

        segm = tf.cast(segm, tf.float32)
        return tf.squeeze(segm, axis=0)

    def downsize_displacement_map(self, disp_map):
        disp_map = tf.expand_dims(disp_map, axis=0)
        # The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
        disp_map = tf.layers.AveragePooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(disp_map)

        # self.downsample_factor = in_shape / out_shape, but here we need out_shape / in_shape. Hence, 1 / factor
        if self.downsample_factor[0] != self.downsample_factor[1] != self.downsample_factor[2]:
            # Downsize the displacement magnitude along the different axes
            disp_map_x = disp_map[..., 0] * 1 / self.downsample_factor[0]
            disp_map_y = disp_map[..., 1] * 1 / self.downsample_factor[1]
            disp_map_z = disp_map[..., 2] * 1 / self.downsample_factor[2]

            disp_map = tf.stack([disp_map_x, disp_map_y, disp_map_z], axis=-1)
        else:
            disp_map = disp_map * 1 / self.downsample_factor[0]

        return tf.squeeze(disp_map, axis=0)

    def gamma_augmentation(self, in_img: tf.Tensor):
        in_img += 1e-5  # To prevent NaNs
        f = tf.random.uniform((), -1, 1, tf.float32)  # gamma [0.5, 2]
        gamma = tf.pow(2.0, f)

        return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)

    def brightness_augmentation(self, in_img: tf.Tensor):
        c = tf.random.uniform((), -0.2, 0.2, tf.float32)  # 20% shift
        return tf.clip_by_value(c + in_img, 0, 1)

    def min_max_normalization(self, in_img: tf.Tensor):
        return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
                      tf.subtract(tf.reduce_max(in_img), tf.reduce_min(in_img)))

    def deform_image(self, fix_img: tf.Tensor, fix_segm: tf.Tensor):
        # Get locations where the intensity > 0.0
        idx_points_in_label = tf.where(tf.greater(fix_img, 0.0))

        # Randomly select N points
        # random_idx = tf.random.uniform((self.num_control_points,),
        #                                          minval=0, maxval=tf.shape(idx_points_in_label)[0],
        #                                          dtype=tf.int32)
        #
        # disp_location = tf.gather(idx_points_in_label, random_idx)  # And get the coordinates
        # disp_location = tf.cast(disp_location, tf.float32)
        disp_location = sample_unique(idx_points_in_label, self.num_control_points, tf.float32)

        # Get the coordinates of the control point displaces
        rand_disp = tf.random.uniform((self.num_control_points, 3), minval=-1, maxval=1, dtype=tf.float32) * self.max_deformation
        warped_location = disp_location + rand_disp

        # Add the selected locations to the control grid and the warped locations to the target grid
        control_grid = tf.concat([self.control_grid, disp_location], axis=0)
        trg_grid = tf.concat([self.control_grid, warped_location], axis=0)

        # Apply global transformation
        valid_trf = False
        while not valid_trf:
            trg_grid, aff = self.global_transformation(trg_grid)

            # Interpolate the displacement map
            try:
                tps = ThinPlateSplines(control_grid, trg_grid)
                def_grid = tps.interpolate(self.fine_grid)
            except InvalidArgumentError as err:
                # If the transformation raises a non-invertible error,
                # try again until we get a valid transformation
                tf.print('TPS non invertible matrix', output_stream=sys.stdout)
                continue
            else:
                valid_trf = True

        disp_map = self.fine_grid - def_grid
        disp_map = tf.reshape(disp_map, (*self.in_img_shape, -1))

        # Apply the displacement map
        fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
        fix_segm = tf.expand_dims(fix_segm, 0)
        disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)

        mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
        mov_segm = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([fix_segm, disp_map])

        mov_img = tf.where(tf.is_nan(mov_img), tf.zeros_like(mov_img), mov_img)
        mov_img = tf.where(tf.is_inf(mov_img), tf.zeros_like(mov_img), mov_img)

        mov_segm = tf.where(tf.is_nan(mov_segm), tf.zeros_like(mov_segm), mov_segm)
        mov_segm = tf.where(tf.is_inf(mov_segm), tf.zeros_like(mov_segm), mov_segm)

        return tf.squeeze(mov_img), tf.squeeze(mov_segm, axis=0), tf.squeeze(disp_map, axis=0)

    def global_transformation(self, points: tf.Tensor):
        axis = tf.random.uniform((), 0, 3)

        alpha = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 0.), tf.less_equal(axis, 1.)),
                        lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
                        lambda: tf.zeros((), tf.float32))
        beta = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 1.), tf.less_equal(axis, 2.)),
                       lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
                       lambda: tf.zeros((), tf.float32))
        gamma = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 2.), tf.less_equal(axis, 3.)),
                        lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
                        lambda: tf.zeros((), tf.float32))

        ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
        tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
        tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement

        M = self.build_affine_transformation(tf.convert_to_tensor(self.in_img_shape, tf.float32),
                                             alpha, beta, gamma, ti, tj, tk)

        points = tf.transpose(points)
        new_pts = tf.matmul(M[:3, :3], points)
        new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
        return tf.transpose(new_pts), M

    @staticmethod
    def build_affine_transformation(img_shape, alpha, beta, gamma, ti, tj, tk):
        img_centre = tf.divide(img_shape, 2.)

        # Rotation matrix around the image centre
        # R* = T(p) R(ang) T(-p)
        # tf.cos and tf.sin expect radians

        T = tf.convert_to_tensor([[1, 0, 0, ti],
                                  [0, 1, 0, tj],
                                  [0, 0, 1, tk],
                                  [0, 0, 0, 1]], tf.float32)

        Ri = tf.convert_to_tensor([[1, 0, 0, 0],
                                   [0, tf.math.cos(alpha), -tf.math.sin(alpha), 0],
                                   [0, tf.math.sin(alpha),  tf.math.cos(alpha), 0],
                                   [0, 0, 0, 1]], tf.float32)

        Rj = tf.convert_to_tensor([[ tf.math.cos(beta), 0, tf.math.sin(beta), 0],
                                   [0, 1, 0, 0],
                                   [-tf.math.sin(beta), 0, tf.math.cos(beta), 0],
                                   [0, 0, 0, 1]], tf.float32)

        Rk = tf.convert_to_tensor([[tf.math.cos(gamma), -tf.math.sin(gamma), 0, 0],
                                   [tf.math.sin(gamma),  tf.math.cos(gamma), 0, 0],
                                   [0, 0, 1, 0],
                                   [0, 0, 0, 1]], tf.float32)

        R = tf.matmul(tf.matmul(Ri, Rj), Rk)

        Tc = tf.convert_to_tensor([[1, 0, 0, img_centre[0]],
                                   [0, 1, 0, img_centre[1]],
                                   [0, 0, 1, img_centre[2]],
                                   [0, 0, 0, 1]], tf.float32)

        Tc_ = tf.convert_to_tensor([[1, 0, 0, -img_centre[0]],
                                    [0, 1, 0, -img_centre[1]],
                                    [0, 0, 1, -img_centre[2]],
                                    [0, 0, 0, 1]], tf.float32)

        return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))

    def get_config(self):
        config = super(AugmentationLayer, self).get_config()
        return config