jpdefrutos commited on
Commit
e5764e7
·
1 Parent(s): 3b554c2
Brain_study/Train_SegmentationGuided.py CHANGED
@@ -173,7 +173,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
173
  loss_weights = {'transformer': 1,
174
  'seg_transformer': 1.,
175
  'flow': 5e-3}
176
- metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
 
177
  'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
178
  #HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
179
  ]}
 
173
  loss_weights = {'transformer': 1,
174
  'seg_transformer': 1.,
175
  'flow': 5e-3}
176
+ metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
177
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
178
  'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
179
  #HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
180
  ]}
COMET/COMET_train.py CHANGED
@@ -34,22 +34,26 @@ import voxelmorph as vxm
34
  import h5py
35
  import re
36
  import itertools
 
37
 
38
 
39
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
40
- segm='dice', max_epochs=C.EPOCHS, freeze_layers=None,
41
- acc_gradients=1, batch_size=16):
 
42
  # 0. Input checks
43
- assert dataset_folder is not None and output_folder is not None and model_file is not None
44
- assert '.h5' in model_file, 'The model must be an H5 file'
45
-
46
- USE_SEGMENTATIONS = bool(re.search('SEGGUIDED', model_file))
47
 
48
  # 1. Load variables
49
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
50
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
51
  C.GPU_NUM = str(gpu_num)
52
 
 
 
 
53
  if freeze_layers is not None:
54
  assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
55
  'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
@@ -63,8 +67,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
63
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
64
  C.VALIDATION_DATASET = validation_folder
65
  C.ACCUM_GRADIENT_STEP = acc_gradients
66
- C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
67
- C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
68
  C.LEARNING_RATE = lr
69
  C.LIMIT_NUM_SAMPLES = None
70
  C.EPOCHS = max_epochs
@@ -97,16 +101,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
97
  print(aux)
98
 
99
  # 2. Data generator
100
- used_labels = 'all' if USE_SEGMENTATIONS else 'none'
101
  data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
102
- C.TRAINING_PERC, labels=[used_labels], combine_segmentations=not USE_SEGMENTATIONS,
103
  directory_val=C.VALIDATION_DATASET)
104
 
105
  train_generator = data_generator.get_train_generator()
106
  validation_generator = data_generator.get_validation_generator()
107
 
108
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
109
- image_output_shape = [64] * 3
110
  nb_labels = len(train_generator.get_segmentation_labels())
111
 
112
  # 3. Load model
@@ -133,7 +137,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
133
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
134
  HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
135
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
136
- print('MODEL LOCATION: ', model_file)
137
 
138
  try:
139
  network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
@@ -143,21 +147,18 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
143
  'metric': metric_fncs},
144
  compile=False)
145
  except ValueError as e:
146
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
147
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
 
 
148
  nb_features = [enc_features, dec_features]
149
- if USE_SEGMENTATIONS:
150
- network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
151
- nb_labels=nb_labels,
152
- nb_unet_features=nb_features,
153
- int_steps=0,
154
- int_downsize=1,
155
- seg_downsize=1)
156
- else:
157
- network = vxm.networks.VxmDense(inshape=image_output_shape,
158
- nb_unet_features=nb_features,
159
- int_steps=0)
160
- network.load_weights(model_file, by_name=True)
161
  # 4. Freeze/unfreeze model layers
162
  # freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
163
  # for l in freeze_layers:
@@ -187,7 +188,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
187
  network.summary()
188
  network.summary(print_fn=log_file.writelines)
189
  # Complete the model with the augmentation layer
190
- augm_train_input_shape = train_generator.get_data_shape()[0] if USE_SEGMENTATIONS else train_generator.get_data_shape()[-1]
191
  input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
192
  augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
193
  max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
@@ -198,7 +199,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
198
  brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
199
  in_img_shape=image_input_shape,
200
  out_img_shape=image_output_shape,
201
- only_image=not USE_SEGMENTATIONS, # If baseline then True
202
  only_resize=False,
