jpdefrutos commited on
Commit
49e1f69
·
1 Parent(s): ca253db

Updated call to DataGeneratorManager

Browse files
TrainingScripts/Train_3d.py CHANGED
@@ -15,6 +15,7 @@ from datetime import datetime
15
 
16
  import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
 
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
 
20
 
@@ -44,19 +45,14 @@ vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features
44
 
45
 
46
  # Losses and loss weights
47
-
48
- def comb_loss(y_true, y_pred):
49
- return vxm.losses.MSE().loss(y_true, y_pred) + vxm.losses.NCC().loss(y_true, y_pred)
50
-
51
-
52
- losses = [comb_loss, vxm.losses.Grad('l2').loss]
53
  loss_weights = [1., 0.01]
54
 
55
  # Compile the model
56
  vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)
57
 
58
  # Train
59
- output_folder = os.path.join('train_3d_mse_ncc_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
60
  try_mkdir(output_folder)
61
  try_mkdir(os.path.join(output_folder, 'checkpoints'))
62
  try_mkdir(os.path.join(output_folder, 'tensorboard'))
 
15
 
16
  import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
+ from DeepDeformationMapRegistration.losses import NCC
19
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
20
 
21
 
 
45
 
46
 
47
  # Losses and loss weights
48
+ losses = [NCC(in_shape).loss, vxm.losses.Grad('l2').loss]
 
 
 
 
 
49
  loss_weights = [1., 0.01]
50
 
51
  # Compile the model
52
  vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)
53
 
54
  # Train
55
+ output_folder = os.path.join('TrainingScripts/TrainOutput/baseline_LITS_NCC_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
56
  try_mkdir(output_folder)
57
  try_mkdir(os.path.join(output_folder, 'checkpoints'))
58
  try_mkdir(os.path.join(output_folder, 'tensorboard'))
TrainingScripts/Train_3d_weaklySupervised.py CHANGED
@@ -33,7 +33,9 @@ C.EPOCHS = 10000
33
  # Build data generator
34
 
35
  data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
36
- 1 - C.TRAINING_PERC, voxelmorph=True, segmentations=True)
 
 
37
 
38
  train_generator = data_generator.get_generator('train')
39
  validation_generator = data_generator.get_generator('validation')
@@ -72,7 +74,7 @@ full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
73
 
74
  # Train
75
- output_folder = os.path.join('train_3d_multiloss_segm_haus_dice_ncc_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
76
  try_mkdir(output_folder)
77
  try_mkdir(os.path.join(output_folder, 'checkpoints'))
78
  try_mkdir(os.path.join(output_folder, 'tensorboard'))
 
33
  # Build data generator
34
 
35
  data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
36
+ 1 - C.TRAINING_PERC,
37
+ input_labels=[C.DG_LBL_MOV_VESSELS, C.DG_LBL_FIX_VESSELS, C.DG_LBL_MOV_IMG, C.DG_LBL_ZERO_GRADS],
38
+ output_labels=[])
39
 
40
  train_generator = data_generator.get_generator('train')
41
  validation_generator = data_generator.get_generator('validation')
 
74
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
75
 
76
  # Train
77
+ output_folder = os.path.join('TrainingScripts/TrainOutput/weaklysupervised_DCTHLN_UW_haus_dice_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
78
  try_mkdir(output_folder)
79
  try_mkdir(os.path.join(output_folder, 'checkpoints'))
80
  try_mkdir(os.path.join(output_folder, 'tensorboard'))