|
import os, sys |
|
|
|
import keras |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
from datetime import datetime |
|
|
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau |
|
from tensorflow.python.keras.utils import Progbar |
|
from tensorflow.keras import Input |
|
from tensorflow.keras.models import Model |
|
from tensorflow.python.framework.errors import InvalidArgumentError |
|
|
|
import ddmr.utils.constants as C |
|
from ddmr.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion |
|
from ddmr.ms_ssim_tf import MultiScaleStructuralSimilarity |
|
from ddmr.ms_ssim_tf import _MSSSIM_WEIGHTS |
|
from ddmr.utils.acummulated_optimizer import AdamAccumulated |
|
from ddmr.utils.misc import function_decorator |
|
from ddmr.layers import AugmentationLayer |
|
from ddmr.utils.nifti_utils import save_nifti |
|
|
|
from Brain_study.data_generator import BatchGenerator |
|
from Brain_study.utils import SummaryDictionary, named_logs |
|
|
|
import COMET.augmentation_constants as COMET_C |
|
from COMET.utils import freeze_layers_by_group |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
import voxelmorph as vxm |
|
import h5py |
|
import re |
|
import itertools |
|
import warnings |
|
|
|
|
|
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', |
|
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None, |
|
acc_gradients=1, batch_size=16, image_size=64, |
|
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None): |
|
|
|
assert dataset_folder is not None and output_folder is not None |
|
if model_file != '': |
|
assert '.h5' in model_file, 'The model must be an H5 file' |
|
|
|
|
|
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) |
|
C.GPU_NUM = str(gpu_num) |
|
|
|
if batch_size != 1 and acc_gradients != 1: |
|
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!') |
|
|
|
if resume is not None: |
|
try: |
|
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume |
|
output_folder = resume |
|
resume = True |
|
except AssertionError: |
|
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")) |
|
resume = False |
|
else: |
|
resume = False |
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
log_file = open(os.path.join(output_folder, 'log.txt'), 'w') |
|
C.TRAINING_DATASET = dataset_folder |
|
C.VALIDATION_DATASET = validation_folder |
|
C.ACCUM_GRADIENT_STEP = acc_gradients |
|
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1 |
|
C.EARLY_STOP_PATIENCE = early_stop_patience |
|
C.LEARNING_RATE = lr |
|
C.LIMIT_NUM_SAMPLES = None |
|
C.EPOCHS = max_epochs |
|
|
|
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \ |
|
"GPU: {}\n" \ |
|
"BATCH SIZE: {}\n" \ |
|
"LR: {}\n" \ |
|
"SIMILARITY: {}\n" \ |
|
"SEGMENTATION: {}\n"\ |
|
"REG. WEIGHT: {}\n" \ |
|
"EPOCHS: {:d}\n" \ |
|
"ACCUM. GRAD: {}\n" \ |
|
"EARLY STOP PATIENCE: {}\n" \ |
|
"FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), |
|
C.TRAINING_DATASET, |
|
C.VALIDATION_DATASET, |
|
C.GPU_NUM, |
|
C.BATCH_SIZE, |
|
C.LEARNING_RATE, |
|
simil, |
|
segm, |
|
rw, |
|
C.EPOCHS, |
|
C.ACCUM_GRADIENT_STEP, |
|
C.EARLY_STOP_PATIENCE, |
|
freeze_layers) |
|
|
|
log_file.write(aux) |
|
print(aux) |
|
|
|
|
|
used_labels = [0, 1, 2] |
|
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True, |
|
C.TRAINING_PERC, labels=used_labels, combine_segmentations=False, |
|
directory_val=C.VALIDATION_DATASET) |
|
|
|
train_generator = data_generator.get_train_generator() |
|
validation_generator = data_generator.get_validation_generator() |
|
|
|
image_input_shape = train_generator.get_data_shape()[-1][:-1] |
|
image_output_shape = [image_size] * 3 |
|
nb_labels = len(train_generator.get_segmentation_labels()) |
|
|
|
|
|
|
|
config = tf.compat.v1.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
config.log_device_placement = False |
|
sess = tf.Session(config=config) |
|
tf.keras.backend.set_session(sess) |
|
|
|
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss, |
|
NCC(image_input_shape).loss, |
|
vxm.losses.MSE().loss, |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss, |
|
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss, |
|
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss, |
|
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro |
|
] |
|
|
|
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric, |
|
NCC(image_input_shape).metric, |
|
vxm.losses.MSE().loss, |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric, |
|
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric, |
|
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric, |
|
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,] |
|
|
|
|
|
try: |
|
network = tf.keras.models.load_model(model_file, { |
|
'VxmDense': vxm.networks.VxmDense, |
|
'AdamAccumulated': AdamAccumulated, |
|
'loss': loss_fncs, |
|
'metric': metric_fncs}, |
|
compile=False) |
|
except ValueError as e: |
|
|
|
|
|
enc_features = unet |
|
dec_features = enc_features[::-1] + head |
|
nb_features = [enc_features, dec_features] |
|
|
|
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape, |
|
nb_labels=nb_labels, |
|
nb_unet_features=nb_features, |
|
int_steps=0, |
|
int_downsize=1, |
|
seg_downsize=1) |
|
|
|
if model_file != '': |
|
network.load_weights(model_file, by_name=True) |
|
print('MODEL LOCATION: ', model_file) |
|
|
|
resume_epoch = 0 |
|
if resume: |
|
cp_dir = os.path.join(output_folder, 'checkpoints') |
|
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))] |
|
if len(cp_file_list): |
|
cp_file_list.sort() |
|
checkpoint_file = cp_file_list[-1] |
|
if os.path.exists(checkpoint_file): |
|
network.load_weights(checkpoint_file, by_name=True) |
|
print('Loaded checkpoint file: ' + checkpoint_file) |
|
try: |
|
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1]) |
|
except TypeError: |
|
|
|
resume_epoch = 0 |
|
print('Resuming from epoch: {:d}'.format(resume_epoch)) |
|
else: |
|
warnings.warn('Checkpoint file NOT found. Training from scratch') |
|
|
|
|
|
_, frozen_layers = freeze_layers_by_group(network, freeze_layers) |
|
if frozen_layers is not None: |
|
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers])) |
|
else: |
|
msg = "[INF] None frozen layers" |
|
print(msg) |
|
log_file.write(msg) |
|
|
|
network.summary(line_length=C.SUMMARY_LINE_LENGTH) |
|
network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines) |
|
|
|
augm_train_input_shape = train_generator.get_data_shape()[0] |
|
input_layer_train = Input(shape=augm_train_input_shape, name='input_train') |
|
augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, |
|
max_deformation=COMET_C.MAX_AUG_DEF, |
|
max_rotation=COMET_C.MAX_AUG_ANGLE, |
|
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=COMET_C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=False, |
|
only_resize=False, |
|
trainable=False) |
|
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train)) |
|
|
|
input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid') |
|
augm_layer_valid = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, |
|
max_deformation=COMET_C.MAX_AUG_DEF, |
|
max_rotation=COMET_C.MAX_AUG_ANGLE, |
|
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=COMET_C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=False, |
|
only_resize=False, |
|
trainable=False) |
|
augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid)) |
|
|
|
|
|
|
|
|
|
SSIM_KER_SIZE = 5 |
|
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3] |
|
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS) |
|
if simil.lower() == 'mse': |
|
loss_fnc = vxm.losses.MSE().loss |
|
elif simil.lower() == 'ncc': |
|
loss_fnc = NCC(image_input_shape).loss |
|
elif simil.lower() == 'ssim': |
|
loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss |
|
elif simil.lower() == 'ms_ssim': |
|
loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss |
|
elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse': |
|
@function_decorator('MSSSIM_MSE__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return vxm.losses.MSE().loss(y_true, y_pred) + \ |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) |
|
elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc': |
|
@function_decorator('MSSSIM_NCC__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return NCC(image_input_shape).loss(y_true, y_pred) + \ |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) |
|
elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse': |
|
@function_decorator('SSIM_MSE__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return vxm.losses.MSE().loss(y_true, y_pred) + \ |
|
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) |
|
elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc': |
|
@function_decorator('SSIM_NCC__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return NCC(image_input_shape).loss(y_true, y_pred) + \ |
|
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) |
|
else: |
|
raise ValueError('Unknown similarity metric: ' + simil) |
|
|
|
if segm == 'hd': |
|
loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss |
|
elif segm == 'dice': |
|
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss |
|
elif segm == 'dice_macro': |
|
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro |
|
else: |
|
raise ValueError('No valid value for segm') |
|
|
|
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True) |
|
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True) |
|
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'), |
|
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0, |
|
update_freq='epoch', |
|
write_graph=True, write_grads=True |
|
) |
|
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001) |
|
|
|
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'), |
|
save_best_only=True, monitor='val_loss', verbose=1, mode='min') |
|
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'), |
|
save_weights_only=True, monitor='val_loss', verbose=0, mode='min') |
|
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10) |
|
|
|
losses = {'transformer': loss_fnc, |
|
'seg_transformer': loss_segm, |
|
'flow': vxm.losses.Grad('l2').loss} |
|
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric, |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric, |
|
tf.keras.losses.MSE, |
|
NCC(image_input_shape).metric], |
|
'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric, |
|
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric, |
|
GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro, |
|
], |
|
|
|
} |
|
loss_weights = {'transformer': 1., |
|
'seg_transformer': 1., |
|
'flow': rw} |
|
|
|
optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE) |
|
network.compile(optimizer=optimizer, |
|
loss=losses, |
|
loss_weights=loss_weights, |
|
metrics=metrics) |
|
|
|
|
|
callback_tensorboard.set_model(network) |
|
callback_early_stop.set_model(network) |
|
callback_best_model.set_model(network) |
|
callback_save_checkpoint.set_model(network) |
|
callback_lr.set_model(network) |
|
|
|
summary = SummaryDictionary(network, C.BATCH_SIZE) |
|
names = network.metrics_names |
|
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
|
|
|
with sess.as_default(): |
|
|
|
callback_tensorboard.on_train_begin() |
|
callback_early_stop.on_train_begin() |
|
callback_best_model.on_train_begin() |
|
callback_save_checkpoint.on_train_begin() |
|
callback_lr.on_train_begin() |
|
|
|
for epoch in range(resume_epoch, C.EPOCHS): |
|
callback_tensorboard.on_epoch_begin(epoch) |
|
callback_early_stop.on_epoch_begin(epoch) |
|
callback_best_model.on_epoch_begin(epoch) |
|
callback_save_checkpoint.on_epoch_begin(epoch) |
|
callback_lr.on_epoch_begin(epoch) |
|
|
|
print("\nEpoch {}/{}".format(epoch, C.EPOCHS)) |
|
print("TRAIN") |
|
|
|
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
|
progress_bar = Progbar(len(train_generator), width=30, verbose=1) |
|
for step, (in_batch, _) in enumerate(train_generator, 1): |
|
callback_best_model.on_train_batch_begin(step) |
|
callback_save_checkpoint.on_train_batch_begin(step) |
|
callback_early_stop.on_train_batch_begin(step) |
|
callback_lr.on_train_batch_begin(step) |
|
|
|
try: |
|
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch) |
|
np.nan_to_num(fix_img, copy=False) |
|
np.nan_to_num(mov_img, copy=False) |
|
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)): |
|
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img), |
|
np.unique(mov_img)) |
|
print(msg) |
|
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg)) |
|
|
|
except InvalidArgumentError as err: |
|
print('TF Error : {}'.format(str(err))) |
|
continue |
|
|
|
in_data = (mov_img, fix_img, mov_seg) |
|
out_data = (fix_img, fix_img, fix_seg) |
|
|
|
ret = network.train_on_batch(x=in_data, y=out_data) |
|
if np.isnan(ret).any(): |
|
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True) |
|
save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz')) |
|
save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz')) |
|
pred_img, dm = network((mov_img, fix_img)) |
|
save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz')) |
|
save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz')) |
|
log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
|
raise ValueError('CORRUPTION ERROR: Halting training') |
|
|
|
summary.on_train_batch_end(ret) |
|
callback_best_model.on_train_batch_end(step, named_logs(network, ret)) |
|
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret)) |
|
callback_early_stop.on_train_batch_end(step, named_logs(network, ret)) |
|
callback_lr.on_predict_batch_end(step, named_logs(network, ret)) |
|
|
|
progress_bar.update(step, zip(names, ret)) |
|
log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
|
val_values = progress_bar._values.copy() |
|
ret = [val_values[x][0]/val_values[x][1] for x in names] |
|
|
|
print('\nVALIDATION') |
|
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
|
progress_bar = Progbar(len(validation_generator), width=30, verbose=1) |
|
for step, (in_batch, _) in enumerate(validation_generator, 1): |
|
try: |
|
fix_img, mov_img, fix_seg, mov_seg = augm_model_valid.predict(in_batch) |
|
except InvalidArgumentError as err: |
|
print('TF Error : {}'.format(str(err))) |
|
continue |
|
|
|
in_data = (mov_img, fix_img, mov_seg) |
|
out_data = (fix_img, fix_img, fix_seg) |
|
|
|
ret = network.test_on_batch(x=in_data, |
|
y=out_data) |
|
|
|
summary.on_validation_batch_end(ret) |
|
progress_bar.update(step, zip(names, ret)) |
|
log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
|
val_values = progress_bar._values.copy() |
|
ret = [val_values[x][0]/val_values[x][1] for x in names] |
|
|
|
train_generator.on_epoch_end() |
|
validation_generator.on_epoch_end() |
|
epoch_summary = summary.on_epoch_end() |
|
callback_tensorboard.on_epoch_end(epoch, epoch_summary) |
|
callback_best_model.on_epoch_end(epoch, epoch_summary) |
|
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary) |
|
callback_early_stop.on_epoch_end(epoch, epoch_summary) |
|
callback_lr.on_epoch_end(epoch,epoch_summary) |
|
|
|
print('End of epoch {}: '.format(epoch), ret, '\n') |
|
|
|
callback_tensorboard.on_train_end() |
|
callback_best_model.on_train_end() |
|
callback_save_checkpoint.on_train_end() |
|
callback_early_stop.on_train_end() |
|
callback_lr.on_train_end() |
|
|
|
|