203
  trainable=False)
204
  augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
@@ -255,16 +256,6 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
255
  else:
256
  raise ValueError('Unknown similarity metric: ' + simil)
257
 
258
- if USE_SEGMENTATIONS:
259
- if segm == 'hd':
260
- loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
261
- elif segm == 'dice':
262
- loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
263
- elif segm == 'dice_macro':
264
- loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
265
- else:
266
- raise ValueError('No valid value for segm')
267
-
268
  os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
269
  os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
270
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
@@ -278,34 +269,17 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
278
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
279
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
280
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
281
- if USE_SEGMENTATIONS:
282
- losses = {'transformer': loss_fnc,
283
- 'seg_transformer': loss_segm,
284
- 'flow': vxm.losses.Grad('l2').loss}
285
- metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
286
- MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
287
- tf.keras.losses.MSE,
288
- NCC(image_input_shape).metric],
289
- 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
290
- HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric,
291
- GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
292
- ],
293
- #'flow': vxm.losses.Grad('l2').loss
294
- }
295
- loss_weights = {'transformer': 1.,
296
- 'seg_transformer': 1.,
297
- 'flow': rw}
298
- else:
299
- losses = {'transformer': loss_fnc,
300
- 'flow': vxm.losses.Grad('l2').loss}
301
- metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
302
- MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
303
- tf.keras.losses.MSE,
304
- NCC(image_input_shape).metric],
305
- #'flow': vxm.losses.Grad('l2').loss
306
- }
307
- loss_weights = {'transformer': 1.,
308
- 'flow': rw}
309
 
310
  optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
311
  network.compile(optimizer=optimizer,
@@ -359,12 +333,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
359
  print('TF Error : {}'.format(str(err)))
360
  continue
361
 
362
- if USE_SEGMENTATIONS:
363
- in_data = (mov_img, fix_img, mov_seg)
364
- out_data = (fix_img, fix_img, fix_seg)
365
- else:
366
- in_data = (mov_img, fix_img)
367
- out_data = (fix_img, fix_img)
368
  ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
369
  if np.isnan(ret).any():
370
  os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
@@ -395,12 +366,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
395
  print('TF Error : {}'.format(str(err)))
396
  continue
397
 
398
- if USE_SEGMENTATIONS:
399
- in_data = (mov_img, fix_img, mov_seg)
400
- out_data = (fix_img, fix_img, fix_seg)
401
- else:
402
- in_data = (mov_img, fix_img)
403
- out_data = (fix_img, fix_img)
404
  ret = network.test_on_batch(x=in_data,
405
  y=out_data)
406
 
 
34
  import h5py
35
  import re
36
  import itertools
37
+ import warnings
38
 
39
 
40
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
41
+ segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
42
+ acc_gradients=1, batch_size=16, image_size=64,
43
+ unet=[16, 32, 64, 128, 256], head=[16, 16]):
44
  # 0. Input checks
45
+ assert dataset_folder is not None and output_folder is not None
46
+ if model_file != '':
47
+ assert '.h5' in model_file, 'The model must be an H5 file'
 
48
 
49
  # 1. Load variables
50
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
51
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
52
  C.GPU_NUM = str(gpu_num)
53
 
54
+ if batch_size != 1 and acc_gradients != 1:
55
+ warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
56
+
57
  if freeze_layers is not None:
58
  assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
59
  'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
 
67
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
68
  C.VALIDATION_DATASET = validation_folder
69
  C.ACCUM_GRADIENT_STEP = acc_gradients
70
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
71
+ C.EARLY_STOP_PATIENCE = early_stop_patience
72
  C.LEARNING_RATE = lr
73
  C.LIMIT_NUM_SAMPLES = None
74
  C.EPOCHS = max_epochs
 
101
  print(aux)
102
 
103
  # 2. Data generator
104
+ used_labels = 'none'
105
  data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
