Merge pull request #1 from jpdefrutosSINTEF/HausdorffConstantIndependent
Browse files
DeepDeformationMapRegistration/losses.py
CHANGED
@@ -1,19 +1,19 @@
|
|
1 |
-
import os, sys
|
2 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
-
parentdir = os.path.dirname(currentdir)
|
4 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
-
|
6 |
-
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
-
|
8 |
import tensorflow as tf
|
9 |
from scipy.ndimage import generate_binary_structure
|
10 |
|
11 |
-
import DeepDeformationMapRegistration.utils.constants as C
|
12 |
from DeepDeformationMapRegistration.utils.operators import soft_threshold
|
13 |
|
14 |
|
15 |
-
class
|
16 |
def __init__(self, ndim=3, nerosion=10):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
self.ndims = ndim
|
18 |
self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
|
19 |
self.nerosions = nerosion
|
@@ -37,11 +37,12 @@ class HausdorffDistance:
|
|
37 |
er = self._erode(diff, kernel)
|
38 |
ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)))
|
39 |
|
40 |
-
|
|
|
41 |
|
42 |
def loss(self, y_true, y_pred):
|
43 |
batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
|
44 |
dtype=tf.float32)
|
45 |
|
46 |
-
return batched_dist
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import tensorflow as tf
|
2 |
from scipy.ndimage import generate_binary_structure
|
3 |
|
|
|
4 |
from DeepDeformationMapRegistration.utils.operators import soft_threshold
|
5 |
|
6 |
|
7 |
+
class HausdorffDistanceErosion:
|
8 |
def __init__(self, ndim=3, nerosion=10):
|
9 |
+
"""
|
10 |
+
Approximation of the Hausdorff distance based on erosion operations based on the work done by Karimi D., et al.
|
11 |
+
Karimi D., et al., "Reducing the Hausdorff Distance in Medical Image Segmentation with Convolutional Neural
|
12 |
+
Networks". IEEE Transactions on Medical Imaging, 39, 2020. DOI 10.1109/TMI.2019.2930068
|
13 |
+
|
14 |
+
:param ndim: Dimensionality of the images
|
15 |
+
:param nerosion: Number of erosion steps. Defaults to 10.
|
16 |
+
"""
|
17 |
self.ndims = ndim
|
18 |
self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
|
19 |
self.nerosions = nerosion
|
|
|
37 |
er = self._erode(diff, kernel)
|
38 |
ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)))
|
39 |
|
40 |
+
img_vol = tf.cast(tf.reduce_prod(y_true.shape), tf.float32)
|
41 |
+
return tf.divide(ret, img_vol) # Divide by the image size
|
42 |
|
43 |
def loss(self, y_true, y_pred):
|
44 |
batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
|
45 |
dtype=tf.float32)
|
46 |
|
47 |
+
return batched_dist
|
48 |
|
TrainingScripts/Train_2d.py
CHANGED
@@ -14,7 +14,7 @@ from datetime import datetime
|
|
14 |
import DeepDeformationMapRegistration.utils.constants as C
|
15 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
|
16 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
17 |
-
from DeepDeformationMapRegistration.losses import
|
18 |
|
19 |
|
20 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
@@ -52,7 +52,7 @@ vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features
|
|
52 |
|
53 |
# Losses and loss weights
|
54 |
def comb_loss(y_true, y_pred):
|
55 |
-
return 1e-3 *
|
56 |
|
57 |
|
58 |
losses = [comb_loss, vxm.losses.Grad('l2').loss]
|
|
|
14 |
import DeepDeformationMapRegistration.utils.constants as C
|
15 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
|
16 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
17 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
|
18 |
|
19 |
|
20 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
|
|
52 |
|
53 |
# Losses and loss weights
|
54 |
def comb_loss(y_true, y_pred):
|
55 |
+
return 1e-3 * HausdorffDistanceErosion(ndim=2, nerosion=2).loss(y_true, y_pred) + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
|
57 |
|
58 |
losses = [comb_loss, vxm.losses.Grad('l2').loss]
|
TrainingScripts/Train_2d_uncertaintyWeighting.py
CHANGED
@@ -17,7 +17,7 @@ from datetime import datetime
|
|
17 |
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
|
19 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
20 |
-
from DeepDeformationMapRegistration.losses import
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
23 |
|
@@ -66,7 +66,7 @@ def dice_loss(y_true, y_pred):
|
|
66 |
#fixed_pred, dm_pred = vxm_model([moving, fixed])
|
67 |
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
68 |
num_reg_fns=1,
|
69 |
-
loss_fns=[
|
70 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
71 |
prior_loss_w=[1., 1.],
|
72 |
prior_reg_w=[0.01],
|
|
|
17 |
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
|
19 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
23 |
|
|
|
66 |
#fixed_pred, dm_pred = vxm_model([moving, fixed])
|
67 |
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
68 |
num_reg_fns=1,
|
69 |
+
loss_fns=[HausdorffDistanceErosion(2, 2).loss, dice_loss],
|
70 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
71 |
prior_loss_w=[1., 1.],
|
72 |
prior_reg_w=[0.01],
|
TrainingScripts/Train_3d_weaklySupervised.py
CHANGED
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
20 |
-
from DeepDeformationMapRegistration.losses import
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
23 |
|
@@ -49,16 +49,16 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
|
|
49 |
# Losses and loss weights
|
50 |
|
51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
52 |
-
fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
53 |
def dice_loss(y_true, y_pred):
|
54 |
# Dice().loss returns -Dice score
|
55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
|
57 |
multiLoss = UncertaintyWeighting(num_loss_fns=3,
|
58 |
num_reg_fns=1,
|
59 |
-
loss_fns=[
|
60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
61 |
-
prior_loss_w=[1., 1
|
62 |
prior_reg_w=[0.01],
|
63 |
name='MultiLossLayer')
|
64 |
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
|
@@ -66,7 +66,7 @@ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
|
|
66 |
grad,
|
67 |
vxm_model.references.pos_flow])
|
68 |
|
69 |
-
full_model = tf.keras.Model(inputs=vxm_model.inputs + [
|
70 |
|
71 |
# Compile the model
|
72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|
|
|
17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
23 |
|
|
|
49 |
# Losses and loss weights
|
50 |
|
51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
52 |
+
# fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
53 |
def dice_loss(y_true, y_pred):
|
54 |
# Dice().loss returns -Dice score
|
55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
|
57 |
multiLoss = UncertaintyWeighting(num_loss_fns=3,
|
58 |
num_reg_fns=1,
|
59 |
+
loss_fns=[HausdorffDistanceErosion(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
|
60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
61 |
+
prior_loss_w=[1., 1.],
|
62 |
prior_reg_w=[0.01],
|
63 |
name='MultiLossLayer')
|
64 |
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
|
|
|
66 |
grad,
|
67 |
vxm_model.references.pos_flow])
|
68 |
|
69 |
+
full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
|
70 |
|
71 |
# Compile the model
|
72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|