Commit
·
4dfbecb
1
Parent(s):
a27e593
COMET train segmentation guided
Browse files- COMET/COMET_train_seggguided.py +414 -0
COMET/COMET_train_seggguided.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import keras
|
4 |
+
|
5 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
6 |
+
parentdir = os.path.dirname(currentdir)
|
7 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
8 |
+
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
12 |
+
from tensorflow.python.keras.utils import Progbar
|
13 |
+
from tensorflow.keras import Input
|
14 |
+
from tensorflow.keras.models import Model
|
15 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
16 |
+
|
17 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion
|
19 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
20 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import _MSSSIM_WEIGHTS
|
21 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
22 |
+
from DeepDeformationMapRegistration.utils.misc import function_decorator
|
23 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
24 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
25 |
+
|
26 |
+
from Brain_study.data_generator import BatchGenerator
|
27 |
+
from Brain_study.utils import SummaryDictionary, named_logs
|
28 |
+
|
29 |
+
import COMET.augmentation_constants as COMET_C
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
import tensorflow as tf
|
33 |
+
import voxelmorph as vxm
|
34 |
+
import h5py
|
35 |
+
import re
|
36 |
+
import itertools
|
37 |
+
import warnings
|
38 |
+
|
39 |
+
|
40 |
+
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
41 |
+
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
|
42 |
+
acc_gradients=1, batch_size=16, image_size=64,
|
43 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
44 |
+
# 0. Input checks
|
45 |
+
assert dataset_folder is not None and output_folder is not None
|
46 |
+
if model_file != '':
|
47 |
+
assert '.h5' in model_file, 'The model must be an H5 file'
|
48 |
+
|
49 |
+
# 1. Load variables
|
50 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
51 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
52 |
+
C.GPU_NUM = str(gpu_num)
|
53 |
+
|
54 |
+
if batch_size != 1 and acc_gradients != 1:
|
55 |
+
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
56 |
+
|
57 |
+
if freeze_layers is not None:
|
58 |
+
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
59 |
+
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
60 |
+
freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
|
61 |
+
if len(freeze_layers) > 1:
|
62 |
+
freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
|
63 |
+
|
64 |
+
os.makedirs(output_folder, exist_ok=True)
|
65 |
+
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
66 |
+
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
67 |
+
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
68 |
+
C.VALIDATION_DATASET = validation_folder
|
69 |
+
C.ACCUM_GRADIENT_STEP = acc_gradients
|
70 |
+
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
|
71 |
+
C.EARLY_STOP_PATIENCE = early_stop_patience
|
72 |
+
C.LEARNING_RATE = lr
|
73 |
+
C.LIMIT_NUM_SAMPLES = None
|
74 |
+
C.EPOCHS = max_epochs
|
75 |
+
|
76 |
+
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
|
77 |
+
"GPU: {}\n" \
|
78 |
+
"BATCH SIZE: {}\n" \
|
79 |
+
"LR: {}\n" \
|
80 |
+
"SIMILARITY: {}\n" \
|
81 |
+
"SEGMENTATION: {}\n"\
|
82 |
+
"REG. WEIGHT: {}\n" \
|
83 |
+
"EPOCHS: {:d}\n" \
|
84 |
+
"ACCUM. GRAD: {}\n" \
|
85 |
+
"EARLY STOP PATIENCE: {}\n" \
|
86 |
+
"FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
|
87 |
+
C.TRAINING_DATASET,
|
88 |
+
C.VALIDATION_DATASET,
|
89 |
+
C.GPU_NUM,
|
90 |
+
C.BATCH_SIZE,
|
91 |
+
C.LEARNING_RATE,
|
92 |
+
simil,
|
93 |
+
segm,
|
94 |
+
rw,
|
95 |
+
C.EPOCHS,
|
96 |
+
C.ACCUM_GRADIENT_STEP,
|
97 |
+
C.EARLY_STOP_PATIENCE,
|
98 |
+
freeze_layers)
|
99 |
+
|
100 |
+
log_file.write(aux)
|
101 |
+
print(aux)
|
102 |
+
|
103 |
+
# 2. Data generator
|
104 |
+
used_labels = 'all'
|
105 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
106 |
+
C.TRAINING_PERC, labels=[used_labels], combine_segmentations=False,
|
107 |
+
directory_val=C.VALIDATION_DATASET)
|
108 |
+
|
109 |
+
train_generator = data_generator.get_train_generator()
|
110 |
+
validation_generator = data_generator.get_validation_generator()
|
111 |
+
|
112 |
+
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
113 |
+
image_output_shape = [image_size] * 3
|
114 |
+
nb_labels = len(train_generator.get_segmentation_labels())
|
115 |
+
|
116 |
+
# 3. Load model
|
117 |
+
# IMPORTANT: the mode MUST be loaded AFTER setting up the session configuration
|
118 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
119 |
+
config.gpu_options.allow_growth = True
|
120 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
121 |
+
sess = tf.Session(config=config)
|
122 |
+
tf.keras.backend.set_session(sess)
|
123 |
+
|
124 |
+
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
125 |
+
NCC(image_input_shape).loss,
|
126 |
+
vxm.losses.MSE().loss,
|
127 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss,
|
128 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss,
|
129 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss,
|
130 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
|
131 |
+
]
|
132 |
+
|
133 |
+
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
134 |
+
NCC(image_input_shape).metric,
|
135 |
+
vxm.losses.MSE().loss,
|
136 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
|
137 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
138 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
139 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
|
140 |
+
|
141 |
+
|
142 |
+
try:
|
143 |
+
network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
144 |
+
'VxmDense': vxm.networks.VxmDense,
|
145 |
+
'AdamAccumulated': AdamAccumulated,
|
146 |
+
'loss': loss_fncs,
|
147 |
+
'metric': metric_fncs},
|
148 |
+
compile=False)
|
149 |
+
except ValueError as e:
|
150 |
+
# enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
151 |
+
# dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
152 |
+
enc_features = unet # const.ENCODER_FILTERS
|
153 |
+
dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
|
154 |
+
nb_features = [enc_features, dec_features]
|
155 |
+
|
156 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
157 |
+
nb_labels=nb_labels,
|
158 |
+
nb_unet_features=nb_features,
|
159 |
+
int_steps=0,
|
160 |
+
int_downsize=1,
|
161 |
+
seg_downsize=1)
|
162 |
+
|
163 |
+
if model_file != '':
|
164 |
+
network.load_weights(model_file, by_name=True)
|
165 |
+
print('MODEL LOCATION: ', model_file)
|
166 |
+
# 4. Freeze/unfreeze model layers
|
167 |
+
# freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
|
168 |
+
# for l in freeze_layers:
|
169 |
+
# network.layers[l].trainable = False
|
170 |
+
# msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
|
171 |
+
# print(msg)
|
172 |
+
# log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
|
173 |
+
if freeze_layers is not None:
|
174 |
+
aux = list()
|
175 |
+
for r in freeze_layers:
|
176 |
+
for l in range(*r):
|
177 |
+
network.layers[l].trainable = False
|
178 |
+
aux.append(l)
|
179 |
+
aux.sort()
|
180 |
+
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
|
181 |
+
else:
|
182 |
+
msg = "[INF] None frozen layers"
|
183 |
+
print(msg)
|
184 |
+
log_file.write(msg)
|
185 |
+
# network.trainable = False # Freeze the base model
|
186 |
+
# # Create a new model on top
|
187 |
+
# input_new_model = keras.Input(network.input_shape)
|
188 |
+
# x = base_model(input_new_model, training=False)
|
189 |
+
# x =
|
190 |
+
# network = keras.Model(input_new_model, x)
|
191 |
+
|
192 |
+
network.summary()
|
193 |
+
network.summary(print_fn=log_file.writelines)
|
194 |
+
# Complete the model with the augmentation layer
|
195 |
+
augm_train_input_shape = train_generator.get_data_shape()[0]
|
196 |
+
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
197 |
+
augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
198 |
+
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
199 |
+
max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
200 |
+
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
|
201 |
+
num_augmentations=COMET_C.NUM_AUGMENTATIONS,
|
202 |
+
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
|
203 |
+
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
204 |
+
in_img_shape=image_input_shape,
|
205 |
+
out_img_shape=image_output_shape,
|
206 |
+
only_image=False, # If baseline then True
|
207 |
+
only_resize=False,
|
208 |
+
trainable=False)
|
209 |
+
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
|
210 |
+
|
211 |
+
input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid')
|
212 |
+
augm_layer_valid = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
213 |
+
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
214 |
+
max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
215 |
+
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
|
216 |
+
num_augmentations=COMET_C.NUM_AUGMENTATIONS,
|
217 |
+
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
|
218 |
+
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
219 |
+
in_img_shape=image_input_shape,
|
220 |
+
out_img_shape=image_output_shape,
|
221 |
+
only_image=False,
|
222 |
+
only_resize=False,
|
223 |
+
trainable=False)
|
224 |
+
augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
|
225 |
+
|
226 |
+
# 5. Setup training environment: loss, optimizer, callbacks, evaluation
|
227 |
+
|
228 |
+
# Losses and loss weights
|
229 |
+
SSIM_KER_SIZE = 5
|
230 |
+
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
231 |
+
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
|
232 |
+
if simil.lower() == 'mse':
|
233 |
+
loss_fnc = vxm.losses.MSE().loss
|
234 |
+
elif simil.lower() == 'ncc':
|
235 |
+
loss_fnc = NCC(image_input_shape).loss
|
236 |
+
elif simil.lower() == 'ssim':
|
237 |
+
loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
|
238 |
+
elif simil.lower() == 'ms_ssim':
|
239 |
+
loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
|
240 |
+
elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse':
|
241 |
+
@function_decorator('MSSSIM_MSE__loss')
|
242 |
+
def loss_fnc(y_true, y_pred):
|
243 |
+
return vxm.losses.MSE().loss(y_true, y_pred) + \
|
244 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
|
245 |
+
elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc':
|
246 |
+
@function_decorator('MSSSIM_NCC__loss')
|
247 |
+
def loss_fnc(y_true, y_pred):
|
248 |
+
return NCC(image_input_shape).loss(y_true, y_pred) + \
|
249 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
|
250 |
+
elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse':
|
251 |
+
@function_decorator('SSIM_MSE__loss')
|
252 |
+
def loss_fnc(y_true, y_pred):
|
253 |
+
return vxm.losses.MSE().loss(y_true, y_pred) + \
|
254 |
+
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
|
255 |
+
elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc':
|
256 |
+
@function_decorator('SSIM_NCC__loss')
|
257 |
+
def loss_fnc(y_true, y_pred):
|
258 |
+
return NCC(image_input_shape).loss(y_true, y_pred) + \
|
259 |
+
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
|
260 |
+
else:
|
261 |
+
raise ValueError('Unknown similarity metric: ' + simil)
|
262 |
+
|
263 |
+
if segm == 'hd':
|
264 |
+
loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
|
265 |
+
elif segm == 'dice':
|
266 |
+
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
|
267 |
+
elif segm == 'dice_macro':
|
268 |
+
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
|
269 |
+
else:
|
270 |
+
raise ValueError('No valid value for segm')
|
271 |
+
|
272 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
273 |
+
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
274 |
+
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
275 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
|
276 |
+
update_freq='epoch', # or 'batch' or integer
|
277 |
+
write_graph=True, write_grads=True
|
278 |
+
)
|
279 |
+
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
280 |
+
|
281 |
+
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
282 |
+
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
283 |
+
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
284 |
+
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
285 |
+
|
286 |
+
losses = {'transformer': loss_fnc,
|
287 |
+
'seg_transformer': loss_segm,
|
288 |
+
'flow': vxm.losses.Grad('l2').loss}
|
289 |
+
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
290 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
|
291 |
+
tf.keras.losses.MSE,
|
292 |
+
NCC(image_input_shape).metric],
|
293 |
+
'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
|
294 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric,
|
295 |
+
GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
|
296 |
+
],
|
297 |
+
#'flow': vxm.losses.Grad('l2').loss
|
298 |
+
}
|
299 |
+
loss_weights = {'transformer': 1.,
|
300 |
+
'seg_transformer': 1.,
|
301 |
+
'flow': rw}
|
302 |
+
|
303 |
+
|
304 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
305 |
+
network.compile(optimizer=optimizer,
|
306 |
+
loss=losses,
|
307 |
+
loss_weights=loss_weights,
|
308 |
+
metrics=metrics)
|
309 |
+
|
310 |
+
# 6. Training loop
|
311 |
+
callback_tensorboard.set_model(network)
|
312 |
+
callback_early_stop.set_model(network)
|
313 |
+
callback_best_model.set_model(network)
|
314 |
+
callback_save_checkpoint.set_model(network)
|
315 |
+
|
316 |
+
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
317 |
+
names = network.metrics_names
|
318 |
+
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
319 |
+
|
320 |
+
with sess.as_default():
|
321 |
+
# tf.global_variables_initializer()
|
322 |
+
callback_tensorboard.on_train_begin()
|
323 |
+
callback_early_stop.on_train_begin()
|
324 |
+
callback_best_model.on_train_begin()
|
325 |
+
callback_save_checkpoint.on_train_begin()
|
326 |
+
|
327 |
+
for epoch in range(C.EPOCHS):
|
328 |
+
callback_tensorboard.on_epoch_begin(epoch)
|
329 |
+
callback_early_stop.on_epoch_begin(epoch)
|
330 |
+
callback_best_model.on_epoch_begin(epoch)
|
331 |
+
callback_save_checkpoint.on_epoch_begin(epoch)
|
332 |
+
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
333 |
+
print("TRAIN")
|
334 |
+
|
335 |
+
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
336 |
+
progress_bar = Progbar(len(train_generator), width=30, verbose=1)
|
337 |
+
for step, (in_batch, _) in enumerate(train_generator, 1):
|
338 |
+
callback_best_model.on_train_batch_begin(step)
|
339 |
+
callback_save_checkpoint.on_train_batch_begin(step)
|
340 |
+
callback_early_stop.on_train_batch_begin(step)
|
341 |
+
|
342 |
+
try:
|
343 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
|
344 |
+
np.nan_to_num(fix_img, copy=False)
|
345 |
+
np.nan_to_num(mov_img, copy=False)
|
346 |
+
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
|
347 |
+
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
|
348 |
+
np.unique(mov_img))
|
349 |
+
print(msg)
|
350 |
+
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
|
351 |
+
|
352 |
+
except InvalidArgumentError as err:
|
353 |
+
print('TF Error : {}'.format(str(err)))
|
354 |
+
continue
|
355 |
+
|
356 |
+
in_data = (mov_img, fix_img, mov_seg)
|
357 |
+
out_data = (fix_img, fix_img, fix_seg)
|
358 |
+
|
359 |
+
ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
|
360 |
+
if np.isnan(ret).any():
|
361 |
+
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
|
362 |
+
save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
|
363 |
+
save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
|
364 |
+
pred_img, dm = network((mov_img, fix_img))
|
365 |
+
save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
|
366 |
+
save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
|
367 |
+
log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
368 |
+
raise ValueError('CORRUPTION ERROR: Halting training')
|
369 |
+
|
370 |
+
summary.on_train_batch_end(ret)
|
371 |
+
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
372 |
+
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
|
373 |
+
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
374 |
+
progress_bar.update(step, zip(names, ret))
|
375 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
376 |
+
val_values = progress_bar._values.copy()
|
377 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
378 |
+
|
379 |
+
print('\nVALIDATION')
|
380 |
+
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
381 |
+
progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
|
382 |
+
for step, (in_batch, _) in enumerate(validation_generator, 1):
|
383 |
+
try:
|
384 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model_valid.predict(in_batch)
|
385 |
+
except InvalidArgumentError as err:
|
386 |
+
print('TF Error : {}'.format(str(err)))
|
387 |
+
continue
|
388 |
+
|
389 |
+
in_data = (mov_img, fix_img, mov_seg)
|
390 |
+
out_data = (fix_img, fix_img, fix_seg)
|
391 |
+
|
392 |
+
ret = network.test_on_batch(x=in_data,
|
393 |
+
y=out_data)
|
394 |
+
|
395 |
+
summary.on_validation_batch_end(ret)
|
396 |
+
progress_bar.update(step, zip(names, ret))
|
397 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
398 |
+
val_values = progress_bar._values.copy()
|
399 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
400 |
+
|
401 |
+
train_generator.on_epoch_end()
|
402 |
+
validation_generator.on_epoch_end()
|
403 |
+
epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
|
404 |
+
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
405 |
+
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
406 |
+
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
407 |
+
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
408 |
+
print('End of epoch {}: '.format(epoch), ret, '\n')
|
409 |
+
|
410 |
+
callback_tensorboard.on_train_end()
|
411 |
+
callback_best_model.on_train_end()
|
412 |
+
callback_save_checkpoint.on_train_end()
|
413 |
+
callback_early_stop.on_train_end()
|
414 |
+
# 7. Wrap up
|