jpdefrutos commited on
Commit
4dfbecb
·
1 Parent(s): a27e593

COMET train segmentation guided

Browse files
Files changed (1) hide show
  1. COMET/COMET_train_seggguided.py +414 -0
COMET/COMET_train_seggguided.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import keras
4
+
5
+ currentdir = os.path.dirname(os.path.realpath(__file__))
6
+ parentdir = os.path.dirname(currentdir)
7
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
8
+
9
+ from datetime import datetime
10
+
11
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
12
+ from tensorflow.python.keras.utils import Progbar
13
+ from tensorflow.keras import Input
14
+ from tensorflow.keras.models import Model
15
+ from tensorflow.python.framework.errors import InvalidArgumentError
16
+
17
+ import DeepDeformationMapRegistration.utils.constants as C
18
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion
19
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
20
+ from DeepDeformationMapRegistration.ms_ssim_tf import _MSSSIM_WEIGHTS
21
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
22
+ from DeepDeformationMapRegistration.utils.misc import function_decorator
23
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
24
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
25
+
26
+ from Brain_study.data_generator import BatchGenerator
27
+ from Brain_study.utils import SummaryDictionary, named_logs
28
+
29
+ import COMET.augmentation_constants as COMET_C
30
+
31
+ import numpy as np
32
+ import tensorflow as tf
33
+ import voxelmorph as vxm
34
+ import h5py
35
+ import re
36
+ import itertools
37
+ import warnings
38
+
39
+
40
+ def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
41
+ segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
42
+ acc_gradients=1, batch_size=16, image_size=64,
43
+ unet=[16, 32, 64, 128, 256], head=[16, 16]):
44
+ # 0. Input checks
45
+ assert dataset_folder is not None and output_folder is not None
46
+ if model_file != '':
47
+ assert '.h5' in model_file, 'The model must be an H5 file'
48
+
49
+ # 1. Load variables
50
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
51
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
52
+ C.GPU_NUM = str(gpu_num)
53
+
54
+ if batch_size != 1 and acc_gradients != 1:
55
+ warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
56
+
57
+ if freeze_layers is not None:
58
+ assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
59
+ 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
60
+ freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
61
+ if len(freeze_layers) > 1:
62
+ freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
63
+
64
+ os.makedirs(output_folder, exist_ok=True)
65
+ # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
66
+ log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
67
+ C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
68
+ C.VALIDATION_DATASET = validation_folder
69
+ C.ACCUM_GRADIENT_STEP = acc_gradients
70
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
71
+ C.EARLY_STOP_PATIENCE = early_stop_patience
72
+ C.LEARNING_RATE = lr
73
+ C.LIMIT_NUM_SAMPLES = None
74
+ C.EPOCHS = max_epochs
75
+
76
+ aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
77
+ "GPU: {}\n" \
78
+ "BATCH SIZE: {}\n" \
79
+ "LR: {}\n" \
80
+ "SIMILARITY: {}\n" \
81
+ "SEGMENTATION: {}\n"\
82
+ "REG. WEIGHT: {}\n" \
83
+ "EPOCHS: {:d}\n" \
84
+ "ACCUM. GRAD: {}\n" \
85
+ "EARLY STOP PATIENCE: {}\n" \
86
+ "FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
87
+ C.TRAINING_DATASET,
88
+ C.VALIDATION_DATASET,
89
+ C.GPU_NUM,
90
+ C.BATCH_SIZE,
91
+ C.LEARNING_RATE,
92
+ simil,
93
+ segm,
94
+ rw,
95
+ C.EPOCHS,
96
+ C.ACCUM_GRADIENT_STEP,
97
+ C.EARLY_STOP_PATIENCE,
98
+ freeze_layers)
99
+
100
+ log_file.write(aux)
101
+ print(aux)
102
+
103
+ # 2. Data generator
104
+ used_labels = 'all'
105
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
106
+ C.TRAINING_PERC, labels=[used_labels], combine_segmentations=False,
107
+ directory_val=C.VALIDATION_DATASET)
108
+
109
+ train_generator = data_generator.get_train_generator()
110
+ validation_generator = data_generator.get_validation_generator()
111
+
112
+ image_input_shape = train_generator.get_data_shape()[-1][:-1]
113
+ image_output_shape = [image_size] * 3
114
+ nb_labels = len(train_generator.get_segmentation_labels())
115
+
116
+ # 3. Load model
117
+ # IMPORTANT: the mode MUST be loaded AFTER setting up the session configuration
118
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
119
+ config.gpu_options.allow_growth = True
120
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
121
+ sess = tf.Session(config=config)
122
+ tf.keras.backend.set_session(sess)
123
+
124
+ loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
125
+ NCC(image_input_shape).loss,
126
+ vxm.losses.MSE().loss,
127
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss,
128
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss,
129
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss,
130
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
131
+ ]
132
+
133
+ metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
134
+ NCC(image_input_shape).metric,
135
+ vxm.losses.MSE().loss,
136
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
137
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
138
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
139
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
140
+
141
+
142
+ try:
143
+ network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
144
+ 'VxmDense': vxm.networks.VxmDense,
145
+ 'AdamAccumulated': AdamAccumulated,
146
+ 'loss': loss_fncs,
147
+ 'metric': metric_fncs},
148
+ compile=False)
149
+ except ValueError as e:
150
+ # enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
151
+ # dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
152
+ enc_features = unet # const.ENCODER_FILTERS
153
+ dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
154
+ nb_features = [enc_features, dec_features]
155
+
156
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
157
+ nb_labels=nb_labels,
158
+ nb_unet_features=nb_features,
159
+ int_steps=0,
160
+ int_downsize=1,
161
+ seg_downsize=1)
162
+
163
+ if model_file != '':
164
+ network.load_weights(model_file, by_name=True)
165
+ print('MODEL LOCATION: ', model_file)
166
+ # 4. Freeze/unfreeze model layers
167
+ # freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
168
+ # for l in freeze_layers:
169
+ # network.layers[l].trainable = False
170
+ # msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
171
+ # print(msg)
172
+ # log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
173
+ if freeze_layers is not None:
174
+ aux = list()
175
+ for r in freeze_layers:
176
+ for l in range(*r):
177
+ network.layers[l].trainable = False
178
+ aux.append(l)
179
+ aux.sort()
180
+ msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
181
+ else:
182
+ msg = "[INF] None frozen layers"
183
+ print(msg)
184
+ log_file.write(msg)
185
+ # network.trainable = False # Freeze the base model
186
+ # # Create a new model on top
187
+ # input_new_model = keras.Input(network.input_shape)
188
+ # x = base_model(input_new_model, training=False)
189
+ # x =
190
+ # network = keras.Model(input_new_model, x)
191
+
192
+ network.summary()
193
+ network.summary(print_fn=log_file.writelines)
194
+ # Complete the model with the augmentation layer
195
+ augm_train_input_shape = train_generator.get_data_shape()[0]
196
+ input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
197
+ augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
198
+ max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
199
+ max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
200
+ num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
201
+ num_augmentations=COMET_C.NUM_AUGMENTATIONS,
202
+ gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
203
+ brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
204
+ in_img_shape=image_input_shape,
205
+ out_img_shape=image_output_shape,
206
+ only_image=False, # If baseline then True
207
+ only_resize=False,
208
+ trainable=False)
209
+ augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
210
+
211
+ input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid')
212
+ augm_layer_valid = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
213
+ max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
214
+ max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
215
+ num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
216
+ num_augmentations=COMET_C.NUM_AUGMENTATIONS,
217
+ gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
218
+ brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
219
+ in_img_shape=image_input_shape,
220
+ out_img_shape=image_output_shape,
221
+ only_image=False,
222
+ only_resize=False,
223
+ trainable=False)
224
+ augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
225
+
226
+ # 5. Setup training environment: loss, optimizer, callbacks, evaluation
227
+
228
+ # Losses and loss weights
229
+ SSIM_KER_SIZE = 5
230
+ MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
231
+ MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
232
+ if simil.lower() == 'mse':
233
+ loss_fnc = vxm.losses.MSE().loss
234
+ elif simil.lower() == 'ncc':
235
+ loss_fnc = NCC(image_input_shape).loss
236
+ elif simil.lower() == 'ssim':
237
+ loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
238
+ elif simil.lower() == 'ms_ssim':
239
+ loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
240
+ elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse':
241
+ @function_decorator('MSSSIM_MSE__loss')
242
+ def loss_fnc(y_true, y_pred):
243
+ return vxm.losses.MSE().loss(y_true, y_pred) + \
244
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
245
+ elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc':
246
+ @function_decorator('MSSSIM_NCC__loss')
247
+ def loss_fnc(y_true, y_pred):
248
+ return NCC(image_input_shape).loss(y_true, y_pred) + \
249
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
250
+ elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse':
251
+ @function_decorator('SSIM_MSE__loss')
252
+ def loss_fnc(y_true, y_pred):
253
+ return vxm.losses.MSE().loss(y_true, y_pred) + \
254
+ StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
255
+ elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc':
256
+ @function_decorator('SSIM_NCC__loss')
257
+ def loss_fnc(y_true, y_pred):
258
+ return NCC(image_input_shape).loss(y_true, y_pred) + \
259
+ StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
260
+ else:
261
+ raise ValueError('Unknown similarity metric: ' + simil)
262
+
263
+ if segm == 'hd':
264
+ loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
265
+ elif segm == 'dice':
266
+ loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
267
+ elif segm == 'dice_macro':
268
+ loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
269
+ else:
270
+ raise ValueError('No valid value for segm')
271
+
272
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
273
+ os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
274
+ callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
275
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
276
+ update_freq='epoch', # or 'batch' or integer
277
+ write_graph=True, write_grads=True
278
+ )
279
+ callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
280
+
281
+ callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
282
+ save_best_only=True, monitor='val_loss', verbose=1, mode='min')
283
+ callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
284
+ save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
285
+
286
+ losses = {'transformer': loss_fnc,
287
+ 'seg_transformer': loss_segm,
288
+ 'flow': vxm.losses.Grad('l2').loss}
289
+ metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
290
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
291
+ tf.keras.losses.MSE,
292
+ NCC(image_input_shape).metric],
293
+ 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
294
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric,
295
+ GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
296
+ ],
297
+ #'flow': vxm.losses.Grad('l2').loss
298
+ }
299
+ loss_weights = {'transformer': 1.,
300
+ 'seg_transformer': 1.,
301
+ 'flow': rw}
302
+
303
+
304
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
305
+ network.compile(optimizer=optimizer,
306
+ loss=losses,
307
+ loss_weights=loss_weights,
308
+ metrics=metrics)
309
+
310
+ # 6. Training loop
311
+ callback_tensorboard.set_model(network)
312
+ callback_early_stop.set_model(network)
313
+ callback_best_model.set_model(network)
314
+ callback_save_checkpoint.set_model(network)
315
+
316
+ summary = SummaryDictionary(network, C.BATCH_SIZE)
317
+ names = network.metrics_names
318
+ log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
319
+
320
+ with sess.as_default():
321
+ # tf.global_variables_initializer()
322
+ callback_tensorboard.on_train_begin()
323
+ callback_early_stop.on_train_begin()
324
+ callback_best_model.on_train_begin()
325
+ callback_save_checkpoint.on_train_begin()
326
+
327
+ for epoch in range(C.EPOCHS):
328
+ callback_tensorboard.on_epoch_begin(epoch)
329
+ callback_early_stop.on_epoch_begin(epoch)
330
+ callback_best_model.on_epoch_begin(epoch)
331
+ callback_save_checkpoint.on_epoch_begin(epoch)
332
+ print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
333
+ print("TRAIN")
334
+
335
+ log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
336
+ progress_bar = Progbar(len(train_generator), width=30, verbose=1)
337
+ for step, (in_batch, _) in enumerate(train_generator, 1):
338
+ callback_best_model.on_train_batch_begin(step)
339
+ callback_save_checkpoint.on_train_batch_begin(step)
340
+ callback_early_stop.on_train_batch_begin(step)
341
+
342
+ try:
343
+ fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
344
+ np.nan_to_num(fix_img, copy=False)
345
+ np.nan_to_num(mov_img, copy=False)
346
+ if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
347
+ msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
348
+ np.unique(mov_img))
349
+ print(msg)
350
+ log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
351
+
352
+ except InvalidArgumentError as err:
353
+ print('TF Error : {}'.format(str(err)))
354
+ continue
355
+
356
+ in_data = (mov_img, fix_img, mov_seg)
357
+ out_data = (fix_img, fix_img, fix_seg)
358
+
359
+ ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
360
+ if np.isnan(ret).any():
361
+ os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
362
+ save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
363
+ save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
364
+ pred_img, dm = network((mov_img, fix_img))
365
+ save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
366
+ save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
367
+ log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
368
+ raise ValueError('CORRUPTION ERROR: Halting training')
369
+
370
+ summary.on_train_batch_end(ret)
371
+ callback_best_model.on_train_batch_end(step, named_logs(network, ret))
372
+ callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
373
+ callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
374
+ progress_bar.update(step, zip(names, ret))
375
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
376
+ val_values = progress_bar._values.copy()
377
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
378
+
379
+ print('\nVALIDATION')
380
+ log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
381
+ progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
382
+ for step, (in_batch, _) in enumerate(validation_generator, 1):
383
+ try:
384
+ fix_img, mov_img, fix_seg, mov_seg = augm_model_valid.predict(in_batch)
385
+ except InvalidArgumentError as err:
386
+ print('TF Error : {}'.format(str(err)))
387
+ continue
388
+
389
+ in_data = (mov_img, fix_img, mov_seg)
390
+ out_data = (fix_img, fix_img, fix_seg)
391
+
392
+ ret = network.test_on_batch(x=in_data,
393
+ y=out_data)
394
+
395
+ summary.on_validation_batch_end(ret)
396
+ progress_bar.update(step, zip(names, ret))
397
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
398
+ val_values = progress_bar._values.copy()
399
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
400
+
401
+ train_generator.on_epoch_end()
402
+ validation_generator.on_epoch_end()
403
+ epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
404
+ callback_tensorboard.on_epoch_end(epoch, epoch_summary)
405
+ callback_best_model.on_epoch_end(epoch, epoch_summary)
406
+ callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
407
+ callback_early_stop.on_epoch_end(epoch, epoch_summary)
408
+ print('End of epoch {}: '.format(epoch), ret, '\n')
409
+
410
+ callback_tensorboard.on_train_end()
411
+ callback_best_model.on_train_end()
412
+ callback_save_checkpoint.on_train_end()
413
+ callback_early_stop.on_train_end()
414
+ # 7. Wrap up