106
+ C.TRAINING_PERC, labels=[used_labels], combine_segmentations=True,
107
  directory_val=C.VALIDATION_DATASET)
108
 
109
  train_generator = data_generator.get_train_generator()
110
  validation_generator = data_generator.get_validation_generator()
111
 
112
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
113
+ image_output_shape = [image_size] * 3
114
  nb_labels = len(train_generator.get_segmentation_labels())
115
 
116
  # 3. Load model
 
137
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
138
  HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
139
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
140
+
141
 
142
  try:
143
  network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
 
147
  'metric': metric_fncs},
148
  compile=False)
149
  except ValueError as e:
150
+ # enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
151
+ # dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
152
+ enc_features = unet # const.ENCODER_FILTERS
153
+ dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
154
  nb_features = [enc_features, dec_features]
155
+
156
+ network = vxm.networks.VxmDense(inshape=image_output_shape,
157
+ nb_unet_features=nb_features,
158
+ int_steps=0)
159
+ if model_file != '':
160
+ network.load_weights(model_file, by_name=True)
161
+ print('MODEL LOCATION: ', model_file)
 
 
 
 
 
162
  # 4. Freeze/unfreeze model layers
163
  # freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
164
  # for l in freeze_layers:
 
188
  network.summary()
189
  network.summary(print_fn=log_file.writelines)
190
  # Complete the model with the augmentation layer
191
+ augm_train_input_shape = train_generator.get_data_shape()[-1]
192
  input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
193
  augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
194
  max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
 
199
  brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
200
  in_img_shape=image_input_shape,
201
  out_img_shape=image_output_shape,
202
+ only_image=True, # If baseline then True
203
  only_resize=False,
204
  trainable=False)
205
  augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
 
256
  else:
257
  raise ValueError('Unknown similarity metric: ' + simil)
258
 
 
 
 
 
 
 
 
 
 
 
259
  os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
260
  os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
261
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
 
269
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
270
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
271
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
272
+
273
+ losses = {'transformer': loss_fnc,
274
+ 'flow': vxm.losses.Grad('l2').loss}
275
+ metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
276
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
277
+ tf.keras.losses.MSE,
278
+ NCC(image_input_shape).metric],
279
+ #'flow': vxm.losses.Grad('l2').loss
280
+ }
281
+ loss_weights = {'transformer': 1.,
282
+ 'flow': rw}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
285
  network.compile(optimizer=optimizer,
 
333
  print('TF Error : {}'.format(str(err)))
334
  continue
335
 
336
+ in_data = (mov_img, fix_img)
337
+ out_data = (fix_img, fix_img)
338
+
 
 
 
339
  ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
340
  if np.isnan(ret).any():
341
  os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
 
366
  print('TF Error : {}'.format(str(err)))
367
  continue
368
 
369
+ in_data = (mov_img, fix_img)
370
+ out_data = (fix_img, fix_img)
 
 
 
 
371
  ret = network.test_on_batch(x=in_data,
372
  y=out_data)
373
 
COMET/COMET_train_UW.py CHANGED
@@ -33,20 +33,26 @@ import voxelmorph as vxm
33
  import h5py
34
  import re
35
  import itertools
 
36
 
37
 
38
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
39
- simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, prior_reg_w=5e-3, freeze_layers=None,
40
- acc_gradients=1, batch_size=16):
 
41
  # 0. Input checks
42
- assert dataset_folder is not None and output_folder is not None and model_file is not None
43
- assert '.h5' in model_file, 'The model must be an H5 file'
 
44
 
45
  # 1. Load variables
46
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
47
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
48
  C.GPU_NUM = str(gpu_num)
49
 
 
 
 
50
  if freeze_layers is not None:
51
  assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
52
  'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
@@ -63,8 +69,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
63
  C.TRAINING_DATASET = dataset_folder # dataset_copy.copy_dataset()
64
  C.VALIDATION_DATASET = validation_folder
65
  C.ACCUM_GRADIENT_STEP = acc_gradients
