Commit
·
e5764e7
1
Parent(s):
3b554c2
Update
Browse files- Brain_study/Train_SegmentationGuided.py +2 -1
- COMET/COMET_train.py +45 -78
- COMET/COMET_train_UW.py +20 -12
- DeepDeformationMapRegistration/layers/augmentation.py +6 -4
- DeepDeformationMapRegistration/utils/constants.py +1 -1
- DeepDeformationMapRegistration/utils/misc.py +32 -0
- SoA_methods/eval_ants.py +39 -24
Brain_study/Train_SegmentationGuided.py
CHANGED
@@ -173,7 +173,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
173 |
loss_weights = {'transformer': 1,
|
174 |
'seg_transformer': 1.,
|
175 |
'flow': 5e-3}
|
176 |
-
metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric,
|
|
|
177 |
'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
|
178 |
#HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
|
179 |
]}
|
|
|
173 |
loss_weights = {'transformer': 1,
|
174 |
'seg_transformer': 1.,
|
175 |
'flow': 5e-3}
|
176 |
+
metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
177 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
|
178 |
'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
|
179 |
#HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
|
180 |
]}
|
COMET/COMET_train.py
CHANGED
@@ -34,22 +34,26 @@ import voxelmorph as vxm
|
|
34 |
import h5py
|
35 |
import re
|
36 |
import itertools
|
|
|
37 |
|
38 |
|
39 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
40 |
-
segm='dice', max_epochs=C.EPOCHS, freeze_layers=None,
|
41 |
-
acc_gradients=1, batch_size=16
|
|
|
42 |
# 0. Input checks
|
43 |
-
assert dataset_folder is not None and output_folder is not None
|
44 |
-
|
45 |
-
|
46 |
-
USE_SEGMENTATIONS = bool(re.search('SEGGUIDED', model_file))
|
47 |
|
48 |
# 1. Load variables
|
49 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
50 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
51 |
C.GPU_NUM = str(gpu_num)
|
52 |
|
|
|
|
|
|
|
53 |
if freeze_layers is not None:
|
54 |
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
55 |
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
@@ -63,8 +67,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
63 |
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
64 |
C.VALIDATION_DATASET = validation_folder
|
65 |
C.ACCUM_GRADIENT_STEP = acc_gradients
|
66 |
-
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else
|
67 |
-
C.EARLY_STOP_PATIENCE =
|
68 |
C.LEARNING_RATE = lr
|
69 |
C.LIMIT_NUM_SAMPLES = None
|
70 |
C.EPOCHS = max_epochs
|
@@ -97,16 +101,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
97 |
print(aux)
|
98 |
|
99 |
# 2. Data generator
|
100 |
-
used_labels = '
|
101 |
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
102 |
-
C.TRAINING_PERC, labels=[used_labels], combine_segmentations=
|
103 |
directory_val=C.VALIDATION_DATASET)
|
104 |
|
105 |
train_generator = data_generator.get_train_generator()
|
106 |
validation_generator = data_generator.get_validation_generator()
|
107 |
|
108 |
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
109 |
-
image_output_shape = [
|
110 |
nb_labels = len(train_generator.get_segmentation_labels())
|
111 |
|
112 |
# 3. Load model
|
@@ -133,7 +137,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
133 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
134 |
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
135 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
|
136 |
-
|
137 |
|
138 |
try:
|
139 |
network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
@@ -143,21 +147,18 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
143 |
'metric': metric_fncs},
|
144 |
compile=False)
|
145 |
except ValueError as e:
|
146 |
-
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
147 |
-
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
|
|
|
|
148 |
nb_features = [enc_features, dec_features]
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
else:
|
157 |
-
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
158 |
-
nb_unet_features=nb_features,
|
159 |
-
int_steps=0)
|
160 |
-
network.load_weights(model_file, by_name=True)
|
161 |
# 4. Freeze/unfreeze model layers
|
162 |
# freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
|
163 |
# for l in freeze_layers:
|
@@ -187,7 +188,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
187 |
network.summary()
|
188 |
network.summary(print_fn=log_file.writelines)
|
189 |
# Complete the model with the augmentation layer
|
190 |
-
augm_train_input_shape = train_generator.get_data_shape()[
|
191 |
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
192 |
augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
193 |
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
@@ -198,7 +199,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
198 |
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
199 |
in_img_shape=image_input_shape,
|
200 |
out_img_shape=image_output_shape,
|
201 |
-
only_image=
|
202 |
only_resize=False,
|
203 |
trainable=False)
|
204 |
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
|
@@ -255,16 +256,6 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
255 |
else:
|
256 |
raise ValueError('Unknown similarity metric: ' + simil)
|
257 |
|
258 |
-
if USE_SEGMENTATIONS:
|
259 |
-
if segm == 'hd':
|
260 |
-
loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
|
261 |
-
elif segm == 'dice':
|
262 |
-
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
|
263 |
-
elif segm == 'dice_macro':
|
264 |
-
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
|
265 |
-
else:
|
266 |
-
raise ValueError('No valid value for segm')
|
267 |
-
|
268 |
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
269 |
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
270 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
@@ -278,34 +269,17 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
278 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
279 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
280 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
],
|
293 |
-
#'flow': vxm.losses.Grad('l2').loss
|
294 |
-
}
|
295 |
-
loss_weights = {'transformer': 1.,
|
296 |
-
'seg_transformer': 1.,
|
297 |
-
'flow': rw}
|
298 |
-
else:
|
299 |
-
losses = {'transformer': loss_fnc,
|
300 |
-
'flow': vxm.losses.Grad('l2').loss}
|
301 |
-
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
302 |
-
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
|
303 |
-
tf.keras.losses.MSE,
|
304 |
-
NCC(image_input_shape).metric],
|
305 |
-
#'flow': vxm.losses.Grad('l2').loss
|
306 |
-
}
|
307 |
-
loss_weights = {'transformer': 1.,
|
308 |
-
'flow': rw}
|
309 |
|
310 |
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
311 |
network.compile(optimizer=optimizer,
|
@@ -359,12 +333,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
359 |
print('TF Error : {}'.format(str(err)))
|
360 |
continue
|
361 |
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
else:
|
366 |
-
in_data = (mov_img, fix_img)
|
367 |
-
out_data = (fix_img, fix_img)
|
368 |
ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
|
369 |
if np.isnan(ret).any():
|
370 |
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
|
@@ -395,12 +366,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
395 |
print('TF Error : {}'.format(str(err)))
|
396 |
continue
|
397 |
|
398 |
-
|
399 |
-
|
400 |
-
out_data = (fix_img, fix_img, fix_seg)
|
401 |
-
else:
|
402 |
-
in_data = (mov_img, fix_img)
|
403 |
-
out_data = (fix_img, fix_img)
|
404 |
ret = network.test_on_batch(x=in_data,
|
405 |
y=out_data)
|
406 |
|
|
|
34 |
import h5py
|
35 |
import re
|
36 |
import itertools
|
37 |
+
import warnings
|
38 |
|
39 |
|
40 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
41 |
+
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
|
42 |
+
acc_gradients=1, batch_size=16, image_size=64,
|
43 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
44 |
# 0. Input checks
|
45 |
+
assert dataset_folder is not None and output_folder is not None
|
46 |
+
if model_file != '':
|
47 |
+
assert '.h5' in model_file, 'The model must be an H5 file'
|
|
|
48 |
|
49 |
# 1. Load variables
|
50 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
51 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
52 |
C.GPU_NUM = str(gpu_num)
|
53 |
|
54 |
+
if batch_size != 1 and acc_gradients != 1:
|
55 |
+
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
56 |
+
|
57 |
if freeze_layers is not None:
|
58 |
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
59 |
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
|
|
67 |
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
68 |
C.VALIDATION_DATASET = validation_folder
|
69 |
C.ACCUM_GRADIENT_STEP = acc_gradients
|
70 |
+
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
|
71 |
+
C.EARLY_STOP_PATIENCE = early_stop_patience
|
72 |
C.LEARNING_RATE = lr
|
73 |
C.LIMIT_NUM_SAMPLES = None
|
74 |
C.EPOCHS = max_epochs
|
|
|
101 |
print(aux)
|
102 |
|
103 |
# 2. Data generator
|
104 |
+
used_labels = 'none'
|
105 |
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
106 |
+
C.TRAINING_PERC, labels=[used_labels], combine_segmentations=True,
|
107 |
directory_val=C.VALIDATION_DATASET)
|
108 |
|
109 |
train_generator = data_generator.get_train_generator()
|
110 |
validation_generator = data_generator.get_validation_generator()
|
111 |
|
112 |
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
113 |
+
image_output_shape = [image_size] * 3
|
114 |
nb_labels = len(train_generator.get_segmentation_labels())
|
115 |
|
116 |
# 3. Load model
|
|
|
137 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
138 |
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
139 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
|
140 |
+
|
141 |
|
142 |
try:
|
143 |
network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
|
|
147 |
'metric': metric_fncs},
|
148 |
compile=False)
|
149 |
except ValueError as e:
|
150 |
+
# enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
151 |
+
# dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
152 |
+
enc_features = unet # const.ENCODER_FILTERS
|
153 |
+
dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
|
154 |
nb_features = [enc_features, dec_features]
|
155 |
+
|
156 |
+
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
157 |
+
nb_unet_features=nb_features,
|
158 |
+
int_steps=0)
|
159 |
+
if model_file != '':
|
160 |
+
network.load_weights(model_file, by_name=True)
|
161 |
+
print('MODEL LOCATION: ', model_file)
|
|
|
|
|
|
|
|
|
|
|
162 |
# 4. Freeze/unfreeze model layers
|
163 |
# freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
|
164 |
# for l in freeze_layers:
|
|
|
188 |
network.summary()
|
189 |
network.summary(print_fn=log_file.writelines)
|
190 |
# Complete the model with the augmentation layer
|
191 |
+
augm_train_input_shape = train_generator.get_data_shape()[-1]
|
192 |
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
193 |
augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
194 |
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
|
|
199 |
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
200 |
in_img_shape=image_input_shape,
|
201 |
out_img_shape=image_output_shape,
|
202 |
+
only_image=True, # If baseline then True
|
203 |
only_resize=False,
|
204 |
trainable=False)
|
205 |
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
|
|
|
256 |
else:
|
257 |
raise ValueError('Unknown similarity metric: ' + simil)
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
260 |
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
261 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
|
|
269 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
270 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
271 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
272 |
+
|
273 |
+
losses = {'transformer': loss_fnc,
|
274 |
+
'flow': vxm.losses.Grad('l2').loss}
|
275 |
+
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
276 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
|
277 |
+
tf.keras.losses.MSE,
|
278 |
+
NCC(image_input_shape).metric],
|
279 |
+
#'flow': vxm.losses.Grad('l2').loss
|
280 |
+
}
|
281 |
+
loss_weights = {'transformer': 1.,
|
282 |
+
'flow': rw}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
285 |
network.compile(optimizer=optimizer,
|
|
|
333 |
print('TF Error : {}'.format(str(err)))
|
334 |
continue
|
335 |
|
336 |
+
in_data = (mov_img, fix_img)
|
337 |
+
out_data = (fix_img, fix_img)
|
338 |
+
|
|
|
|
|
|
|
339 |
ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
|
340 |
if np.isnan(ret).any():
|
341 |
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
|
|
|
366 |
print('TF Error : {}'.format(str(err)))
|
367 |
continue
|
368 |
|
369 |
+
in_data = (mov_img, fix_img)
|
370 |
+
out_data = (fix_img, fix_img)
|
|
|
|
|
|
|
|
|
371 |
ret = network.test_on_batch(x=in_data,
|
372 |
y=out_data)
|
373 |
|
COMET/COMET_train_UW.py
CHANGED
@@ -33,20 +33,26 @@ import voxelmorph as vxm
|
|
33 |
import h5py
|
34 |
import re
|
35 |
import itertools
|
|
|
36 |
|
37 |
|
38 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
|
39 |
-
simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, prior_reg_w=5e-3,
|
40 |
-
acc_gradients=1, batch_size=16
|
|
|
41 |
# 0. Input checks
|
42 |
-
assert dataset_folder is not None and output_folder is not None
|
43 |
-
|
|
|
44 |
|
45 |
# 1. Load variables
|
46 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
47 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
48 |
C.GPU_NUM = str(gpu_num)
|
49 |
|
|
|
|
|
|
|
50 |
if freeze_layers is not None:
|
51 |
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
52 |
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
@@ -63,8 +69,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
63 |
C.TRAINING_DATASET = dataset_folder # dataset_copy.copy_dataset()
|
64 |
C.VALIDATION_DATASET = validation_folder
|
65 |
C.ACCUM_GRADIENT_STEP = acc_gradients
|
66 |
-
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else
|
67 |
-
C.EARLY_STOP_PATIENCE =
|
68 |
C.LEARNING_RATE = lr
|
69 |
C.LIMIT_NUM_SAMPLES = None
|
70 |
C.EPOCHS = max_epochs
|
@@ -105,7 +111,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
105 |
validation_generator = data_generator.get_validation_generator()
|
106 |
|
107 |
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
108 |
-
image_output_shape = [
|
109 |
nb_labels = len(train_generator.get_segmentation_labels())
|
110 |
|
111 |
# 3. Load model
|
@@ -116,10 +122,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
116 |
sess = tf.Session(config=config)
|
117 |
tf.keras.backend.set_session(sess)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
enc_features =
|
122 |
-
dec_features = [
|
123 |
nb_features = [enc_features, dec_features]
|
124 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
125 |
nb_labels=nb_labels,
|
@@ -127,7 +133,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
127 |
int_steps=0,
|
128 |
int_downsize=1,
|
129 |
seg_downsize=1)
|
130 |
-
|
|
|
|
|
131 |
|
132 |
# 4. Freeze/unfreeze model layers
|
133 |
if freeze_layers is not None:
|
|
|
33 |
import h5py
|
34 |
import re
|
35 |
import itertools
|
36 |
+
import warnings
|
37 |
|
38 |
|
39 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
|
40 |
+
simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3,
|
41 |
+
freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64,
|
42 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
43 |
# 0. Input checks
|
44 |
+
assert dataset_folder is not None and output_folder is not None
|
45 |
+
if model_file != '':
|
46 |
+
assert '.h5' in model_file, 'The model must be an H5 file'
|
47 |
|
48 |
# 1. Load variables
|
49 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
50 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
51 |
C.GPU_NUM = str(gpu_num)
|
52 |
|
53 |
+
if batch_size != 1 and acc_gradients != 1:
|
54 |
+
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
55 |
+
|
56 |
if freeze_layers is not None:
|
57 |
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
58 |
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
|
|
69 |
C.TRAINING_DATASET = dataset_folder # dataset_copy.copy_dataset()
|
70 |
C.VALIDATION_DATASET = validation_folder
|
71 |
C.ACCUM_GRADIENT_STEP = acc_gradients
|
72 |
+
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
|
73 |
+
C.EARLY_STOP_PATIENCE = early_stop_patience
|
74 |
C.LEARNING_RATE = lr
|
75 |
C.LIMIT_NUM_SAMPLES = None
|
76 |
C.EPOCHS = max_epochs
|
|
|
111 |
validation_generator = data_generator.get_validation_generator()
|
112 |
|
113 |
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
114 |
+
image_output_shape = [image_size] * 3
|
115 |
nb_labels = len(train_generator.get_segmentation_labels())
|
116 |
|
117 |
# 3. Load model
|
|
|
122 |
sess = tf.Session(config=config)
|
123 |
tf.keras.backend.set_session(sess)
|
124 |
|
125 |
+
# enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
126 |
+
# dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
127 |
+
enc_features = unet # const.ENCODER_FILTERS
|
128 |
+
dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
|
129 |
nb_features = [enc_features, dec_features]
|
130 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
131 |
nb_labels=nb_labels,
|
|
|
133 |
int_steps=0,
|
134 |
int_downsize=1,
|
135 |
seg_downsize=1)
|
136 |
+
if model_file != '':
|
137 |
+
print('MODEL LOCATION: ', model_file)
|
138 |
+
network.load_weights(model_file, by_name=True)
|
139 |
|
140 |
# 4. Freeze/unfreeze model layers
|
141 |
if freeze_layers is not None:
|
DeepDeformationMapRegistration/layers/augmentation.py
CHANGED
@@ -182,14 +182,15 @@ class AugmentationLayer(kl.Layer):
|
|
182 |
return tf.squeeze(disp_map, axis=0)
|
183 |
|
184 |
def gamma_augmentation(self, in_img: tf.Tensor):
|
185 |
-
in_img += 1e-5 # To
|
186 |
-
|
|
|
187 |
|
188 |
return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
|
189 |
|
190 |
def brightness_augmentation(self, in_img: tf.Tensor):
|
191 |
-
c = tf.random.uniform((), 0.
|
192 |
-
return tf.clip_by_value(c
|
193 |
|
194 |
def min_max_normalization(self, in_img: tf.Tensor):
|
195 |
return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
|
@@ -228,6 +229,7 @@ class AugmentationLayer(kl.Layer):
|
|
228 |
except InvalidArgumentError as err:
|
229 |
# If the transformation raises a non-invertible error,
|
230 |
# try again until we get a valid transformation
|
|
|
231 |
continue
|
232 |
else:
|
233 |
valid_trf = True
|
|
|
182 |
return tf.squeeze(disp_map, axis=0)
|
183 |
|
184 |
def gamma_augmentation(self, in_img: tf.Tensor):
|
185 |
+
in_img += 1e-5 # To prevent NaNs
|
186 |
+
f = tf.random.uniform((), -1, 1, tf.float32) # gamma [0.5, 2]
|
187 |
+
gamma = tf.pow(2.0, f)
|
188 |
|
189 |
return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
|
190 |
|
191 |
def brightness_augmentation(self, in_img: tf.Tensor):
|
192 |
+
c = tf.random.uniform((), -0.2, 0.2, tf.float32) # 20% shift
|
193 |
+
return tf.clip_by_value(c + in_img, 0, 1)
|
194 |
|
195 |
def min_max_normalization(self, in_img: tf.Tensor):
|
196 |
return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
|
|
|
229 |
except InvalidArgumentError as err:
|
230 |
# If the transformation raises a non-invertible error,
|
231 |
# try again until we get a valid transformation
|
232 |
+
tf.print('TPS non invertible matrix', output_stream=sys.stdout)
|
233 |
continue
|
234 |
else:
|
235 |
valid_trf = True
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
@@ -518,7 +518,7 @@ MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scal
|
|
518 |
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
519 |
np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[1]) * 180 / np.pi,
|
520 |
np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi]) # Scaled angles
|
521 |
-
GAMMA_AUGMENTATION =
|
522 |
BRIGHTNESS_AUGMENTATION = False
|
523 |
NUM_CONTROL_PTS_AUG = 10
|
524 |
NUM_AUGMENTATIONS = 1
|
|
|
518 |
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
519 |
np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[1]) * 180 / np.pi,
|
520 |
np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[2] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi]) # Scaled angles
|
521 |
+
GAMMA_AUGMENTATION = True
|
522 |
BRIGHTNESS_AUGMENTATION = False
|
523 |
NUM_CONTROL_PTS_AUG = 10
|
524 |
NUM_AUGMENTATIONS = 1
|
DeepDeformationMapRegistration/utils/misc.py
CHANGED
@@ -7,6 +7,7 @@ from skimage.measure import regionprops
|
|
7 |
from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
|
8 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
from tensorflow import squeeze
|
|
|
10 |
|
11 |
|
12 |
def try_mkdir(dir, verbose=True):
|
@@ -148,3 +149,34 @@ def segmentation_cardinal_to_ohe(segmentation):
|
|
148 |
for ch, lbl in enumerate(np.unique(segmentation)[1:]):
|
149 |
cpy[segmentation == lbl, ch] = 1
|
150 |
return cpy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
|
8 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
from tensorflow import squeeze
|
10 |
+
from scipy.ndimage import zoom
|
11 |
|
12 |
|
13 |
def try_mkdir(dir, verbose=True):
|
|
|
149 |
for ch, lbl in enumerate(np.unique(segmentation)[1:]):
|
150 |
cpy[segmentation == lbl, ch] = 1
|
151 |
return cpy
|
152 |
+
|
153 |
+
|
154 |
+
def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray=None):
|
155 |
+
if scale_trf is None:
|
156 |
+
scale_trf = scale_transformation(displacement_map.shape, dest_shape)
|
157 |
+
else:
|
158 |
+
assert isinstance(scale_trf, np.ndarray) and scale_trf.shape == (4, 4), 'Invalid transformation: {}'.format(scale_trf)
|
159 |
+
zoom_factors = scale_trf.diagonal()
|
160 |
+
# First scale the values, so we cut down the number of multiplications
|
161 |
+
dm_resized = np.copy(displacement_map)
|
162 |
+
dm_resized[..., 0] *= zoom_factors[0]
|
163 |
+
dm_resized[..., 1] *= zoom_factors[1]
|
164 |
+
dm_resized[..., 2] *= zoom_factors[2]
|
165 |
+
# Then rescale using zoom
|
166 |
+
dm_resized = zoom(dm_resized, zoom_factors)
|
167 |
+
return dm_resized
|
168 |
+
|
169 |
+
|
170 |
+
def scale_transformation(original_shape: [list, tuple, np.ndarray], dest_shape: [list, tuple, np.ndarray]) -> np.ndarray:
|
171 |
+
if isinstance(original_shape, (list, tuple)):
|
172 |
+
original_shape = np.asarray(original_shape, dtype=int)
|
173 |
+
if isinstance(dest_shape, (list, tuple)):
|
174 |
+
dest_shape = np.asarray(dest_shape, dtype=int)
|
175 |
+
original_shape = original_shape.astype(int)
|
176 |
+
dest_shape = dest_shape.astype(int)
|
177 |
+
|
178 |
+
trf = np.eye(4)
|
179 |
+
np.fill_diagonal(trf, [*np.divide(dest_shape, original_shape), 1])
|
180 |
+
|
181 |
+
return trf
|
182 |
+
|
SoA_methods/eval_ants.py
CHANGED
@@ -35,6 +35,9 @@ WARPED_FIX = 'warpedfixout'
|
|
35 |
FWD_TRFS = 'fwdtransforms'
|
36 |
INV_TRFS = 'invtransforms'
|
37 |
|
|
|
|
|
|
|
38 |
if __name__ == '__main__':
|
39 |
parser = ArgumentParser()
|
40 |
parser.add_argument('--dataset', type=str, help='Directory with the images')
|
@@ -42,11 +45,13 @@ if __name__ == '__main__':
|
|
42 |
args = parser.parse_args()
|
43 |
|
44 |
os.makedirs(args.outdir, exist_ok=True)
|
|
|
|
|
45 |
dataset_files = os.listdir(args.dataset)
|
46 |
dataset_files.sort()
|
47 |
dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
|
48 |
|
49 |
-
dataset_iterator = tqdm(dataset_files)
|
50 |
|
51 |
f = h5py.File(dataset_files[0], 'r')
|
52 |
image_shape = list(f['fix_image'][:].shape[:-1])
|
@@ -87,17 +92,19 @@ if __name__ == '__main__':
|
|
87 |
print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
|
88 |
dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
|
89 |
# Header of the metrics csv file
|
90 |
-
csv_header = ['File', '
|
91 |
|
92 |
-
metrics_file = os.path.join(args.outdir, 'metrics.csv')
|
93 |
-
|
94 |
-
|
|
|
|
|
95 |
|
96 |
print('Starting the loop')
|
97 |
-
for step, file_path in
|
98 |
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
|
99 |
|
100 |
-
dataset_iterator.set_description('{} ({}): loading data'.format(file_num,
|
101 |
with h5py.File(file_path, 'r') as vol_file:
|
102 |
fix_img = vol_file['fix_image'][:]
|
103 |
mov_img = vol_file['mov_image'][:]
|
@@ -112,10 +119,12 @@ if __name__ == '__main__':
|
|
112 |
fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
|
113 |
mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img)) # SoA doesn't work fine with 1-ch images
|
114 |
|
|
|
115 |
t0_syn = time.time()
|
116 |
reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
|
117 |
t1_syn = time.time()
|
118 |
|
|
|
119 |
t0_syncc = time.time()
|
120 |
reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
|
121 |
t1_syncc = time.time()
|
@@ -135,6 +144,8 @@ if __name__ == '__main__':
|
|
135 |
pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
|
136 |
transformlist=mov_to_fix_trf_list).numpy()
|
137 |
pred_seg = np.squeeze(pred_seg) # SoA adds an extra axis which shouldn't be there
|
|
|
|
|
138 |
with sess.as_default():
|
139 |
dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
|
140 |
{'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
|
@@ -163,22 +174,26 @@ if __name__ == '__main__':
|
|
163 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
164 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
165 |
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
168 |
f.write(';'.join(map(str, new_line))+'\n')
|
169 |
|
170 |
-
save_nifti(fix_img[0, ...], os.path.join(args.outdir, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
171 |
-
save_nifti(mov_img[0, ...], os.path.join(args.outdir, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
172 |
-
save_nifti(pred_img[0, ...], os.path.join(args.outdir, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
173 |
-
save_nifti(fix_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
174 |
-
save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
175 |
-
save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
176 |
-
|
177 |
-
plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...], pred_img[np.newaxis, ...], os.path.join(args.outdir, '{:03d}_figures_img.png'.format(step)), show=False)
|
178 |
-
plot_predictions(fix_seg[np.newaxis, ...], mov_seg[np.newaxis, ...], disp_map[np.newaxis, ...], pred_seg[np.newaxis, ...], os.path.join(args.outdir, '{:03d}_figures_seg.png'.format(step)), show=False)
|
179 |
-
save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
pd.read_csv(metrics_file, sep=';', header=0).
|
184 |
-
|
|
|
|
35 |
FWD_TRFS = 'fwdtransforms'
|
36 |
INV_TRFS = 'invtransforms'
|
37 |
|
38 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
39 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
40 |
+
|
41 |
if __name__ == '__main__':
|
42 |
parser = ArgumentParser()
|
43 |
parser.add_argument('--dataset', type=str, help='Directory with the images')
|
|
|
45 |
args = parser.parse_args()
|
46 |
|
47 |
os.makedirs(args.outdir, exist_ok=True)
|
48 |
+
os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
|
49 |
+
os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
|
50 |
dataset_files = os.listdir(args.dataset)
|
51 |
dataset_files.sort()
|
52 |
dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
|
53 |
|
54 |
+
dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
|
55 |
|
56 |
f = h5py.File(dataset_files[0], 'r')
|
57 |
image_shape = list(f['fix_image'][:].shape[:-1])
|
|
|
92 |
print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
|
93 |
dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
|
94 |
# Header of the metrics csv file
|
95 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
|
96 |
|
97 |
+
metrics_file = {'SyN': os.path.join(args.outdir, 'SyN', 'metrics.csv'),
|
98 |
+
'SyNCC': os.path.join(args.outdir, 'SyNCC', 'metrics.csv')}
|
99 |
+
for k in metrics_file.keys():
|
100 |
+
with open(metrics_file[k], 'w') as f:
|
101 |
+
f.write(';'.join(csv_header)+'\n')
|
102 |
|
103 |
print('Starting the loop')
|
104 |
+
for step, file_path in dataset_iterator:
|
105 |
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
|
106 |
|
107 |
+
dataset_iterator.set_description('{} ({}): loading data'.format(file_num, file_path))
|
108 |
with h5py.File(file_path, 'r') as vol_file:
|
109 |
fix_img = vol_file['fix_image'][:]
|
110 |
mov_img = vol_file['mov_image'][:]
|
|
|
119 |
fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
|
120 |
mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img)) # SoA doesn't work fine with 1-ch images
|
121 |
|
122 |
+
dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
|
123 |
t0_syn = time.time()
|
124 |
reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
|
125 |
t1_syn = time.time()
|
126 |
|
127 |
+
dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
|
128 |
t0_syncc = time.time()
|
129 |
reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
|
130 |
t1_syncc = time.time()
|
|
|
144 |
pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
|
145 |
transformlist=mov_to_fix_trf_list).numpy()
|
146 |
pred_seg = np.squeeze(pred_seg) # SoA adds an extra axis which shouldn't be there
|
147 |
+
|
148 |
+
dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
|
149 |
with sess.as_default():
|
150 |
dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
|
151 |
{'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
|
|
|
174 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
175 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
176 |
|
177 |
+
dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
|
178 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
|
179 |
+
t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
|
180 |
+
tre]
|
181 |
+
with open(metrics_file[reg_method], 'a') as f:
|
182 |
f.write(';'.join(map(str, new_line))+'\n')
|
183 |
|
184 |
+
save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
185 |
+
save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
186 |
+
save_nifti(pred_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
187 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
188 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
189 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
190 |
+
|
191 |
+
plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...], pred_img[np.newaxis, ...], os.path.join(args.outdir, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
|
192 |
+
plot_predictions(fix_seg[np.newaxis, ...], mov_seg[np.newaxis, ...], disp_map[np.newaxis, ...], pred_seg[np.newaxis, ...], os.path.join(args.outdir, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
|
193 |
+
save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
194 |
+
|
195 |
+
for k in metrics_file.keys():
|
196 |
+
print('Summary {}\n=======\n'.format(k))
|
197 |
+
print('\nAVG:\n' + str(pd.read_csv(metrics_file[k], sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(
|
198 |
+
pd.read_csv(metrics_file[k], sep=';', header=0).std(axis=0)))
|
199 |
+
print('\n=======\n')
|