jpdefrutos commited on
Commit
0902b38
·
1 Parent(s): 081d04d

Fxed Hausdorff implementation. The successive erosions were not applied to the eroded diff variable, but to the original diff variable (hence, no successive erosion was done at all)

Browse files

Implemented 3D Structural Similarity Metric Index, based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
The patch size, overlap and dynamic range can be modified at will.

DeepDeformationMapRegistration/losses.py CHANGED
@@ -34,7 +34,7 @@ class HausdorffDistanceErosion:
34
 
35
  def _erosion_distance_single(self, y_true, y_pred):
36
  diff = tf.math.pow(y_pred - y_true, 2)
37
- alpha = 2.
38
 
39
  norm = 1 / (self.ndims * 2 + 1)
40
  kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
@@ -42,10 +42,12 @@ class HausdorffDistanceErosion:
42
  kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) # [H, W, D, C_in, C_out]
43
 
44
  ret = 0.
45
- for i in range(self.nerosions):
46
- for j in range(i + 1):
47
- er = self._erode_per_channel(diff, kernel)
48
- ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)), self.sum_range)
 
 
49
 
50
  img_vol = tf.cast(tf.reduce_prod(tf.shape(y_true)[:-1]), tf.float32) # Volume of each channel
51
  return tf.divide(ret, img_vol) # Divide by the image size
@@ -54,7 +56,7 @@ class HausdorffDistanceErosion:
54
  batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
55
  dtype=tf.float32)
56
 
57
- return batched_dist
58
 
59
 
60
  class NCC:
@@ -77,4 +79,60 @@ class NCC:
77
  return tf.math.divide_no_nan(numerator, denominator)
78
 
79
  def loss(self, y_true, y_pred):
80
- return tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def _erosion_distance_single(self, y_true, y_pred):
36
  diff = tf.math.pow(y_pred - y_true, 2)
37
+ alpha = 2
38
 
39
  norm = 1 / (self.ndims * 2 + 1)
40
  kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
 
42
  kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1) # [H, W, D, C_in, C_out]
43
 
44
  ret = 0.
45
+ for k in range(1, self.nerosions+1):
46
+ er = diff
47
+ # k successive erosions
48
+ for j in range(k):
49
+ er = self._erode_per_channel(er, kernel)
50
+ ret += tf.reduce_sum(tf.multiply(er, tf.cast(tf.pow(k, alpha), tf.float32)), self.sum_range)
51
 
52
  img_vol = tf.cast(tf.reduce_prod(tf.shape(y_true)[:-1]), tf.float32) # Volume of each channel
53
  return tf.divide(ret, img_vol) # Divide by the image size
 
56
  batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
57
  dtype=tf.float32)
58
 
59
+ return tf.reduce_mean(batched_dist)
60
 
61
 
62
  class NCC:
 
79
  return tf.math.divide_no_nan(numerator, denominator)
80
 
81
  def loss(self, y_true, y_pred):
82
+ # According to the documentation, the loss returns a scalar
83
+ # Ref: https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile
84
+ return tf.reduce_mean(tf.map_fn(lambda x: 1 - self.ncc(x[0], x[1]), (y_true, y_pred), tf.float32))
85
+
86
+
87
+ class StructuralSimilarity:
88
+ # Based on https://github.com/keras-team/keras-contrib/blob/master/keras_contrib/losses/dssim.py
89
+ def __init__(self, k1=0.01, k2=0.03, patch_size=3, dynamic_range=1., overlap=0.0):
90
+ """
91
+ Structural (Di)Similarity Index Measure:
92
+
93
+ :param k1: Internal parameter. Defaults to 0.01
94
+ :param k2: Internal parameter. Defaults to 0.02
95
+ :param patch_size: Size of the extracted patches
96
+ :param dynamic_range: Maximum numerical intensity value (typ. 2^bits_per_pixel - 1). Defaults to 1.
97
+ """
98
+ self.__c1 = (k1 * dynamic_range) ** 2
99
+ self.__c2 = (k2 * dynamic_range) ** 2
100
+ self.__kernel_shape = [1] + [patch_size] * 3 + [1]
101
+ stride = int(patch_size * (1 - overlap))
102
+ self.__stride = [1] + [stride if stride else 1] * 3 + [1]
103
+ self.__max_val = dynamic_range
104
+
105
+ def __int_shape(self, x):
106
+ return tf.keras.backend.int_shape(x) if tf.keras.backend.backend() == 'tensorflow' else tf.keras.backend.shape(x)
107
+
108
+ def ssim(self, y_true, y_pred):
109
+
110
+ patches_true = tf.extract_volume_patches(y_true, self.__kernel_shape, self.__stride, 'VALID',
111
+ 'patches_true')
112
+ patches_pred = tf.extract_volume_patches(y_pred, self.__kernel_shape, self.__stride, 'VALID',
113
+ 'patches_pred')
114
+
115
+ #bs, w, h, d, *c = self.__int_shape(patches_pred)
116
+ #patches_true = tf.reshape(patches_true, [-1, w, h, d, tf.reduce_prod(c)])
117
+ #patches_pred = tf.reshape(patches_pred, [-1, w, h, d, tf.reduce_prod(c)])
118
+
119
+ # Mean
120
+ u_true = tf.reduce_mean(patches_true, axis=-1)
121
+ u_pred = tf.reduce_mean(patches_pred, axis=-1)
122
+
123
+ # Variance
124
+ v_true = tf.math.reduce_variance(patches_true, axis=-1)
125
+ v_pred = tf.math.reduce_variance(patches_pred, axis=-1)
126
+
127
+ # Covariance
128
+ covar = tf.reduce_mean(patches_true * patches_pred, axis=-1) - u_true * u_pred
129
+
130
+ # SSIM
131
+ numerator = (2 * u_true * u_pred + self.__c1) * (2 * covar + self.__c2)
132
+ denominator = ((tf.square(u_true) + tf.square(u_pred) + self.__c1) * (v_pred + v_true + self.__c2))
133
+ ssim = numerator / denominator
134
+
135
+ return tf.reduce_mean(ssim)
136
+
137
+ def dssim(self, y_true, y_pred):
138
+ return tf.reduce_mean((1. - self.ssim(y_true, y_pred)) / 2.0)