66
- C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
67
- C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
68
  C.LEARNING_RATE = lr
69
  C.LIMIT_NUM_SAMPLES = None
70
  C.EPOCHS = max_epochs
@@ -105,7 +111,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
105
  validation_generator = data_generator.get_validation_generator()
106
 
107
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
108
- image_output_shape = [64] * 3
109
  nb_labels = len(train_generator.get_segmentation_labels())
110
 
111
  # 3. Load model
@@ -116,10 +122,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
116
  sess = tf.Session(config=config)
117
  tf.keras.backend.set_session(sess)
118
 
119
- print('MODEL LOCATION: ', model_file)
120
-
121
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
122
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
123
  nb_features = [enc_features, dec_features]
124
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
125
  nb_labels=nb_labels,
@@ -127,7 +133,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
127
  int_steps=0,
128
  int_downsize=1,
129
  seg_downsize=1)
130
- network.load_weights(model_file, by_name=True)
 
 
131
 
132
  # 4. Freeze/unfreeze model layers
133
  if freeze_layers is not None:
 
33
  import h5py
34
  import re
35
  import itertools
36
+ import warnings
37
 
38
 
39
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
40
+ simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3,
41
+ freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64,
42
+ unet=[16, 32, 64, 128, 256], head=[16, 16]):
43
  # 0. Input checks
44
+ assert dataset_folder is not None and output_folder is not None
45
+ if model_file != '':
46
+ assert '.h5' in model_file, 'The model must be an H5 file'
47
 
48
  # 1. Load variables
49
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
50
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
51
  C.GPU_NUM = str(gpu_num)
52
 
53
+ if batch_size != 1 and acc_gradients != 1:
54
+ warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
55
+
56
  if freeze_layers is not None:
57
  assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
58
  'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
 
69
  C.TRAINING_DATASET = dataset_folder # dataset_copy.copy_dataset()
70
  C.VALIDATION_DATASET = validation_folder
71
  C.ACCUM_GRADIENT_STEP = acc_gradients
72
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
73
+ C.EARLY_STOP_PATIENCE = early_stop_patience
74
  C.LEARNING_RATE = lr
75
  C.LIMIT_NUM_SAMPLES = None
76
  C.EPOCHS = max_epochs
 
111
  validation_generator = data_generator.get_validation_generator()
112
 
113
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
114
+ image_output_shape = [image_size] * 3
115
  nb_labels = len(train_generator.get_segmentation_labels())
116
 
117
  # 3. Load model
 
122
  sess = tf.Session(config=config)
123
  tf.keras.backend.set_session(sess)
124
 
125
+ # enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
126
+ # dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
127
+ enc_features = unet # const.ENCODER_FILTERS
128
+ dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
129
  nb_features = [enc_features, dec_features]
130
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
131
  nb_labels=nb_labels,
 
133
  int_steps=0,
134
  int_downsize=1,
135
  seg_downsize=1)
136
+ if model_file != '':
137
+ print('MODEL LOCATION: ', model_file)
138
+ network.load_weights(model_file, by_name=True)
139
 
140
  # 4. Freeze/unfreeze model layers
141
  if freeze_layers is not None:
DeepDeformationMapRegistration/layers/augmentation.py CHANGED
@@ -182,14 +182,15 @@ class AugmentationLayer(kl.Layer):
182
  return tf.squeeze(disp_map, axis=0)
183
 
184
  def gamma_augmentation(self, in_img: tf.Tensor):
185
- in_img += 1e-5 # To prvent NaNs
186
- gamma = tf.random.uniform((), 0.5, 2, tf.float32)
 
187
 
188
  return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
189
 
190
  def brightness_augmentation(self, in_img: tf.Tensor):
191
- c = tf.random.uniform((), 0.5, 2, tf.float32)
192
- return tf.clip_by_value(c*in_img, 0, 1)
193
 
194
  def min_max_normalization(self, in_img: tf.Tensor):
