Javier Pérez de Frutos commited on
Commit
371bf06
·
unverified ·
2 Parent(s): c44df3d 9b94746

Merge pull request #1 from jpdefrutosSINTEF/HausdorffConstantIndependent

Browse files
DeepDeformationMapRegistration/losses.py CHANGED
@@ -1,19 +1,19 @@
1
- import os, sys
2
- currentdir = os.path.dirname(os.path.realpath(__file__))
3
- parentdir = os.path.dirname(currentdir)
4
- sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
-
6
- PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
-
8
  import tensorflow as tf
9
  from scipy.ndimage import generate_binary_structure
10
 
11
- import DeepDeformationMapRegistration.utils.constants as C
12
  from DeepDeformationMapRegistration.utils.operators import soft_threshold
13
 
14
 
15
- class HausdorffDistance:
16
  def __init__(self, ndim=3, nerosion=10):
 
 
 
 
 
 
 
 
17
  self.ndims = ndim
18
  self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
19
  self.nerosions = nerosion
@@ -37,11 +37,12 @@ class HausdorffDistance:
37
  er = self._erode(diff, kernel)
38
  ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)))
39
 
40
- return tf.multiply(C.IMG_SIZE ** -self.ndims, ret) # Divide by the image size
 
41
 
42
  def loss(self, y_true, y_pred):
43
  batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
44
  dtype=tf.float32)
45
 
46
- return batched_dist # tf.reduce_mean(batched_dist)
47
 
 
 
 
 
 
 
 
 
1
  import tensorflow as tf
2
  from scipy.ndimage import generate_binary_structure
3
 
 
4
  from DeepDeformationMapRegistration.utils.operators import soft_threshold
5
 
6
 
7
+ class HausdorffDistanceErosion:
8
  def __init__(self, ndim=3, nerosion=10):
9
+ """
10
+ Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
11
+ Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
12
+ Networks". IEEE Transactions on Medical Imaging, 39, 2020. DOI 10.1109/TMI.2019.2930068
13
+
14
+ :param ndim: Dimensionality of the images
15
+ :param nerosion: Number of erosion steps. Defaults to 10.
16
+ """
17
  self.ndims = ndim
18
  self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
19
  self.nerosions = nerosion
 
37
  er = self._erode(diff, kernel)
38
  ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)))
39
 
40
+ img_vol = tf.cast(tf.reduce_prod(y_true.shape), tf.float32)
41
+ return tf.divide(ret, img_vol) # Divide by the image size
42
 
43
  def loss(self, y_true, y_pred):
44
  batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
45
  dtype=tf.float32)
46
 
47
+ return batched_dist
48
 
TrainingScripts/Train_2d.py CHANGED
@@ -14,7 +14,7 @@ from datetime import datetime
14
  import DeepDeformationMapRegistration.utils.constants as C
15
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
16
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
17
- from DeepDeformationMapRegistration.losses import HausdorffDistance
18
 
19
 
20
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
@@ -52,7 +52,7 @@ vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features
52
 
53
  # Losses and loss weights
54
  def comb_loss(y_true, y_pred):
55
- return 1e-3 * HausdorffDistance(ndim=2, nerosion=2).loss(y_true, y_pred) + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
 
58
  losses = [comb_loss, vxm.losses.Grad('l2').loss]
 
14
  import DeepDeformationMapRegistration.utils.constants as C
15
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
16
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
17
+ from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
18
 
19
 
20
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
 
52
 
53
  # Losses and loss weights
54
  def comb_loss(y_true, y_pred):
55
+ return 1e-3 * HausdorffDistanceErosion(ndim=2, nerosion=2).loss(y_true, y_pred) + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
 
58
  losses = [comb_loss, vxm.losses.Grad('l2').loss]
TrainingScripts/Train_2d_uncertaintyWeighting.py CHANGED
@@ -17,7 +17,7 @@ from datetime import datetime
17
  import DeepDeformationMapRegistration.utils.constants as C
18
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
19
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
20
- from DeepDeformationMapRegistration.losses import HausdorffDistance
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
@@ -66,7 +66,7 @@ def dice_loss(y_true, y_pred):
66
  #fixed_pred, dm_pred = vxm_model([moving, fixed])
67
  multiLoss = UncertaintyWeighting(num_loss_fns=2,
68
  num_reg_fns=1,
69
- loss_fns=[HausdorffDistance(2, 2).loss, dice_loss],
70
  reg_fns=[vxm.losses.Grad('l2').loss],
71
  prior_loss_w=[1., 1.],
72
  prior_reg_w=[0.01],
 
17
  import DeepDeformationMapRegistration.utils.constants as C
18
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
19
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
20
+ from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
 
66
  #fixed_pred, dm_pred = vxm_model([moving, fixed])
67
  multiLoss = UncertaintyWeighting(num_loss_fns=2,
68
  num_reg_fns=1,
69
+ loss_fns=[HausdorffDistanceErosion(2, 2).loss, dice_loss],
70
  reg_fns=[vxm.losses.Grad('l2').loss],
71
  prior_loss_w=[1., 1.],
72
  prior_reg_w=[0.01],
TrainingScripts/Train_3d_weaklySupervised.py CHANGED
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
  from DeepDeformationMapRegistration.networks import WeaklySupervised
20
- from DeepDeformationMapRegistration.losses import HausdorffDistance
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
@@ -49,16 +49,16 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
49
  # Losses and loss weights
50
 
51
  grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
- fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
  def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
  multiLoss = UncertaintyWeighting(num_loss_fns=3,
58
  num_reg_fns=1,
59
- loss_fns=[HausdorffDistance(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
- prior_loss_w=[1., 1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
  loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
@@ -66,7 +66,7 @@ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
69
- full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
70
 
71
  # Compile the model
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
 
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
  from DeepDeformationMapRegistration.networks import WeaklySupervised
20
+ from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
 
49
  # Losses and loss weights
50
 
51
  grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
+ # fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
  def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
  multiLoss = UncertaintyWeighting(num_loss_fns=3,
58
  num_reg_fns=1,
59
+ loss_fns=[HausdorffDistanceErosion(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
+ prior_loss_w=[1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
  loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
 
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
69
+ full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
70
 
71
  # Compile the model
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)