File size: 6,310 Bytes
ab9857f
 
 
 
8fb4d8e
ab9857f
 
9b94746
081d04d
9b94746
 
 
 
 
 
 
081d04d
9b94746
ab9857f
 
 
081d04d
ab9857f
 
 
 
 
081d04d
 
 
 
 
 
 
ab9857f
 
0902b38
ab9857f
c44df3d
ab9857f
 
081d04d
ab9857f
 
0902b38
 
 
 
 
 
ab9857f
081d04d
9b94746
ab9857f
 
 
 
 
0902b38
ab9857f
8fb4d8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0902b38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import tensorflow as tf
from scipy.ndimage import generate_binary_structure

from DeepDeformationMapRegistration.utils.operators import soft_threshold
from DeepDeformationMapRegistration.utils.constants import EPS_tf


class HausdorffDistanceErosion:
    def __init__(self, ndim=3, nerosion=10, value_per_channel=False):
        """
        Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
            Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
            Networks". IEEE Transactions on Medical Imaging, 39, 2020. DOI 10.1109/TMI.2019.2930068

        :param ndim: Dimensionality of the images
        :param nerosion: Number of erosion steps. Defaults to 10.
        :param value_per_channel: Return an array with the HD distance computed on each channel independently or the sum
        """
        self.ndims = ndim
        self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
        self.nerosions = nerosion
        self.sum_range = tf.range(0, self.ndims) if value_per_channel else None

    def _erode(self, in_tensor, kernel):
        out = 1. - tf.squeeze(self.conv(tf.expand_dims(1. - in_tensor, 0), kernel, [1] * (self.ndims + 2), 'SAME'), axis=0)
        return soft_threshold(out, 0.5, name='soft_thresholding')

    def _erode_per_channel(self, in_tensor, kernel):
        # In the lambda function we add a fictitious channel and then remove it, so the final shape is [1, H, W, D]
        er_tensor = tf.map_fn(lambda tens: tf.squeeze(self._erode(tf.expand_dims(tens, -1), kernel)),
                         tf.transpose(in_tensor, [3, 0, 1, 2]), tf.float32)  # Iterate along the channel dimension (3)

        return tf.transpose(er_tensor, [1, 2, 3, 0])  # move the channels back to the end

    def _erosion_distance_single(self, y_true, y_pred):
        diff = tf.math.pow(y_pred - y_true, 2)
        alpha = 2

        norm = 1 / (self.ndims * 2 + 1)
        kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
        kernel = tf.constant(kernel, tf.float32)
        kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1)   # [H, W, D, C_in, C_out]

        ret = 0.
        for k in range(1, self.nerosions+1):
            er = diff
            # k successive erosions
            for j in range(k):
                er = self._erode_per_channel(er, kernel)
            ret += tf.reduce_sum(tf.multiply(er, tf.cast(tf.pow(k, alpha), tf.float32)), self.sum_range)

        img_vol = tf.cast(tf.reduce_prod(tf.shape(y_true)[:-1]), tf.float32)  # Volume of each channel
        return tf.divide(ret, img_vol)  # Divide by the image size

    def loss(self, y_true, y_pred):
        batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
                                 dtype=tf.float32)

        return tf.reduce_mean(batched_dist)


class NCC:
    def __init__(self, in_shape, eps=EPS_tf):
        self.__shape_size = tf.cast(tf.reduce_prod(in_shape), tf.float32)
        self.__eps = eps

    def ncc(self, y_true, y_pred):
        f_yt = tf.reshape(y_true, [-1])
        f_yp = tf.reshape(y_pred, [-1])
        mean_yt = tf.reduce_mean(f_yt)
        mean_yp = tf.reduce_mean(f_yp)
        std_yt = tf.math.reduce_std(f_yt)
        std_yp = tf.math.reduce_std(f_yp)

        n_f_yt = f_yt - mean_yt
        n_f_yp = f_yp - mean_yp
        numerator = tf.reduce_sum(n_f_yt * n_f_yp)
        denominator = std_yt * std_yp * self.__shape_size + self.__eps
        return tf.math.divide_no_nan(numerator, denominator)

    def loss(self, y_true, y_pred):
        # According to the documentation, the loss returns a scalar
        # Ref: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
        return tf.reduce_mean(tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))


class StructuralSimilarity:
    # Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
    def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0):
        """
        Structural (Di)Similarity Index Measure:

        :param k1: Internal parameter. Defaults to 0.01
        :param k2: Internal parameter. Defaults to 0.02
        :param patch_size: Size of the extracted patches
        :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
        """
        self.__c1 = (k1 * dynamic_range) ** 2
        self.__c2 = (k2 * dynamic_range) ** 2
        self.__kernel_shape = [1] + [patch_size] * 3 + [1]
        stride = int(patch_size * (1 - overlap))
        self.__stride = [1] + [stride if stride else 1] * 3 + [1]
        self.__max_val = dynamic_range

    def __int_shape(self, x):
        return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)

    def ssim(self, y_true, y_pred):

        patches_true = tf.extract_volume_patches(y_true, self.__kernel_shape, self.__stride, 'VALID',
                                                 'patches_true')
        patches_pred = tf.extract_volume_patches(y_pred, self.__kernel_shape, self.__stride, 'VALID',
                                                 'patches_pred')

        #bs, w, h, d, *c = self.__int_shape(patches_pred)
        #patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
        #patches_pred = tf.reshape(patches_pred, [-1, w, h, d, tf.reduce_prod(c)])

        # Mean
        u_true = tf.reduce_mean(patches_true, axis=-1)
        u_pred = tf.reduce_mean(patches_pred, axis=-1)

        # Variance
        v_true = tf.math.reduce_variance(patches_true, axis=-1)
        v_pred = tf.math.reduce_variance(patches_pred, axis=-1)

        # Covariance
        covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred

        # SSIM
        numerator = (2 * u_true * u_pred + self.__c1) * (2 * covar + self.__c2)
        denominator = ((tf.square(u_true) + tf.square(u_pred) + self.__c1) * (v_pred + v_true + self.__c2))
        ssim = numerator / denominator

        return tf.reduce_mean(ssim)

    def dssim(self, y_true, y_pred):
        return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)