195
  return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
@@ -228,6 +229,7 @@ class AugmentationLayer(kl.Layer):
228
  except InvalidArgumentError as err:
229
  # If the transformation raises a non-invertible error,
230
  # try again until we get a valid transformation
 
231
  continue
232
  else:
233
  valid_trf = True
 
182
  return tf.squeeze(disp_map, axis=0)
183
 
184
  def gamma_augmentation(self, in_img: tf.Tensor):
185
+ in_img += 1e-5 # To prevent NaNs
186
+ f = tf.random.uniform((), -1, 1, tf.float32) # gamma [0.5, 2]
187
+ gamma = tf.pow(2.0, f)
188
 
189
  return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
190
 
191
  def brightness_augmentation(self, in_img: tf.Tensor):
192
+ c = tf.random.uniform((), -0.2, 0.2, tf.float32) # 20% shift
193
+ return tf.clip_by_value(c + in_img, 0, 1)
194
 
195
  def min_max_normalization(self, in_img: tf.Tensor):
196
  return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
 
229
  except InvalidArgumentError as err:
230
  # If the transformation raises a non-invertible error,
231
  # try again until we get a valid transformation
232
+ tf.print('TPS non invertible matrix', output_stream=sys.stdout)
233
  continue
234
  else:
235
  valid_trf = True
DeepDeformationMapRegistration/utils/constants.py CHANGED
@@ -518,7 +518,7 @@ MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scal
518
  MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
519
  np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[1]) * 180 / np.pi,
520
  np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi]) # Scaled angles
521
- GAMMA_AUGMENTATION = False
522
  BRIGHTNESS_AUGMENTATION = False
523
  NUM_CONTROL_PTS_AUG = 10
524
  NUM_AUGMENTATIONS = 1
 
518
  MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
519
  np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[1]) * 180 / np.pi,
520
  np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi]) # Scaled angles
521
+ GAMMA_AUGMENTATION = True
522
  BRIGHTNESS_AUGMENTATION = False
523
  NUM_CONTROL_PTS_AUG = 10
524
  NUM_AUGMENTATIONS = 1
DeepDeformationMapRegistration/utils/misc.py CHANGED
@@ -7,6 +7,7 @@ from skimage.measure import regionprops
7
  from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
8
  from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
9
  from tensorflow import squeeze
 
10
 
11
 
12
  def try_mkdir(dir, verbose=True):
@@ -148,3 +149,34 @@ def segmentation_cardinal_to_ohe(segmentation):
148
  for ch, lbl in enumerate(np.unique(segmentation)[1:]):
149
  cpy[segmentation == lbl, ch] = 1
150
  return cpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
8
  from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
9
  from tensorflow import squeeze
10
+ from scipy.ndimage import zoom
11
 
12
 
13
  def try_mkdir(dir, verbose=True):
 
149
  for ch, lbl in enumerate(np.unique(segmentation)[1:]):
150
  cpy[segmentation == lbl, ch] = 1
151
  return cpy
152
+
153
+
154
+ def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray=None):
155
+ if scale_trf is None:
156
+ scale_trf = scale_transformation(displacement_map.shape, dest_shape)
157
+ else:
158
+ assert isinstance(scale_trf, np.ndarray) and scale_trf.shape == (4, 4), 'Invalid transformation: {}'.format(scale_trf)
159
+ zoom_factors = scale_trf.diagonal()
160
+ # First scale the values, so we cut down the number of multiplications
161
+ dm_resized = np.copy(displacement_map)
162
+ dm_resized[..., 0] *= zoom_factors[0]
163
+ dm_resized[..., 1] *= zoom_factors[1]
164
+ dm_resized[..., 2] *= zoom_factors[2]
165
+ # Then rescale using zoom
166
+ dm_resized = zoom(dm_resized, zoom_factors)
167
+ return dm_resized
168
+
169
+
170
+ def scale_transformation(original_shape: [list, tuple, np.ndarray], dest_shape: [list, tuple, np.ndarray]) -> np.ndarray:
171
+ if isinstance(original_shape, (list, tuple)):
172
+ original_shape = np.asarray(original_shape, dtype=int)
173
+ if isinstance(dest_shape, (list, tuple)):
174
+ dest_shape = np.asarray(dest_shape, dtype=int)
175
+ original_shape = original_shape.astype(int)
176
+ dest_shape = dest_shape.astype(int)
177
+
178
+ trf = np.eye(4)
179
+ np.fill_diagonal(trf, [*np.divide(dest_shape, original_shape), 1])
180
+
181
+ return trf
182
+
SoA_methods/eval_ants.py CHANGED
@@ -35,6 +35,9 @@ WARPED_FIX = 'warpedfixout'
35
  FWD_TRFS = 'fwdtransforms'
