jpdefrutos commited on
Commit
2458333
·
1 Parent(s): 8c1dc9d

Train only on segmentation data

Browse files
EvaluationScripts/Evaluate_3d_weaklySupervised.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ import voxelmorph as vxm
11
+ import neurite as ne
12
+ from datetime import datetime
13
+
14
+ import DeepDeformationMapRegistration.utils.constants as C
15
+ from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
16
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir
17
+ from DeepDeformationMapRegistration.utils.nifty_utils import save_nifti
18
+ from DeepDeformationMapRegistration.networks import WeaklySupervised
19
+ from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
20
+ from DeepDeformationMapRegistration.layers import UncertaintyWeighting
21
+
22
+
23
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
24
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
25
+
26
+ C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_vessels'
27
+ C.BATCH_SIZE = 2
28
+ C.LIMIT_NUM_SAMPLES = None
29
+ C.EPOCHS = 10000
30
+
31
+ # Load data
32
+ # Build data generator
33
+
34
+ data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
35
+ 1 - C.TRAINING_PERC, voxelmorph=True, segmentations=True)
36
+
37
+ train_generator = data_generator.get_generator('train')
38
+ validation_generator = data_generator.get_generator('validation')
39
+
40
+ data_folder = '../train_3d_multiloss_segm_haus_dice_ncc_grad_203925-29012021'
41
+
42
+ # Build model
43
+ in_shape = train_generator.get_input_shape()[1:-1]
44
+ enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
45
+ dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
46
+ nb_features = [enc_features, dec_features]
47
+ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
48
+ vxm_model.load_weights(os.path.join(data_folder, 'checkpoints', 'best_model.h5'), by_name=True)
49
+
50
+ # Get some samples and plot them
51
+ sample = validation_generator[0]
52
+
53
+ samp_id = 1
54
+ pred_img, pred_seg, pred_flow = vxm_model.predict([sample[0][0][samp_id, ...][np.newaxis, ...],
55
+ sample[0][1][samp_id, ...][np.newaxis, ...],
56
+ sample[0][2][samp_id, ...][np.newaxis, ...]])
57
+
58
+ save_nifti(np.squeeze(pred_img), os.path.join(data_folder, 'pred_img.nii.gz'))
59
+ save_nifti(np.squeeze(pred_seg), os.path.join(data_folder, 'pred_seg.nii.gz'))
60
+ save_nifti(sample[0][0][samp_id, ...], os.path.join(data_folder, 'mov_seg.nii.gz'))
61
+ save_nifti(sample[0][1][samp_id, ...], os.path.join(data_folder, 'fix_seg.nii.gz'))
62
+ save_nifti(sample[0][2][samp_id, ...], os.path.join(data_folder, 'mov_img.nii.gz'))
63
+ save_nifti(sample[0][-2][samp_id, ...], os.path.join(data_folder, 'fix_img.nii.gz'))
TrainingScripts/Train_3d_weaklySupervised.py CHANGED
@@ -17,7 +17,7 @@ import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
  from DeepDeformationMapRegistration.networks import WeaklySupervised
20
- from DeepDeformationMapRegistration.losses import HausdorffDistance
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
@@ -49,16 +49,16 @@ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=
49
  # Losses and loss weights
50
 
51
  grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
- fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
  def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
  multiLoss = UncertaintyWeighting(num_loss_fns=2,
58
  num_reg_fns=1,
59
- loss_fns=[HausdorffDistance(3, 5).loss, dice_loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
- prior_loss_w=[1., 1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
  loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
@@ -66,7 +66,7 @@ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
69
- full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
70
 
71
  # Compile the model
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
 
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
  from DeepDeformationMapRegistration.networks import WeaklySupervised
20
+ from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
23
 
 
49
  # Losses and loss weights
50
 
51
  grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
+ # fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
  def dice_loss(y_true, y_pred):
54
  # Dice().loss returns -Dice score
55
  return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
 
57
  multiLoss = UncertaintyWeighting(num_loss_fns=2,
58
  num_reg_fns=1,
59
+ loss_fns=[HausdorffDistanceErosion(3, 5).loss, dice_loss],
60
  reg_fns=[vxm.losses.Grad('l2').loss],
61
+ prior_loss_w=[1., 1.],
62
  prior_reg_w=[0.01],
63
  name='MultiLossLayer')
64
  loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1],
 
66
  grad,
67
  vxm_model.references.pos_flow])
68
 
69
+ full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
70
 
71
  # Compile the model
72
  full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)