jpdefrutos commited on
Commit
8c1dc9d
·
1 Parent(s): 371bf06

Removed NCC, training only on Hausdorff and DICE

Browse files
TrainingScripts/Train_3d_weaklySupervised.py CHANGED
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
  from DeepDeformationMapRegistration.networks import WeaklySupervised
20
- from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
@@ -49,24 +49,24 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
49
  # Losses and loss weights
50
 
51
  grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
- # fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
  def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
- multiLoss = UncertaintyWeighting(num_loss_fns=3,
58
  num_reg_fns=1,
59
- loss_fns=[HausdorffDistanceErosion(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
- prior_loss_w=[1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
- loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
65
- vxm_model.references.pred_segm, vxm_model.references.pred_segm, vxm_model.references.pred_img,
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
69
- full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
70
 
71
  # Compile the model
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
 
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
  from DeepDeformationMapRegistration.networks import WeaklySupervised
20
+ from DeepDeformationMapRegistration.losses import HausdorffDistance
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
 
49
  # Losses and loss weights
50
 
51
  grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
+ fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
  def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
+ multiLoss = UncertaintyWeighting(num_loss_fns=2,
58
  num_reg_fns=1,
59
+ loss_fns=[HausdorffDistance(3, 5).loss, dice_loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
+ prior_loss_w=[1., 1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
+ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
65
+ vxm_model.references.pred_segm, vxm_model.references.pred_segm,
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
69
+ full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
70
 
71
  # Compile the model
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)