36
  INV_TRFS = 'invtransforms'
37
 
 
 
 
38
  if __name__ == '__main__':
39
  parser = ArgumentParser()
40
  parser.add_argument('--dataset', type=str, help='Directory with the images')
@@ -42,11 +45,13 @@ if __name__ == '__main__':
42
  args = parser.parse_args()
43
 
44
  os.makedirs(args.outdir, exist_ok=True)
 
 
45
  dataset_files = os.listdir(args.dataset)
46
  dataset_files.sort()
47
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
48
 
49
- dataset_iterator = tqdm(dataset_files)
50
 
51
  f = h5py.File(dataset_files[0], 'r')
52
  image_shape = list(f['fix_image'][:].shape[:-1])
@@ -87,17 +92,19 @@ if __name__ == '__main__':
87
  print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
88
  dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
89
  # Header of the metrics csv file
90
- csv_header = ['File', 'Method', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time_SyN', 'Time_SyNCC', 'TRE']
91
 
92
- metrics_file = os.path.join(args.outdir, 'metrics.csv')
93
- with open(metrics_file, 'w') as f:
94
- f.write(';'.join(csv_header)+'\n')
 
 
95
 
96
  print('Starting the loop')
97
- for step, file_path in tqdm(enumerate(dataset_iterator), desc="Running ANTs"):
98
  file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
99
 
100
- dataset_iterator.set_description('{} ({}): loading data'.format(file_num, args.dataset))
101
  with h5py.File(file_path, 'r') as vol_file:
102
  fix_img = vol_file['fix_image'][:]
103
  mov_img = vol_file['mov_image'][:]
@@ -112,10 +119,12 @@ if __name__ == '__main__':
112
  fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
113
  mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img)) # SoA doesn't work fine with 1-ch images
114
 
 
115
  t0_syn = time.time()
116
  reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
117
  t1_syn = time.time()
118
 
 
119
  t0_syncc = time.time()
120
  reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
121
  t1_syncc = time.time()
@@ -135,6 +144,8 @@ if __name__ == '__main__':
135
  pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
136
  transformlist=mov_to_fix_trf_list).numpy()
137
  pred_seg = np.squeeze(pred_seg) # SoA adds an extra axis which shouldn't be there
 
 
138
  with sess.as_default():
139
  dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
