jpdefrutos commited on
Commit
476daa5
·
1 Parent(s): 6a4f823

Scripts for training on the COMET CT Dataset

Browse files
COMET/Build_test_set.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import shutil
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ currentdir = os.path.dirname(os.path.realpath(__file__))
8
+ parentdir = os.path.dirname(currentdir)
9
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
10
+
11
+ import tensorflow as tf
12
+ # tf.enable_eager_execution(config=config)
13
+
14
+ import numpy as np
15
+ import h5py
16
+
17
+ import DeepDeformationMapRegistration.utils.constants as C
18
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
20
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
21
+ from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
22
+ from tqdm import tqdm
23
+
24
+ from Brain_study.data_generator import BatchGenerator
25
+
26
+ from skimage.measure import regionprops
27
+ from scipy.interpolate import griddata
28
+
29
+ import argparse
30
+
31
+
32
+ DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test'
33
+
34
+ POINTS = None
35
+ MISSING_CENTROID = np.asarray([[np.nan]*3])
36
+
37
+
38
+ def get_mov_centroids(fix_seg, disp_map, nb_labels=28, brain_study=True):
39
+ fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=brain_study)
40
+ disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
41
+ return fix_centroids, fix_centroids + disp, disp
42
+
43
+
44
+ if __name__ == '__main__':
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
47
+ parser.add_argument('--reldir', type=str, help='Relative path to dataset, in where to store the files', default='')
48
+ parser.add_argument('--gpu', type=int, help='GPU', default=0)
49
+ parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
50
+ parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
51
+ args = parser.parse_args()
52
+
53
+ assert args.dataset != '', "Missing original dataset dataset"
54
+ if args.dir == '' and args.reldir != '':
55
+ OUTPUT_FOLDER_DIR = os.path.join(args.dataset, 'test_dataset')
56
+ elif args.dir != '' and args.reldir == '':
57
+ OUTPUT_FOLDER_DIR = args.dir
58
+ else:
59
+ raise ValueError("Either provide 'dir' or 'reldir'")
60
+
61
+ if args.erase:
62
+ shutil.rmtree(OUTPUT_FOLDER_DIR, ignore_errors=True)
63
+ os.makedirs(OUTPUT_FOLDER_DIR, exist_ok=True)
64
+ print('DESTINATION FOLDER: ', OUTPUT_FOLDER_DIR)
65
+
66
+ DATASET = args.dataset
67
+
68
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
69
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
70
+
71
+ data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
72
+
73
+ img_generator = data_generator.get_train_generator()
74
+ nb_labels = len(img_generator.get_segmentation_labels())
75
+ image_input_shape = img_generator.get_data_shape()[-1][:-1]
76
+ image_output_shape = [64] * 3
77
+ # Build model
78
+
79
+ xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
80
+ yy = np.linspace(0, image_output_shape[1], image_output_shape[2], endpoint=False)
81
+ zz = np.linspace(0, image_output_shape[2], image_output_shape[1], endpoint=False)
82
+
83
+ xx, yy, zz = np.meshgrid(xx, yy, zz)
84
+
85
+ POINTS = np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
86
+
87
+ input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm')
88
+ augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
89
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
90
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
91
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
92
+ num_augmentations=C.NUM_AUGMENTATIONS,
93
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
94
+ brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
95
+ in_img_shape=image_input_shape,
96
+ out_img_shape=image_output_shape,
97
+ only_image=False,
98
+ only_resize=False,
99
+ trainable=False,
100
+ return_displacement_map=True)
101
+ augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
102
+
103
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
104
+ config.gpu_options.allow_growth = True
105
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
106
+
107
+ sess = tf.Session(config=config)
108
+ tf.keras.backend.set_session(sess)
109
+ with sess.as_default():
110
+ sess.run(tf.global_variables_initializer())
111
+ progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
112
+ for step, (in_batch, _) in progress_bar:
113
+ fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
114
+
115
+ fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, False)
116
+
117
+ out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
118
+ out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
119
+ img_shape = fix_img.shape
120
+ segm_shape = fix_seg.shape
121
+ disp_shape = disp_map.shape
122
+ centroids_shape = fix_centroids.shape
123
+ with h5py.File(out_file, 'w') as f:
124
+ f.create_dataset('fix_image', shape=img_shape[1:], dtype=np.float32, data=fix_img[0, ...])
125
+ f.create_dataset('mov_image', shape=img_shape[1:], dtype=np.float32, data=mov_img[0, ...])
126
+ f.create_dataset('fix_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=fix_seg[0, ...])
127
+ f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
128
+ f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
129
+ f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
130
+
131
+ with h5py.File(out_file_dm, 'w') as f:
132
+ f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
133
+ f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
134
+
135
+ print('Done')
COMET/COMET_train.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import keras
4
+
5
+ currentdir = os.path.dirname(os.path.realpath(__file__))
6
+ parentdir = os.path.dirname(currentdir)
7
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
8
+
9
+ from datetime import datetime
10
+
11
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
12
+ from tensorflow.python.keras.utils import Progbar
13
+ from tensorflow.keras import Input
14
+ from tensorflow.keras.models import Model
15
+ from tensorflow.python.framework.errors import InvalidArgumentError
16
+
17
+ import DeepDeformationMapRegistration.utils.constants as C
18
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion
19
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
20
+ from DeepDeformationMapRegistration.ms_ssim_tf import _MSSSIM_WEIGHTS
21
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
22
+ from DeepDeformationMapRegistration.utils.misc import function_decorator
23
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
24
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
25
+
26
+ from Brain_study.data_generator import BatchGenerator
27
+ from Brain_study.utils import SummaryDictionary, named_logs
28
+
29
+ import COMET.augmentation_constants as COMET_C
30
+
31
+ import numpy as np
32
+ import tensorflow as tf
33
+ import voxelmorph as vxm
34
+ import h5py
35
+ import re
36
+ import itertools
37
+
38
+
39
+ def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
40
+ segm='dice', max_epochs=C.EPOCHS, freeze_layers=None,
41
+ acc_gradients=1, batch_size=16):
42
+ # 0. Input checks
43
+ assert dataset_folder is not None and output_folder is not None and model_file is not None
44
+ assert '.h5' in model_file, 'The model must be an H5 file'
45
+
46
+ USE_SEGMENTATIONS = bool(re.search('SEGGUIDED', model_file))
47
+
48
+ # 1. Load variables
49
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
50
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
51
+ C.GPU_NUM = str(gpu_num)
52
+
53
+ if freeze_layers is not None:
54
+ assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
55
+ 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
56
+ freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
57
+ if len(freeze_layers) > 1:
58
+ freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
59
+
60
+ os.makedirs(output_folder, exist_ok=True)
61
+ # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
62
+ log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
63
+ C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
64
+ C.VALIDATION_DATASET = validation_folder
65
+ C.ACCUM_GRADIENT_STEP = acc_gradients
66
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
67
+ C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
68
+ C.LEARNING_RATE = lr
69
+ C.LIMIT_NUM_SAMPLES = None
70
+ C.EPOCHS = max_epochs
71
+
72
+ aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
73
+ "GPU: {}\n" \
74
+ "BATCH SIZE: {}\n" \
75
+ "LR: {}\n" \
76
+ "SIMILARITY: {}\n" \
77
+ "SEGMENTATION: {}\n"\
78
+ "REG. WEIGHT: {}\n" \
79
+ "EPOCHS: {:d}\n" \
80
+ "ACCUM. GRAD: {}\n" \
81
+ "EARLY STOP PATIENCE: {}\n" \
82
+ "FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
83
+ C.TRAINING_DATASET,
84
+ C.VALIDATION_DATASET,
85
+ C.GPU_NUM,
86
+ C.BATCH_SIZE,
87
+ C.LEARNING_RATE,
88
+ simil,
89
+ segm,
90
+ rw,
91
+ C.EPOCHS,
92
+ C.ACCUM_GRADIENT_STEP,
93
+ C.EARLY_STOP_PATIENCE,
94
+ freeze_layers)
95
+
96
+ log_file.write(aux)
97
+ print(aux)
98
+
99
+ # 2. Data generator
100
+ used_labels = 'all' if USE_SEGMENTATIONS else 'none'
101
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
102
+ C.TRAINING_PERC, labels=[used_labels], combine_segmentations=not USE_SEGMENTATIONS,
103
+ directory_val=C.VALIDATION_DATASET)
104
+
105
+ train_generator = data_generator.get_train_generator()
106
+ validation_generator = data_generator.get_validation_generator()
107
+
108
+ image_input_shape = train_generator.get_data_shape()[-1][:-1]
109
+ image_output_shape = [64] * 3
110
+ nb_labels = len(train_generator.get_segmentation_labels())
111
+
112
+ # 3. Load model
113
+ # IMPORTANT: the mode MUST be loaded AFTER setting up the session configuration
114
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
115
+ config.gpu_options.allow_growth = True
116
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
117
+ sess = tf.Session(config=config)
118
+ tf.keras.backend.set_session(sess)
119
+
120
+ loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
121
+ NCC(image_input_shape).loss,
122
+ vxm.losses.MSE().loss,
123
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss,
124
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss,
125
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss,
126
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
127
+ ]
128
+
129
+ metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
130
+ NCC(image_input_shape).metric,
131
+ vxm.losses.MSE().loss,
132
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
133
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
134
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
135
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
136
+ print('MODEL LOCATION: ', model_file)
137
+
138
+ try:
139
+ network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
140
+ 'VxmDense': vxm.networks.VxmDense,
141
+ 'AdamAccumulated': AdamAccumulated,
142
+ 'loss': loss_fncs,
143
+ 'metric': metric_fncs},
144
+ compile=False)
145
+ except ValueError as e:
146
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
147
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
148
+ nb_features = [enc_features, dec_features]
149
+ if USE_SEGMENTATIONS:
150
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
151
+ nb_labels=nb_labels,
152
+ nb_unet_features=nb_features,
153
+ int_steps=0,
154
+ int_downsize=1,
155
+ seg_downsize=1)
156
+ else:
157
+ network = vxm.networks.VxmDense(inshape=image_output_shape,
158
+ nb_unet_features=nb_features,
159
+ int_steps=0)
160
+ network.load_weights(model_file, by_name=True)
161
+ # 4. Freeze/unfreeze model layers
162
+ # freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
163
+ # for l in freeze_layers:
164
+ # network.layers[l].trainable = False
165
+ # msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
166
+ # print(msg)
167
+ # log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
168
+ if freeze_layers is not None:
169
+ aux = list()
170
+ for r in freeze_layers:
171
+ for l in range(*r):
172
+ network.layers[l].trainable = False
173
+ aux.append(l)
174
+ aux.sort()
175
+ msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
176
+ else:
177
+ msg = "[INF] None frozen layers"
178
+ print(msg)
179
+ log_file.write(msg)
180
+ # network.trainable = False # Freeze the base model
181
+ # # Create a new model on top
182
+ # input_new_model = keras.Input(network.input_shape)
183
+ # x = base_model(input_new_model, training=False)
184
+ # x =
185
+ # network = keras.Model(input_new_model, x)
186
+
187
+ network.summary()
188
+ network.summary(print_fn=log_file.writelines)
189
+ # Complete the model with the augmentation layer
190
+ augm_train_input_shape = train_generator.get_data_shape()[0] if USE_SEGMENTATIONS else train_generator.get_data_shape()[-1]
191
+ input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
192
+ augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
193
+ max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
194
+ max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
195
+ num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
196
+ num_augmentations=COMET_C.NUM_AUGMENTATIONS,
197
+ gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
198
+ brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
199
+ in_img_shape=image_input_shape,
200
+ out_img_shape=image_output_shape,
201
+ only_image=not USE_SEGMENTATIONS, # If baseline then True
202
+ only_resize=False,
203
+ trainable=False)
204
+ augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
205
+
206
+ input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid')
207
+ augm_layer_valid = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
208
+ max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
209
+ max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
210
+ num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
211
+ num_augmentations=COMET_C.NUM_AUGMENTATIONS,
212
+ gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
213
+ brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
214
+ in_img_shape=image_input_shape,
215
+ out_img_shape=image_output_shape,
216
+ only_image=False,
217
+ only_resize=False,
218
+ trainable=False)
219
+ augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
220
+
221
+ # 5. Setup training environment: loss, optimizer, callbacks, evaluation
222
+
223
+ # Losses and loss weights
224
+ SSIM_KER_SIZE = 5
225
+ MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
226
+ MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
227
+ if simil.lower() == 'mse':
228
+ loss_fnc = vxm.losses.MSE().loss
229
+ elif simil.lower() == 'ncc':
230
+ loss_fnc = NCC(image_input_shape).loss
231
+ elif simil.lower() == 'ssim':
232
+ loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
233
+ elif simil.lower() == 'ms_ssim':
234
+ loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
235
+ elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse':
236
+ @function_decorator('MSSSIM_MSE__loss')
237
+ def loss_fnc(y_true, y_pred):
238
+ return vxm.losses.MSE().loss(y_true, y_pred) + \
239
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
240
+ elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc':
241
+ @function_decorator('MSSSIM_NCC__loss')
242
+ def loss_fnc(y_true, y_pred):
243
+ return NCC(image_input_shape).loss(y_true, y_pred) + \
244
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
245
+ elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse':
246
+ @function_decorator('SSIM_MSE__loss')
247
+ def loss_fnc(y_true, y_pred):
248
+ return vxm.losses.MSE().loss(y_true, y_pred) + \
249
+ StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
250
+ elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc':
251
+ @function_decorator('SSIM_NCC__loss')
252
+ def loss_fnc(y_true, y_pred):
253
+ return NCC(image_input_shape).loss(y_true, y_pred) + \
254
+ StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
255
+ else:
256
+ raise ValueError('Unknown similarity metric: ' + simil)
257
+
258
+ if USE_SEGMENTATIONS:
259
+ if segm == 'hd':
260
+ loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
261
+ elif segm == 'dice':
262
+ loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
263
+ elif segm == 'dice_macro':
264
+ loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
265
+ else:
266
+ raise ValueError('No valid value for segm')
267
+
268
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
269
+ os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
270
+ callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
271
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
272
+ update_freq='epoch', # or 'batch' or integer
273
+ write_graph=True, write_grads=True
274
+ )
275
+ callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
276
+
277
+ callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
278
+ save_best_only=True, monitor='val_loss', verbose=1, mode='min')
279
+ callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
280
+ save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
281
+ if USE_SEGMENTATIONS:
282
+ losses = {'transformer': loss_fnc,
283
+ 'seg_transformer': loss_segm,
284
+ 'flow': vxm.losses.Grad('l2').loss}
285
+ metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
286
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
287
+ tf.keras.losses.MSE,
288
+ NCC(image_input_shape).metric],
289
+ 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
290
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric,
291
+ GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
292
+ ],
293
+ #'flow': vxm.losses.Grad('l2').loss
294
+ }
295
+ loss_weights = {'transformer': 1.,
296
+ 'seg_transformer': 1.,
297
+ 'flow': rw}
298
+ else:
299
+ losses = {'transformer': loss_fnc,
300
+ 'flow': vxm.losses.Grad('l2').loss}
301
+ metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
302
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
303
+ tf.keras.losses.MSE,
304
+ NCC(image_input_shape).metric],
305
+ #'flow': vxm.losses.Grad('l2').loss
306
+ }
307
+ loss_weights = {'transformer': 1.,
308
+ 'flow': rw}
309
+
310
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
311
+ network.compile(optimizer=optimizer,
312
+ loss=losses,
313
+ loss_weights=loss_weights,
314
+ metrics=metrics)
315
+
316
+ # 6. Training loop
317
+ callback_tensorboard.set_model(network)
318
+ callback_early_stop.set_model(network)
319
+ callback_best_model.set_model(network)
320
+ callback_save_checkpoint.set_model(network)
321
+
322
+ summary = SummaryDictionary(network, C.BATCH_SIZE)
323
+ names = network.metrics_names
324
+ log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
325
+
326
+ with sess.as_default():
327
+ # tf.global_variables_initializer()
328
+ callback_tensorboard.on_train_begin()
329
+ callback_early_stop.on_train_begin()
330
+ callback_best_model.on_train_begin()
331
+ callback_save_checkpoint.on_train_begin()
332
+
333
+ for epoch in range(C.EPOCHS):
334
+ callback_tensorboard.on_epoch_begin(epoch)
335
+ callback_early_stop.on_epoch_begin(epoch)
336
+ callback_best_model.on_epoch_begin(epoch)
337
+ callback_save_checkpoint.on_epoch_begin(epoch)
338
+ print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
339
+ print("TRAIN")
340
+
341
+ log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
342
+ progress_bar = Progbar(len(train_generator), width=30, verbose=1)
343
+ for step, (in_batch, _) in enumerate(train_generator, 1):
344
+ callback_best_model.on_train_batch_begin(step)
345
+ callback_save_checkpoint.on_train_batch_begin(step)
346
+ callback_early_stop.on_train_batch_begin(step)
347
+
348
+ try:
349
+ fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
350
+ np.nan_to_num(fix_img, copy=False)
351
+ np.nan_to_num(mov_img, copy=False)
352
+ if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
353
+ msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
354
+ np.unique(mov_img))
355
+ print(msg)
356
+ log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
357
+
358
+ except InvalidArgumentError as err:
359
+ print('TF Error : {}'.format(str(err)))
360
+ continue
361
+
362
+ if USE_SEGMENTATIONS:
363
+ in_data = (mov_img, fix_img, mov_seg)
364
+ out_data = (fix_img, fix_img, fix_seg)
365
+ else:
366
+ in_data = (mov_img, fix_img)
367
+ out_data = (fix_img, fix_img)
368
+ ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
369
+ if np.isnan(ret).any():
370
+ os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
371
+ save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
372
+ save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
373
+ pred_img, dm = network((mov_img, fix_img))
374
+ save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
375
+ save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
376
+ log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
377
+ raise ValueError('CORRUPTION ERROR: Halting training')
378
+
379
+ summary.on_train_batch_end(ret)
380
+ callback_best_model.on_train_batch_end(step, named_logs(network, ret))
381
+ callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
382
+ callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
383
+ progress_bar.update(step, zip(names, ret))
384
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
385
+ val_values = progress_bar._values.copy()
386
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
387
+
388
+ print('\nVALIDATION')
389
+ log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
390
+ progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
391
+ for step, (in_batch, _) in enumerate(validation_generator, 1):
392
+ try:
393
+ fix_img, mov_img, fix_seg, mov_seg = augm_model_valid.predict(in_batch)
394
+ except InvalidArgumentError as err:
395
+ print('TF Error : {}'.format(str(err)))
396
+ continue
397
+
398
+ if USE_SEGMENTATIONS:
399
+ in_data = (mov_img, fix_img, mov_seg)
400
+ out_data = (fix_img, fix_img, fix_seg)
401
+ else:
402
+ in_data = (mov_img, fix_img)
403
+ out_data = (fix_img, fix_img)
404
+ ret = network.test_on_batch(x=in_data,
405
+ y=out_data)
406
+
407
+ summary.on_validation_batch_end(ret)
408
+ progress_bar.update(step, zip(names, ret))
409
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
410
+ val_values = progress_bar._values.copy()
411
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
412
+
413
+ train_generator.on_epoch_end()
414
+ validation_generator.on_epoch_end()
415
+ epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
416
+ callback_tensorboard.on_epoch_end(epoch, epoch_summary)
417
+ callback_best_model.on_epoch_end(epoch, epoch_summary)
418
+ callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
419
+ callback_early_stop.on_epoch_end(epoch, epoch_summary)
420
+ print('End of epoch {}: '.format(epoch), ret, '\n')
421
+
422
+ callback_tensorboard.on_train_end()
423
+ callback_best_model.on_train_end()
424
+ callback_save_checkpoint.on_train_end()
425
+ callback_early_stop.on_train_end()
426
+ # 7. Wrap up
COMET/COMET_train_UW.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ from datetime import datetime
8
+
9
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
10
+ from tensorflow.python.keras.utils import Progbar
11
+ from tensorflow.keras import Input
12
+ from tensorflow.keras.models import Model
13
+ from tensorflow.python.framework.errors import InvalidArgumentError
14
+
15
+ import DeepDeformationMapRegistration.utils.constants as C
16
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, \
17
+ HausdorffDistanceErosion
18
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
19
+ from DeepDeformationMapRegistration.ms_ssim_tf import _MSSSIM_WEIGHTS
20
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
21
+ from DeepDeformationMapRegistration.utils.misc import function_decorator
22
+ from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
23
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
24
+
25
+ from Brain_study.data_generator import BatchGenerator
26
+ from Brain_study.utils import SummaryDictionary, named_logs
27
+
28
+ import COMET.augmentation_constants as COMET_C
29
+
30
+ import numpy as np
31
+ import tensorflow as tf
32
+ import voxelmorph as vxm
33
+ import h5py
34
+ import re
35
+ import itertools
36
+
37
+
38
+ def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
39
+ simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, prior_reg_w=5e-3, freeze_layers=None,
40
+ acc_gradients=1, batch_size=16):
41
+ # 0. Input checks
42
+ assert dataset_folder is not None and output_folder is not None and model_file is not None
43
+ assert '.h5' in model_file, 'The model must be an H5 file'
44
+
45
+ # 1. Load variables
46
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
47
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
48
+ C.GPU_NUM = str(gpu_num)
49
+
50
+ if freeze_layers is not None:
51
+ assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
52
+ 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
53
+ multiple_ranges = 'TOP' in freeze_layers
54
+ freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
55
+ freeze_layers = freeze_layers[0] if multiple_ranges else freeze_layers
56
+
57
+ # if len(freeze_layers) > 1:
58
+ # freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
59
+
60
+ os.makedirs(output_folder, exist_ok=True)
61
+ # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
62
+ log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
63
+ C.TRAINING_DATASET = dataset_folder # dataset_copy.copy_dataset()
64
+ C.VALIDATION_DATASET = validation_folder
65
+ C.ACCUM_GRADIENT_STEP = acc_gradients
66
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
67
+ C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
68
+ C.LEARNING_RATE = lr
69
+ C.LIMIT_NUM_SAMPLES = None
70
+ C.EPOCHS = max_epochs
71
+
72
+ aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
73
+ "GPU: {}\n" \
74
+ "BATCH SIZE: {}\n" \
75
+ "LR: {}\n" \
76
+ "SIMILARITY: {}\n" \
77
+ "SEGMENTATION: {}\n" \
78
+ "REG. WEIGHT: {}\n" \
79
+ "EPOCHS: {:d}\n" \
80
+ "ACCUM. GRAD: {}\n" \
81
+ "EARLY STOP PATIENCE: {}\n" \
82
+ "FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
83
+ C.TRAINING_DATASET,
84
+ C.VALIDATION_DATASET,
85
+ C.GPU_NUM,
86
+ C.BATCH_SIZE,
87
+ C.LEARNING_RATE,
88
+ simil,
89
+ segm,
90
+ rw,
91
+ C.EPOCHS,
92
+ C.ACCUM_GRADIENT_STEP,
93
+ C.EARLY_STOP_PATIENCE,
94
+ freeze_layers)
95
+
96
+ log_file.write(aux)
97
+ print(aux)
98
+
99
+ # 2. Data generator
100
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
101
+ C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
102
+ directory_val=C.VALIDATION_DATASET)
103
+
104
+ train_generator = data_generator.get_train_generator()
105
+ validation_generator = data_generator.get_validation_generator()
106
+
107
+ image_input_shape = train_generator.get_data_shape()[-1][:-1]
108
+ image_output_shape = [64] * 3
109
+ nb_labels = len(train_generator.get_segmentation_labels())
110
+
111
+ # 3. Load model
112
+ # IMPORTANT: the mode MUST be loaded AFTER setting up the session configuration
113
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
114
+ config.gpu_options.allow_growth = True
115
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
116
+ sess = tf.Session(config=config)
117
+ tf.keras.backend.set_session(sess)
118
+
119
+ print('MODEL LOCATION: ', model_file)
120
+
121
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
122
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
123
+ nb_features = [enc_features, dec_features]
124
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
125
+ nb_labels=nb_labels,
126
+ nb_unet_features=nb_features,
127
+ int_steps=0,
128
+ int_downsize=1,
129
+ seg_downsize=1)
130
+ network.load_weights(model_file, by_name=True)
131
+
132
+ # 4. Freeze/unfreeze model layers
133
+ if freeze_layers is not None:
134
+ aux = list()
135
+ for r in freeze_layers:
136
+ for l in range(*r):
137
+ network.layers[l].trainable = False
138
+ aux.append(l)
139
+ aux.sort()
140
+ msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
141
+ else:
142
+ msg = "[INF] None frozen layers"
143
+ print(msg)
144
+ log_file.write(msg)
145
+
146
+ network.summary()
147
+ network.summary(print_fn=log_file.write)
148
+ # Complete the model with the augmentation layer
149
+ input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train')
150
+ augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
151
+ max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
152
+ max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
153
+ num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
154
+ num_augmentations=COMET_C.NUM_AUGMENTATIONS,
155
+ gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
156
+ brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
157
+ in_img_shape=image_input_shape,
158
+ out_img_shape=image_output_shape,
159
+ only_image=False, # If baseline then True
160
+ only_resize=False,
161
+ trainable=False)
162
+ augm_model = Model(inputs=input_layer_train, outputs=augm_layer(input_layer_train))
163
+
164
+ # 5. Setup training environment: loss, optimizer, callbacks, evaluation
165
+
166
+ # Losses and loss weights
167
+ SSIM_KER_SIZE = 5
168
+ MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
169
+ MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
170
+ loss_simil = []
171
+ prior_loss_w = []
172
+ for s in simil:
173
+ if s == 'ssim':
174
+ loss_simil.append(StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss)
175
+ prior_loss_w.append(1.)
176
+ elif s == 'ms_ssim':
177
+ loss_simil.append(MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE,
178
+ power_factors=MS_SSIM_WEIGHTS).loss)
179
+ prior_loss_w.append(1.)
180
+ elif s == 'ncc':
181
+ loss_simil.append(NCC(image_input_shape).loss)
182
+ prior_loss_w.append(1.)
183
+ elif s == 'mse':
184
+ loss_simil.append(vxm.losses.MSE().loss)
185
+ prior_loss_w.append(1.)
186
+ else:
187
+ raise ValueError('Unknown similarity function: ', s)
188
+
189
+ loss_segm = []
190
+ for s in segm:
191
+ if s == 'dice':
192
+ loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]],
193
+ num_labels=nb_labels).loss)
194
+ prior_loss_w.append(1.)
195
+ elif s == 'dice_macro':
196
+ loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]],
197
+ num_labels=nb_labels).loss_macro)
198
+ prior_loss_w.append(1.)
199
+ elif s == 'hd':
200
+ loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [
201
+ train_generator.get_data_shape()[2][-1]]).loss)
202
+ prior_loss_w.append(1.)
203
+ else:
204
+ raise ValueError('Unknown similarity function: ', s)
205
+
206
+ # Uncertainty weigthing module
207
+ grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
208
+ fix_seg = tf.keras.Input(shape=(*image_output_shape, len(train_generator.get_segmentation_labels())),
209
+ name='multiLoss_fix_seg_input', dtype=tf.float32)
210
+
211
+ multiLoss = UncertaintyWeighting(num_loss_fns=len(loss_simil) + len(loss_segm),
212
+ num_reg_fns=1,
213
+ loss_fns=[*loss_simil,
214
+ *loss_segm],
215
+ reg_fns=[vxm.losses.Grad('l2').loss],
216
+ prior_loss_w=prior_loss_w,
217
+ # prior_loss_w=[1., 0.1, 1., 1.],
218
+ prior_reg_w=[prior_reg_w],
219
+ name='MultiLossLayer')
220
+ loss = multiLoss([*[network.inputs[1]] * len(loss_simil), *[fix_seg] * len(loss_segm),
221
+ *[network.outputs[0]] * len(loss_simil), *[network.outputs[2]] * len(loss_simil),
222
+ grad,
223
+ network.outputs[1]])
224
+
225
+ # inputs = [mov_img, fix_img, mov_segm, fix_segm, zero_grads]
226
+ # outputs = [pred_img, flow, pred_segm, loss]
227
+ full_model = tf.keras.Model(inputs=network.inputs + [fix_seg, grad],
228
+ outputs=network.outputs + [loss])
229
+
230
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
231
+ os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
232
+ callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
233
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
234
+ update_freq='epoch', # or 'batch' or integer
235
+ write_graph=True, write_grads=True
236
+ )
237
+ callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE,
238
+ min_delta=0.00001)
239
+
240
+ callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
241
+ save_best_only=True, monitor='val_loss', verbose=1, mode='min')
242
+ callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
243
+ save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
244
+
245
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
246
+ full_model.compile(optimizer=optimizer,
247
+ loss=None, )
248
+
249
+ # 6. Training loop
250
+ callback_tensorboard.set_model(full_model)
251
+ callback_early_stop.set_model(full_model)
252
+ callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
253
+ callback_save_checkpoint.set_model(network) # ONLY SAVE THE NETWORK!!!
254
+
255
+ summary = SummaryDictionary(full_model, C.BATCH_SIZE)
256
+ names = full_model.metrics_names
257
+ zero_grads = tf.zeros_like(network.references.pos_flow, name='dummy_zero_grads') # Dummy zeros-tensor
258
+ log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
259
+
260
+ with sess.as_default():
261
+ # tf.global_variables_initializer()
262
+ callback_tensorboard.on_train_begin()
263
+ callback_early_stop.on_train_begin()
264
+ callback_best_model.on_train_begin()
265
+ callback_save_checkpoint.on_train_begin()
266
+
267
+ for epoch in range(C.EPOCHS):
268
+ callback_tensorboard.on_epoch_begin(epoch)
269
+ callback_early_stop.on_epoch_begin(epoch)
270
+ callback_best_model.on_epoch_begin(epoch)
271
+ callback_save_checkpoint.on_epoch_begin(epoch)
272
+ print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
273
+ print("TRAIN")
274
+
275
+ log_file.write(
276
+ '\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
277
+ progress_bar = Progbar(len(train_generator), width=30, verbose=1)
278
+ for step, (in_batch, _) in enumerate(train_generator, 1):
279
+ callback_best_model.on_train_batch_begin(step)
280
+ callback_save_checkpoint.on_train_batch_begin(step)
281
+ callback_early_stop.on_train_batch_begin(step)
282
+
283
+ try:
284
+ fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
285
+ np.nan_to_num(fix_img, copy=False)
286
+ np.nan_to_num(mov_img, copy=False)
287
+ if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(
288
+ np.sum(fix_img)):
289
+ msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
290
+ np.unique(mov_img))
291
+ print(msg)
292
+ log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
293
+
294
+ except InvalidArgumentError as err:
295
+ print('TF Error : {}'.format(str(err)))
296
+ continue
297
+
298
+ ret = full_model.train_on_batch(
299
+ x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads)) # The second element doesn't matter
300
+
301
+ summary.on_train_batch_end(ret)
302
+ callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
303
+ callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret))
304
+ callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
305
+ progress_bar.update(step, zip(names, ret))
306
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
307
+ val_values = progress_bar._values.copy()
308
+ ret = [val_values[x][0] / val_values[x][1] for x in names]
309
+
310
+ print('\nVALIDATION')
311
+ log_file.write(
312
+ '\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
313
+ progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
314
+ for step, (in_batch, _) in enumerate(validation_generator, 1):
315
+ try:
316
+ fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
317
+ except InvalidArgumentError as err:
318
+ print('TF Error : {}'.format(str(err)))
319
+ continue
320
+
321
+ ret = full_model.test_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads))
322
+
323
+ summary.on_validation_batch_end(ret)
324
+ progress_bar.update(step, zip(names, ret))
325
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
326
+ val_values = progress_bar._values.copy()
327
+ ret = [val_values[x][0] / val_values[x][1] for x in names]
328
+
329
+ train_generator.on_epoch_end()
330
+ validation_generator.on_epoch_end()
331
+ epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
332
+ callback_tensorboard.on_epoch_end(epoch, epoch_summary)
333
+ callback_best_model.on_epoch_end(epoch, epoch_summary)
334
+ callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
335
+ callback_early_stop.on_epoch_end(epoch, epoch_summary)
336
+ print('End of epoch {}: '.format(epoch), ret, '\n')
337
+
338
+ callback_tensorboard.on_train_end()
339
+ callback_best_model.on_train_end()
340
+ callback_save_checkpoint.on_train_end()
341
+ callback_early_stop.on_train_end()
342
+ # 7. Wrap up
COMET/Evaluate_network.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import shutil
4
+ import time
5
+ import tkinter
6
+
7
+ import h5py
8
+ import matplotlib.pyplot as plt
9
+
10
+ currentdir = os.path.dirname(os.path.realpath(__file__))
11
+ parentdir = os.path.dirname(currentdir)
12
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
13
+
14
+ import tensorflow as tf
15
+ # tf.enable_eager_execution(config=config)
16
+
17
+ import numpy as np
18
+ import pandas as pd
19
+ import voxelmorph as vxm
20
+
21
+ import DeepDeformationMapRegistration.utils.constants as C
22
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm
23
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
24
+ from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
25
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
26
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
27
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
28
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
29
+ from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
30
+ from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
31
+ from scipy.interpolate import RegularGridInterpolator
32
+ from tqdm import tqdm
33
+
34
+ import h5py
35
+ import re
36
+ from Brain_study.data_generator import BatchGenerator
37
+
38
+ import argparse
39
+
40
+ from skimage.transform import warp
41
+ import neurite as ne
42
+
43
+ # from tkinter import filedialog as fd
44
+
45
+
46
+ DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed'
47
+ MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/checkpoints/best_model.h5'
48
+ DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/'
49
+
50
+
51
+ if __name__ == '__main__':
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument('-m', '--model', nargs='+', type=str, help='.h5 of the model', default=None)
54
+ parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
55
+ parser.add_argument('--gpu', type=int, help='GPU', default=0)
56
+ parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', default=DATASET)
57
+ parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
58
+ parser.add_argument('--outdirname', type=str, default='Evaluate')
59
+ args = parser.parse_args()
60
+ if args.model is not None:
61
+ assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
62
+ MODEL_FILE_LIST = args.model
63
+ DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in args.model]
64
+ elif args.dir is not None:
65
+ assert '.h5' not in args.dir[0], 'Provided checkpoint file, user -m/--model instead'
66
+ MODEL_FILE_LIST = [os.path.join(dir_path, 'checkpoints', 'best_model.h5') for dir_path in args.dir]
67
+ DATA_ROOT_DIR_LIST = args.dir
68
+ else:
69
+ # try:
70
+ # MODEL_FILE_LIST = fd.askopenfilenames(title='Select .h model file',
71
+ # initialdir='/mnt/EncryptedData1/Users/javier/train_output/COMET',
72
+ # filetypes=(('Model', '*.h5'),))
73
+ # if len(MODEL_FILE_LIST):
74
+ # DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in MODEL_FILE_LIST]
75
+ # else:
76
+ # raise ValueError('No model selected')
77
+ # except tkinter.TclError as e:
78
+ # raise ValueError('Cannot launch TkInter file explorer. User the -m or -d arguments')
79
+ pass
80
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
81
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
82
+ DATASET = args.dataset
83
+ list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
84
+ list_test_files.sort()
85
+
86
+ with h5py.File(list_test_files[0], 'r') as f:
87
+ image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
88
+ nb_labels = f['fix_segmentations'][:].shape[-1]
89
+
90
+ # Header of the metrics csv file
91
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
92
+
93
+ # TF stuff
94
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
95
+ config.gpu_options.allow_growth = True
96
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
97
+ config.allow_soft_placement = True
98
+
99
+ sess = tf.Session(config=config)
100
+ tf.keras.backend.set_session(sess)
101
+
102
+ # Loss and metric functions. Common to all models
103
+ loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
104
+ NCC(image_input_shape).loss,
105
+ vxm.losses.MSE().loss,
106
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
107
+
108
+ metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
109
+ NCC(image_input_shape).metric,
110
+ vxm.losses.MSE().loss,
111
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
112
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
113
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
114
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
115
+
116
+ ### METRICS GRAPH ###
117
+ fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
118
+ pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
119
+ fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
120
+ pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
121
+
122
+ ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
123
+ ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
124
+ mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
125
+ ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
126
+ dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
127
+ hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
128
+ dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
129
+ # hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
130
+
131
+ # Needed for VxmDense type of network
132
+ warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
133
+
134
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
135
+
136
+ for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
137
+ print('MODEL LOCATION: ', MODEL_FILE)
138
+
139
+ # data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
140
+ output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
141
+ # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
142
+ if args.erase:
143
+ shutil.rmtree(output_folder, ignore_errors=True)
144
+ os.makedirs(output_folder, exist_ok=True)
145
+ print('DESTINATION FOLDER: ', output_folder)
146
+
147
+ try:
148
+ network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
149
+ 'VxmDense': vxm.networks.VxmDense,
150
+ 'AdamAccumulated': AdamAccumulated,
151
+ 'loss': loss_fncs,
152
+ 'metric': metric_fncs},
153
+ compile=False)
154
+ except ValueError as e:
155
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
156
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
157
+ nb_features = [enc_features, dec_features]
158
+ if re.search('^UW|SEGGUIDED_', MODEL_FILE):
159
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
160
+ nb_labels=nb_labels,
161
+ nb_unet_features=nb_features,
162
+ int_steps=0,
163
+ int_downsize=1,
164
+ seg_downsize=1)
165
+ else:
166
+ network = vxm.networks.VxmDense(inshape=image_output_shape,
167
+ nb_unet_features=nb_features,
168
+ int_steps=0)
169
+ network.load_weights(MODEL_FILE, by_name=True)
170
+ # Record metrics
171
+ metrics_file = os.path.join(output_folder, 'metrics.csv')
172
+ with open(metrics_file, 'w') as f:
173
+ f.write(';'.join(csv_header)+'\n')
174
+
175
+ ssim = ncc = mse = ms_ssim = dice = hd = 0
176
+ with sess.as_default():
177
+ sess.run(tf.global_variables_initializer())
178
+ network.load_weights(MODEL_FILE, by_name=True)
179
+ progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
180
+ for step, in_batch in progress_bar:
181
+ with h5py.File(in_batch, 'r') as f:
182
+ fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
183
+ mov_img = f['mov_image'][:][np.newaxis, ...]
184
+ fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
185
+ mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
186
+ fix_centroids = f['fix_centroids'][:]
187
+
188
+ if network.name == 'vxm_dense_semi_supervised_seg':
189
+ t0 = time.time()
190
+ pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
191
+ t1 = time.time()
192
+ else:
193
+ t0 = time.time()
194
+ pred_img, disp_map = network.predict([mov_img, fix_img])
195
+ pred_seg = warp_segmentation.predict([mov_seg, disp_map])
196
+ t1 = time.time()
197
+
198
+ pred_img = min_max_norm(pred_img)
199
+ mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=False)
200
+ # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
201
+ pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
202
+
203
+ # I need the labels to be OHE to compute the segmentation metrics.
204
+ dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
205
+
206
+ pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
207
+ mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
208
+ fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
209
+
210
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
211
+ ms_ssim = ms_ssim[0]
212
+
213
+ # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
214
+ upsample_scale = 128 / 64
215
+ fix_centroids_isotropic = fix_centroids * upsample_scale
216
+ # mov_centroids_isotropic = mov_centroids * upsample_scale
217
+ pred_centroids_isotropic = pred_centroids * upsample_scale
218
+
219
+ fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
220
+ # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
221
+ pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
222
+ # Now we can measure the TRE in mm
223
+ tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
224
+ tre = np.mean([v for v in tre_array if not np.isnan(v)])
225
+ # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
226
+
227
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
228
+ with open(metrics_file, 'a') as f:
229
+ f.write(';'.join(map(str, new_line))+'\n')
230
+
231
+ save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
232
+ save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
233
+ save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
234
+ save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
235
+ save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
236
+ save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
237
+
238
+ # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
239
+ # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
240
+ # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
241
+ # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
242
+ # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
243
+ # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
244
+
245
+ # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
246
+ # _ = plt.hist(magnitude.flatten())
247
+ # plt.title('Histogram of disp. magnitudes')
248
+ # plt.show(block=False)
249
+ # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
250
+ # plt.close()
251
+
252
+ plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
253
+ plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
254
+ save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
255
+
256
+ progress_bar.set_description('SSIM {:.04f}\tG_DICE: {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice, dice_macro))
257
+
258
+ print('Summary\n=======\n')
259
+ print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).std(axis=0)))
260
+ print('\n=======\n')
261
+ tf.keras.backend.clear_session()
262
+ # sess.close()
263
+ del network
264
+ print('Done')
COMET/MultiTrain_cli.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import argparse
7
+ from datetime import datetime
8
+
9
+ TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
10
+
11
+ err = list()
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--model', type=str, help='Path to the model file', required=True)
16
+ parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
17
+ parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
18
+ parser.add_argument('--similarity', nargs='+', help='Similarity metric: mse, ncc, ssim', default=['ncc'])
19
+ parser.add_argument('--segmentation', nargs='+', help='Segmentation loss function: hd, dice', default=['dice'])
20
+ parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
21
+ parser.add_argument('--gpu', type=str, help='GPU number', default='0')
22
+ parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
23
+ parser.add_argument('--rw', type=float, help='Regularization weigh', default=5e-3)
24
+ parser.add_argument('--epochs', type=int, help='Max number of epochs', default=1000)
25
+ parser.add_argument('--name', type=str, default='COMET')
26
+ parser.add_argument('--uw', type=bool, default=False)
27
+ parser.add_argument('--freeze', nargs='+', help='What layers to freeze: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM', default=None)
28
+ parser.add_argument('--epochs', default=1000)
29
+ parser.add_argument('--batch', default=16)
30
+ args = parser.parse_args()
31
+
32
+ print('TRAIN ' + args.dataset)
33
+
34
+ if args.uw:
35
+ from COMET.COMET_train_UW import launch_train
36
+ simil = args.similarity
37
+ segm = args.segmentation
38
+ output_folder = os.path.join(args.output, '{}_Lsim_{}__Lseg_{}'.format(args.name, '_'.join(simil), '_'.join(segm)))
39
+ else:
40
+ from COMET.COMET_train import launch_train
41
+ simil = args.similarity[0]
42
+ segm = args.segmentation[0]
43
+ output_folder = os.path.join(args.output, '{}_Lsim_{}__Lseg_{}'.format(args.name, simil, segm))
44
+ output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")
45
+
46
+ if args.freeze is not None:
47
+ assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in args.freeze),\
48
+ 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
49
+ args.freeze = list(set(args.freeze))
50
+
51
+ launch_train(dataset_folder=args.dataset,
52
+ validation_folder=args.validation,
53
+ output_folder=output_folder,
54
+ gpu_num=args.gpu,
55
+ lr=args.lr,
56
+ rw=args.rw,
57
+ simil=simil,
58
+ segm=segm,
59
+ max_epochs=args.epochs,
60
+ model_file=args.model,
61
+ freeze_layers=args.freeze)
COMET/MultiTrain_config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import argparse
7
+ from configparser import ConfigParser
8
+ from shutil import copy2
9
+ import os
10
+ from datetime import datetime
11
+
12
+ TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
13
+
14
+ err = list()
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--ini', help='Configuration file')
19
+ args = parser.parse_args()
20
+
21
+ configFile = ConfigParser()
22
+ configFile.read(args.ini)
23
+ print('Loaded configuration file: ' + args.ini)
24
+ print({section: dict(configFile[section]) for section in configFile.sections()})
25
+ print('\n\n')
26
+
27
+ trainConfig = configFile['TRAIN']
28
+ lossesConfig = configFile['LOSSES']
29
+ datasetConfig = configFile['DATASETS']
30
+ othersConfig = configFile['OTHERS']
31
+
32
+ print('TRAIN MODEL IN' + trainConfig['model'])
33
+
34
+ simil = lossesConfig['similarity'].split(',')
35
+ segm = lossesConfig['segmentation'].split(',')
36
+ if trainConfig['name'].lower() == 'uw':
37
+ from COMET.COMET_train_UW import launch_train
38
+ output_folder = os.path.join(othersConfig['outputFolder'], '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], '_'.join(simil), '_'.join(segm)))
39
+ else:
40
+ from COMET.COMET_train import launch_train
41
+ simil = simil[0]
42
+ segm = segm[0]
43
+ output_folder = os.path.join(othersConfig['outputFolder'], '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], simil, segm))
44
+ output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")
45
+
46
+ try:
47
+ froozen_layers = eval(trainConfig['freeze'])
48
+ except NameError as err:
49
+ froozen_layers = [trainConfig['freeze'].upper()]
50
+ if froozen_layers is not None:
51
+ assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in froozen_layers),\
52
+ 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
53
+ froozen_layers = list(set(froozen_layers)) # Unique elements
54
+
55
+ # copy the configuration file to the destionation folder
56
+ os.makedirs(output_folder, exist_ok=True)
57
+ copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
58
+
59
+ launch_train(dataset_folder=datasetConfig['train'],
60
+ validation_folder=datasetConfig['validation'],
61
+ output_folder=output_folder,
62
+ gpu_num=eval(trainConfig['gpu']),
63
+ lr=eval(trainConfig['learningRate']),
64
+ rw=eval(trainConfig['regularizationWeight']),
65
+ simil=simil,
66
+ segm=segm,
67
+ max_epochs=eval(trainConfig['epochs']),
68
+ model_file=trainConfig['model'],
69
+ freeze_layers=froozen_layers,
70
+ acc_gradients=eval(trainConfig['accumulativeGradients']),
71
+ batch_size=eval(trainConfig['batchSize']))
COMET/augmentation_constants.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Constants for augmentation layer
4
+ # .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
5
+ # The augmentation values will be scaled using the average+std
6
+ ZOOM_FACTORS = np.asarray([0.5032864535069749, 0.5363100665659675, 0.6292598243796296])
7
+ MAX_AUG_DISP_ISOT = 30
8
+ MAX_AUG_DEF_ISOT = 6
9
+ MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * ZOOM_FACTORS) # Scaled displacements
10
+ MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * ZOOM_FACTORS) # Scaled deformations
11
+ MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * ZOOM_FACTORS[1] / ZOOM_FACTORS[0]) * 180 / np.pi,
12
+ np.arctan(np.tan(10*np.pi/180) * ZOOM_FACTORS[2] / ZOOM_FACTORS[1]) * 180 / np.pi,
13
+ np.arctan(np.tan(10*np.pi/180) * ZOOM_FACTORS[2] / ZOOM_FACTORS[0]) * 180 / np.pi]) # Scaled angles
14
+ GAMMA_AUGMENTATION = False
15
+ BRIGHTNESS_AUGMENTATION = False
16
+ NUM_CONTROL_PTS_AUG = 10
17
+ NUM_AUGMENTATIONS = 5
18
+
19
+ IN_LAYERS = (0, 3)
20
+ OUT_LAYERS = (33, 39)
21
+
22
+ ENCONDER_LAYERS = (3, 17)
23
+ DECODER_LAYERS = (17, 33)
24
+
25
+ TOP_LAYERS_ENC = (3, 9)
26
+ TOP_LAYERS_DEC = (22, 29)
27
+ BOTTOM_LAYERS = (9, 22)
28
+
29
+ LAYER_RANGES = {'INPUT': (IN_LAYERS),
30
+ 'OUTPUT': (OUT_LAYERS),
31
+ 'ENCODER': (ENCONDER_LAYERS),
32
+ 'DECODER': (DECODER_LAYERS),
33
+ 'TOP': (TOP_LAYERS_ENC, TOP_LAYERS_DEC),
34
+ 'BOTTOM': (BOTTOM_LAYERS)}
COMET/format_dataset.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import h5py
7
+ import nibabel as nib
8
+ from nilearn.image import resample_img
9
+ import re
10
+ import numpy as np
11
+ from scipy.ndimage import zoom
12
+ from skimage.measure import regionprops
13
+ from tqdm import tqdm
14
+ from argparse import ArgumentParser
15
+ from scipy.ndimage.morphology import binary_dilation, generate_binary_structure
16
+
17
+ import pandas as pd
18
+
19
+ from DeepDeformationMapRegistration.utils import constants as C
20
+ from DeepDeformationMapRegistration.utils.misc import segmentation_cardinal_to_ohe, segmentation_ohe_to_cardinal
21
+
22
+ SEGMENTATION_NR2LBL_LUT = {0: 'background',
23
+ 1: 'parenchyma',
24
+ 2: 'vessel'}
25
+ SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
26
+
27
+ IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Volumes'
28
+ SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Segmentations' # '/home/jpdefrutos/workspace/LiverSegmentation_UNet3D/data/prediction'
29
+
30
+ IMG_NAME_PATTERN = '(.*).nii.gz'
31
+ SEG_NAME_PATTERN = '(.*).nii.gz'
32
+
33
+ OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128'
34
+
35
+
36
+ if __name__ == '__main__':
37
+ parser = ArgumentParser()
38
+ parser.add_argument('--crop', action='store_true') # If present, args.crop = True, else args.crop = False
39
+ parser.add_argument('--offset', type=int, default=C.MAX_AUG_DISP_ISOT + 10, help='Crop offset in mm')
40
+ parser.add_argument('--dilate-segmentations', type=bool, default=False)
41
+ args = parser.parse_args()
42
+
43
+ img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
44
+ img_list.sort()
45
+
46
+ seg_list = [os.path.join(SEG_DIRECTORY, f) for f in os.listdir(SEG_DIRECTORY) if f.endswith('.nii.gz')]
47
+ seg_list.sort()
48
+
49
+ zoom_file = pd.DataFrame(columns=['scale_i', 'scale_j', 'scale_k'])
50
+ os.makedirs(OUT_DIRECTORY, exist_ok=True)
51
+ binary_ball = generate_binary_structure(3, 1)
52
+ for seg_file in tqdm(seg_list):
53
+ img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
54
+ img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
55
+
56
+ img = resample_img(nib.load(img_file), np.eye(3))
57
+ seg = resample_img(nib.load(seg_file), np.eye(3), interpolation='nearest')
58
+
59
+ img = np.asarray(img.dataobj)
60
+ seg = np.asarray(seg.dataobj)
61
+
62
+ segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
63
+ if args.crop:
64
+ parenchyma = regionprops(seg[..., 0])[0]
65
+ bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
66
+ # check that the new bbox is within the image limits!
67
+ bbox[:3] = np.maximum(bbox[:3], [0, 0, 0])
68
+ bbox[3:] = np.minimum(bbox[3:], img.shape)
69
+ img = img[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
70
+ seg = seg[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5], ...]
71
+ # Resize to 128x128x128
72
+ isot_shape = img.shape
73
+
74
+ zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
75
+
76
+ img = zoom(img, zoom_factors, order=3)
77
+ if args.dilate_segmentations:
78
+ seg = binary_dilation(seg, binary_ball, iterations=1)
79
+ seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
80
+ zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
81
+ 'scale_j': zoom_factors[1],
82
+ 'scale_k': zoom_factors[2]}, ignore_index=True)
83
+
84
+ # seg -> cardinal
85
+ # seg_expanded -> OHE
86
+ if segs_are_ohe:
87
+ seg_expanded = seg.copy()
88
+ seg = segmentation_ohe_to_cardinal(seg) # Ordinal encoded. argmax returns the first ocurrence of the maximum. Hence the previoous multiplication operation
89
+ else:
90
+ seg_expanded = segmentation_cardinal_to_ohe(seg)
91
+
92
+ h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
93
+
94
+ h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
95
+ h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
96
+ h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
97
+ h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:]) # Remove the 0 (background label)
98
+ h5_file.create_dataset('isotropic_shape', data=isot_shape)
99
+
100
+ print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
101
+ h5_file.close()
102
+
103
+ zoom_file.to_csv(os.path.join(OUT_DIRECTORY, 'zoom_factors.csv'))
104
+ print("Average")
105
+ print(zoom_file.mean().to_list())
106
+
107
+ print("Standard deviation")
108
+ print(zoom_file.std().to_list())
109
+
110
+ print("Average + STD")
111
+ print((zoom_file.mean() + zoom_file.std()).to_list())
112
+
113
+
114
+
COMET/spit_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shutil import move, copy2
2
+ import os
3
+
4
+ OR_DIR = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128'
5
+ val_split = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/For_validation.txt'
6
+ test_split = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/For_testing.txt'
7
+
8
+ # Create out dirs
9
+ os.makedirs(os.path.join(OR_DIR, 'train'), exist_ok=True)
10
+ os.makedirs(os.path.join(OR_DIR, 'validation'), exist_ok=True)
11
+ os.makedirs(os.path.join(OR_DIR, 'test'), exist_ok=True)
12
+
13
+ # Copy all to train and then split into validation and test
14
+ list_of_files = [os.path.join(OR_DIR, f) for f in os.listdir(OR_DIR) if f.endswith('.h5')]
15
+ list_of_files.sort()
16
+ for f in list_of_files:
17
+ copy2(f, os.path.join(OR_DIR, 'train'))
18
+
19
+ # Get the indices for the validation and test subsets
20
+ with open(val_split, 'r') as f:
21
+ val_idcs = f.readlines()[0]
22
+ val_idcs = [int(e) for e in val_idcs.split(',')]
23
+
24
+ with open(test_split, 'r') as f:
25
+ test_indcs = f.readlines()[0]
26
+ test_indcs = [int(e) for e in test_indcs.split(',')]
27
+
28
+ # move the files from train to validation and test
29
+ for i in val_idcs:
30
+ move(os.path.join(OR_DIR, 'train', '{:05d}_CT.h5'.format(i)), os.path.join(OR_DIR, 'validation'))
31
+ print('Done moving the validation subset.')
32
+
33
+ for i in test_indcs:
34
+ move(os.path.join(OR_DIR, 'train', '{:05d}_CT.h5'.format(i)), os.path.join(OR_DIR, 'test'))
35
+ print('Done moving the validation subset.')
36
+
37
+ print('Done splitting the data')
38
+ print('Training samples: '+str(len(os.listdir(os.path.join(OR_DIR, 'train')))))
39
+ print('Validation samples: '+str(len(val_idcs)))
40
+ print('Test samples: '+str(len(test_indcs)))
COMET/train_config_files/Config_BASELINE_None_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ similarity = ncc
3
+ segmentation =
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/BASELINE_L_ncc__MET_mse_ncc_ssim_232329-01092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 1
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = BASELINE
14
+ freeze = None
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/NONE_FROZEN
COMET/train_config_files/Config_BASELINE_bottom_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ similarity = ncc
3
+ segmentation =
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/BASELINE_L_ncc__MET_mse_ncc_ssim_232329-01092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 0
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = BASELINE
14
+ freeze = BOTTOM
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/BOTTOM_FROZEN
COMET/train_config_files/Config_BASELINE_top_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ similarity = ncc
3
+ segmentation =
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/BASELINE_L_ncc__MET_mse_ncc_ssim_232329-01092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 1
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = BASELINE
14
+ freeze = TOP
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/TOP_FROZEN
COMET/train_config_files/Config_SEGGUIDED_None_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ segmentation = dice
3
+ similarity = ssim
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/SEGGUIDED_Lsim_ncc__Lseg_dice__MET_mse_ncc_ssim_013319-11092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 1
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = SEGGUIDED
14
+ freeze = None
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/NONE_FROZEN
COMET/train_config_files/Config_SEGGUIDED_bottom_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ segmentation = dice
3
+ similarity = ssim
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/SEGGUIDED_Lsim_ncc__Lseg_dice__MET_mse_ncc_ssim_013319-11092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 2
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = SEGGUIDED
14
+ freeze = BOTTOM
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/BOTTOM_FROZEN
COMET/train_config_files/Config_SEGGUIDED_top_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ segmentation = dice
3
+ similarity = ssim
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/SEGGUIDED_Lsim_ncc__Lseg_dice__MET_mse_ncc_ssim_013319-11092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 1
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = SEGGUIDED
14
+ freeze = TOP
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/TOP_FROZEN
COMET/train_config_files/Config_UW_None_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ segmentation = dice,hd
3
+ similarity = ncc,ssim
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/UW_Lsim_ncc__ssim__Lseg_dice__MET_mse_ncc_ssim_204557-02092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 0
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = UW
14
+ freeze = None
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/NONE_FROZEN
COMET/train_config_files/Config_UW_bottom_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ segmentation = dice,hd
3
+ similarity = ncc,ssim
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/UW_Lsim_ncc__ssim__Lseg_dice__MET_mse_ncc_ssim_204557-02092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 2
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = UW
14
+ freeze = BOTTOM
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/BOTTOM_FROZEN
COMET/train_config_files/Config_UW_top_froozen.ini ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [LOSSES]
2
+ segmentation = dice,hd
3
+ similarity = ncc,ssim
4
+
5
+ [TRAIN]
6
+ model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/UW_Lsim_ncc__ssim__Lseg_dice__MET_mse_ncc_ssim_204557-02092021/checkpoints/best_model.h5
7
+ batchSize = 8
8
+ learningRate = 1e-5
9
+ accumulativeGradients = 1
10
+ gpu = 0
11
+ regularizationWeight = 1e-5
12
+ epochs = 10000
13
+ name = UW
14
+ freeze = TOP
15
+
16
+ [DATASETS]
17
+ train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
18
+ validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
19
+
20
+ [OTHERS]
21
+ outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/TOP_FROZEN