|
import os, sys |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True' |
|
|
|
import tensorflow.keras.layers as kl |
|
import tensorflow as tf |
|
from tensorflow.python.framework.errors import InvalidArgumentError |
|
|
|
from ddmr.utils.operators import soft_threshold, gaussian_kernel, sample_unique |
|
import ddmr.utils.constants as C |
|
from ddmr.utils.thin_plate_splines import ThinPlateSplines |
|
from voxelmorph.tf.layers import SpatialTransformer |
|
|
|
|
|
class AugmentationLayer(kl.Layer): |
|
def __init__(self, |
|
max_deformation, |
|
max_displacement, |
|
max_rotation, |
|
num_control_points, |
|
in_img_shape, |
|
out_img_shape, |
|
num_augmentations=1, |
|
gamma_augmentation=True, |
|
brightness_augmentation=True, |
|
only_image=False, |
|
only_resize=True, |
|
return_displacement_map=False, |
|
**kwargs): |
|
super(AugmentationLayer, self).__init__(**kwargs) |
|
|
|
self.max_deformation = max_deformation |
|
self.max_displacement = max_displacement |
|
self.max_rotation = max_rotation |
|
self.num_control_points = num_control_points |
|
self.num_augmentations = num_augmentations |
|
self.in_img_shape = in_img_shape |
|
self.out_img_shape = out_img_shape |
|
self.only_image = only_image |
|
self.return_disp_map = return_displacement_map |
|
|
|
self.do_gamma_augm = gamma_augmentation |
|
self.do_brightness_augm = brightness_augmentation |
|
|
|
grid = C.CoordinatesGrid() |
|
grid.set_coords_grid(in_img_shape, [C.TPS_NUM_CTRL_PTS_PER_AXIS] * 3) |
|
self.control_grid = tf.identity(grid.grid_flat(), name='control_grid') |
|
self.target_grid = tf.identity(grid.grid_flat(), name='target_grid') |
|
|
|
grid.set_coords_grid(in_img_shape, in_img_shape) |
|
self.fine_grid = tf.identity(grid.grid_flat(), 'fine_grid') |
|
|
|
if out_img_shape is not None: |
|
self.downsample_factor = [i // o for o, i in zip(out_img_shape, in_img_shape)] |
|
self.img_gauss_filter = gaussian_kernel(3, 0.001, 1, 1, 3) |
|
|
|
|
|
|
|
self.augment = not only_resize |
|
|
|
def compute_output_shape(self, input_shape): |
|
input_shape = tf.TensorShape(input_shape).as_list() |
|
img_shape = (input_shape[0], *self.out_img_shape, 1) |
|
seg_shape = (input_shape[0], *self.out_img_shape, input_shape[-1] - 1) |
|
disp_shape = (input_shape[0], *self.out_img_shape, 3) |
|
|
|
if self.return_disp_map: |
|
return (img_shape, img_shape, seg_shape, seg_shape, disp_shape) |
|
else: |
|
return (img_shape, img_shape, seg_shape, seg_shape) |
|
|
|
|
|
def call(self, in_data, training=None): |
|
|
|
|
|
if training is not None: |
|
self.augment = training |
|
return self.build_batch(in_data) |
|
|
|
def build_batch(self, fix_data: tf.Tensor): |
|
if len(fix_data.get_shape().as_list()) < 5: |
|
fix_data = tf.expand_dims(fix_data, axis=0) |
|
|
|
fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map = tf.map_fn(lambda x: self.augment_sample(x), |
|
fix_data, |
|
dtype=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32)) |
|
|
|
if self.return_disp_map: |
|
return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map |
|
else: |
|
return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch |
|
|
|
def augment_sample(self, fix_data: tf.Tensor): |
|
if self.only_image or not self.augment: |
|
fix_img = fix_data |
|
fix_segm = tf.zeros_like(fix_data, dtype=tf.float32) |
|
else: |
|
fix_img = fix_data[..., 0] |
|
fix_img = tf.expand_dims(fix_img, -1) |
|
fix_segm = fix_data[..., 1:] |
|
|
|
if self.augment: |
|
|
|
fix_img = self.min_max_normalization(fix_img) |
|
|
|
mov_img, mov_segm, disp_map = self.deform_image(tf.squeeze(fix_img), fix_segm) |
|
mov_img = tf.expand_dims(mov_img, -1) |
|
|
|
|
|
if self.out_img_shape is not None: |
|
fix_img = self.downsize_image(fix_img) |
|
mov_img = self.downsize_image(mov_img) |
|
|
|
fix_segm = self.downsize_segmentation(fix_segm) |
|
mov_segm = self.downsize_segmentation(mov_segm) |
|
|
|
disp_map = self.downsize_displacement_map(disp_map) |
|
|
|
if self.do_gamma_augm: |
|
fix_img = self.gamma_augmentation(fix_img) |
|
mov_img = self.gamma_augmentation(mov_img) |
|
|
|
if self.do_brightness_augm: |
|
fix_img = self.brightness_augmentation(fix_img) |
|
mov_img = self.brightness_augmentation(mov_img) |
|
|
|
else: |
|
|
|
mov_img = tf.zeros_like(fix_img) |
|
mov_segm = tf.zeros_like(fix_segm) |
|
|
|
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3]) |
|
|
|
if self.out_img_shape is not None: |
|
fix_img = self.downsize_image(fix_img) |
|
mov_img = self.downsize_image(mov_img) |
|
|
|
fix_segm = self.downsize_segmentation(fix_segm) |
|
mov_segm = self.downsize_segmentation(mov_segm) |
|
|
|
disp_map = self.downsize_displacement_map(disp_map) |
|
|
|
fix_img = self.min_max_normalization(fix_img) |
|
mov_img = self.min_max_normalization(mov_img) |
|
return fix_img, mov_img, fix_segm, mov_segm, disp_map |
|
|
|
def downsize_image(self, img): |
|
img = tf.expand_dims(img, axis=0) |
|
|
|
img = tf.nn.conv3d(img, self.img_gauss_filter, strides=[1, ] * 5, padding='SAME', data_format='NDHWC') |
|
img = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(img) |
|
|
|
return tf.squeeze(img, axis=0) |
|
|
|
def downsize_segmentation(self, segm): |
|
segm = tf.expand_dims(segm, axis=0) |
|
segm = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(segm) |
|
|
|
segm = tf.cast(segm, tf.float32) |
|
return tf.squeeze(segm, axis=0) |
|
|
|
def downsize_displacement_map(self, disp_map): |
|
disp_map = tf.expand_dims(disp_map, axis=0) |
|
|
|
disp_map = tf.layers.AveragePooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(disp_map) |
|
|
|
|
|
if self.downsample_factor[0] != self.downsample_factor[1] != self.downsample_factor[2]: |
|
|
|
disp_map_x = disp_map[..., 0] * 1 / self.downsample_factor[0] |
|
disp_map_y = disp_map[..., 1] * 1 / self.downsample_factor[1] |
|
disp_map_z = disp_map[..., 2] * 1 / self.downsample_factor[2] |
|
|
|
disp_map = tf.stack([disp_map_x, disp_map_y, disp_map_z], axis=-1) |
|
else: |
|
disp_map = disp_map * 1 / self.downsample_factor[0] |
|
|
|
return tf.squeeze(disp_map, axis=0) |
|
|
|
def gamma_augmentation(self, in_img: tf.Tensor): |
|
in_img += 1e-5 |
|
f = tf.random.uniform((), -1, 1, tf.float32) |
|
gamma = tf.pow(2.0, f) |
|
|
|
return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1) |
|
|
|
def brightness_augmentation(self, in_img: tf.Tensor): |
|
c = tf.random.uniform((), -0.2, 0.2, tf.float32) |
|
return tf.clip_by_value(c + in_img, 0, 1) |
|
|
|
def min_max_normalization(self, in_img: tf.Tensor): |
|
return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)), |
|
tf.subtract(tf.reduce_max(in_img), tf.reduce_min(in_img))) |
|
|
|
def deform_image(self, fix_img: tf.Tensor, fix_segm: tf.Tensor): |
|
|
|
idx_points_in_label = tf.where(tf.greater(fix_img, 0.0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disp_location = sample_unique(idx_points_in_label, self.num_control_points, tf.float32) |
|
|
|
|
|
rand_disp = tf.random.uniform((self.num_control_points, 3), minval=-1, maxval=1, dtype=tf.float32) * self.max_deformation |
|
warped_location = disp_location + rand_disp |
|
|
|
|
|
control_grid = tf.concat([self.control_grid, disp_location], axis=0) |
|
trg_grid = tf.concat([self.control_grid, warped_location], axis=0) |
|
|
|
|
|
valid_trf = False |
|
while not valid_trf: |
|
trg_grid, aff = self.global_transformation(trg_grid) |
|
|
|
|
|
try: |
|
tps = ThinPlateSplines(control_grid, trg_grid) |
|
def_grid = tps.interpolate(self.fine_grid) |
|
except InvalidArgumentError as err: |
|
|
|
|
|
tf.print('TPS non invertible matrix', output_stream=sys.stdout) |
|
continue |
|
else: |
|
valid_trf = True |
|
|
|
disp_map = self.fine_grid - def_grid |
|
disp_map = tf.reshape(disp_map, (*self.in_img_shape, -1)) |
|
|
|
|
|
fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0) |
|
fix_segm = tf.expand_dims(fix_segm, 0) |
|
disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32) |
|
|
|
mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map]) |
|
mov_segm = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([fix_segm, disp_map]) |
|
|
|
mov_img = tf.where(tf.is_nan(mov_img), tf.zeros_like(mov_img), mov_img) |
|
mov_img = tf.where(tf.is_inf(mov_img), tf.zeros_like(mov_img), mov_img) |
|
|
|
mov_segm = tf.where(tf.is_nan(mov_segm), tf.zeros_like(mov_segm), mov_segm) |
|
mov_segm = tf.where(tf.is_inf(mov_segm), tf.zeros_like(mov_segm), mov_segm) |
|
|
|
return tf.squeeze(mov_img), tf.squeeze(mov_segm, axis=0), tf.squeeze(disp_map, axis=0) |
|
|
|
def global_transformation(self, points: tf.Tensor): |
|
axis = tf.random.uniform((), 0, 3) |
|
|
|
alpha = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 0.), tf.less_equal(axis, 1.)), |
|
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation), |
|
lambda: tf.zeros((), tf.float32)) |
|
beta = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 1.), tf.less_equal(axis, 2.)), |
|
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation), |
|
lambda: tf.zeros((), tf.float32)) |
|
gamma = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 2.), tf.less_equal(axis, 3.)), |
|
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation), |
|
lambda: tf.zeros((), tf.float32)) |
|
|
|
ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement |
|
tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement |
|
tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement |
|
|
|
M = self.build_affine_transformation(tf.convert_to_tensor(self.in_img_shape, tf.float32), |
|
alpha, beta, gamma, ti, tj, tk) |
|
|
|
points = tf.transpose(points) |
|
new_pts = tf.matmul(M[:3, :3], points) |
|
new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts |
|
return tf.transpose(new_pts), M |
|
|
|
@staticmethod |
|
def build_affine_transformation(img_shape, alpha, beta, gamma, ti, tj, tk): |
|
img_centre = tf.divide(img_shape, 2.) |
|
|
|
|
|
|
|
|
|
|
|
T = tf.convert_to_tensor([[1, 0, 0, ti], |
|
[0, 1, 0, tj], |
|
[0, 0, 1, tk], |
|
[0, 0, 0, 1]], tf.float32) |
|
|
|
Ri = tf.convert_to_tensor([[1, 0, 0, 0], |
|
[0, tf.math.cos(alpha), -tf.math.sin(alpha), 0], |
|
[0, tf.math.sin(alpha), tf.math.cos(alpha), 0], |
|
[0, 0, 0, 1]], tf.float32) |
|
|
|
Rj = tf.convert_to_tensor([[ tf.math.cos(beta), 0, tf.math.sin(beta), 0], |
|
[0, 1, 0, 0], |
|
[-tf.math.sin(beta), 0, tf.math.cos(beta), 0], |
|
[0, 0, 0, 1]], tf.float32) |
|
|
|
Rk = tf.convert_to_tensor([[tf.math.cos(gamma), -tf.math.sin(gamma), 0, 0], |
|
[tf.math.sin(gamma), tf.math.cos(gamma), 0, 0], |
|
[0, 0, 1, 0], |
|
[0, 0, 0, 1]], tf.float32) |
|
|
|
R = tf.matmul(tf.matmul(Ri, Rj), Rk) |
|
|
|
Tc = tf.convert_to_tensor([[1, 0, 0, img_centre[0]], |
|
[0, 1, 0, img_centre[1]], |
|
[0, 0, 1, img_centre[2]], |
|
[0, 0, 0, 1]], tf.float32) |
|
|
|
Tc_ = tf.convert_to_tensor([[1, 0, 0, -img_centre[0]], |
|
[0, 1, 0, -img_centre[1]], |
|
[0, 0, 1, -img_centre[2]], |
|
[0, 0, 0, 1]], tf.float32) |
|
|
|
return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_))) |
|
|
|
def get_config(self): |
|
config = super(AugmentationLayer, self).get_config() |
|
return config |
|
|
|
|
|
|
|
|