140
  {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
@@ -163,22 +174,26 @@ if __name__ == '__main__':
163
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
164
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
165
 
166
- new_line = [step, reg_method, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1_syn-t0_syn, t1_syncc-t0_syncc, tre]
167
- with open(metrics_file, 'a') as f:
 
 
 
168
  f.write(';'.join(map(str, new_line))+'\n')
169
 
170
- save_nifti(fix_img[0, ...], os.path.join(args.outdir, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
171
- save_nifti(mov_img[0, ...], os.path.join(args.outdir, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
172
- save_nifti(pred_img[0, ...], os.path.join(args.outdir, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
173
- save_nifti(fix_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
174
- save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
175
- save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
176
-
177
- plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...], pred_img[np.newaxis, ...], os.path.join(args.outdir, '{:03d}_figures_img.png'.format(step)), show=False)
178
- plot_predictions(fix_seg[np.newaxis, ...], mov_seg[np.newaxis, ...], disp_map[np.newaxis, ...], pred_seg[np.newaxis, ...], os.path.join(args.outdir, '{:03d}_figures_seg.png'.format(step)), show=False)
179
- save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, '{:03d}_disp_map_fig.png'.format(step)), show=False)
180
-
181
- print('Summary\n=======\n')
182
- print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(
183
- pd.read_csv(metrics_file, sep=';', header=0).std(axis=0)))
184
- print('\n=======\n')
 
 
35
  FWD_TRFS = 'fwdtransforms'
36
  INV_TRFS = 'invtransforms'
37
 
38
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
39
+ os.environ['CUDA_VISIBLE_DEVICES'] = '2'
40
+
41
  if __name__ == '__main__':
42
  parser = ArgumentParser()
43
  parser.add_argument('--dataset', type=str, help='Directory with the images')
 
45
  args = parser.parse_args()
46
 
47
  os.makedirs(args.outdir, exist_ok=True)
48
+ os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
49
+ os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
50
  dataset_files = os.listdir(args.dataset)
51
  dataset_files.sort()
52
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
53
 
54
+ dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
55
 
56
  f = h5py.File(dataset_files[0], 'r')
57
  image_shape = list(f['fix_image'][:].shape[:-1])
 
92
  print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
93
  dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
94
  # Header of the metrics csv file
95
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
96
 
97
+ metrics_file = {'SyN': os.path.join(args.outdir, 'SyN', 'metrics.csv'),
98
+ 'SyNCC': os.path.join(args.outdir, 'SyNCC', 'metrics.csv')}
99
+ for k in metrics_file.keys():
100
+ with open(metrics_file[k], 'w') as f:
101
+ f.write(';'.join(csv_header)+'\n')
102
 
103
  print('Starting the loop')
104
+ for step, file_path in dataset_iterator:
105
  file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
106
 
107
+ dataset_iterator.set_description('{} ({}): loading data'.format(file_num, file_path))
108
  with h5py.File(file_path, 'r') as vol_file:
109
  fix_img = vol_file['fix_image'][:]
110
  mov_img = vol_file['mov_image'][:]
 
119
  fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
120
  mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img)) # SoA doesn't work fine with 1-ch images
121
 
122
+ dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
123
  t0_syn = time.time()
124
  reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
125
  t1_syn = time.time()
126
 
127
+ dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
128
  t0_syncc = time.time()
129
  reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
130
  t1_syncc = time.time()
 
144
  pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
145
  transformlist=mov_to_fix_trf_list).numpy()
146
  pred_seg = np.squeeze(pred_seg) # SoA adds an extra axis which shouldn't be there
147
+
148
+ dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
149
  with sess.as_default():
150
  dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
151
  {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
 
174
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
175
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
176
 
177
+ dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
178
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
179
+ t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
180
+ tre]
181
+ with open(metrics_file[reg_method], 'a') as f:
182
  f.write(';'.join(map(str, new_line))+'\n')
183
 
184
+ save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
185
+ save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
186
+ save_nifti(pred_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
187
+ save_nifti(fix_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
188
+ save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
189
+ save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
190
+
191
+ plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...], pred_img[np.newaxis, ...], os.path.join(args.outdir, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
192
+ plot_predictions(fix_seg[np.newaxis, ...], mov_seg[np.newaxis, ...], disp_map[np.newaxis, ...], pred_seg[np.newaxis, ...], os.path.join(args.outdir, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
193
+ save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
194
+
195
+ for k in metrics_file.keys():
196
+ print('Summary {}\n=======\n'.format(k))
197
+ print('\nAVG:\n' + str(pd.read_csv(metrics_file[k], sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(
198
+ pd.read_csv(metrics_file[k], sep=';', header=0).std(axis=0)))
199
+ print('\n=======\n')