Commit
·
8c1dc9d
1
Parent(s):
371bf06
Removed NCC, training only on Hausdorff and DICE
Browse files
TrainingScripts/Train_3d_weaklySupervised.py
CHANGED
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
20 |
-
from DeepDeformationMapRegistration.losses import
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
23 |
|
@@ -49,24 +49,24 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
|
|
49 |
# Losses and loss weights
|
50 |
|
51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
52 |
-
|
53 |
def dice_loss(y_true, y_pred):
|
54 |
# Dice().loss returns -Dice score
|
55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
|
57 |
-
multiLoss = UncertaintyWeighting(num_loss_fns=
|
58 |
num_reg_fns=1,
|
59 |
-
loss_fns=[
|
60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
61 |
-
prior_loss_w=[1., 1.],
|
62 |
prior_reg_w=[0.01],
|
63 |
name='MultiLossLayer')
|
64 |
-
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
65 |
-
vxm_model.references.pred_segm, vxm_model.references.pred_segm,
|
66 |
grad,
|
67 |
vxm_model.references.pos_flow])
|
68 |
|
69 |
-
full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
|
70 |
|
71 |
# Compile the model
|
72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|
|
|
17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
23 |
|
|
|
49 |
# Losses and loss weights
|
50 |
|
51 |
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
52 |
+
fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
53 |
def dice_loss(y_true, y_pred):
|
54 |
# Dice().loss returns -Dice score
|
55 |
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
|
57 |
+
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
58 |
num_reg_fns=1,
|
59 |
+
loss_fns=[HausdorffDistance(3, 5).loss, dice_loss],
|
60 |
reg_fns=[vxm.losses.Grad('l2').loss],
|
61 |
+
prior_loss_w=[1., 1., 1.],
|
62 |
prior_reg_w=[0.01],
|
63 |
name='MultiLossLayer')
|
64 |
+
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
|
65 |
+
vxm_model.references.pred_segm, vxm_model.references.pred_segm,
|
66 |
grad,
|
67 |
vxm_model.references.pos_flow])
|
68 |
|
69 |
+
full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
|
70 |
|
71 |
# Compile the model
|
72 |
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|