|
import os, sys |
|
import warnings |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau |
|
from tensorflow.keras import Input |
|
from tensorflow.keras.models import Model |
|
from tensorflow.python.keras.utils import Progbar |
|
from tensorflow.python.framework.errors import InvalidArgumentError |
|
|
|
import voxelmorph as vxm |
|
import neurite as ne |
|
import h5py |
|
import pickle |
|
|
|
import ddmr.utils.constants as C |
|
from ddmr.losses import NCC, StructuralSimilarity, StructuralSimilarity_simplified |
|
from ddmr.utils.misc import try_mkdir, DatasetCopy, function_decorator |
|
from ddmr.utils.acummulated_optimizer import AdamAccumulated |
|
from ddmr.layers import AugmentationLayer |
|
from ddmr.utils.nifti_utils import save_nifti |
|
from ddmr.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS |
|
|
|
from Brain_study.data_generator import BatchGenerator |
|
from Brain_study.utils import SummaryDictionary, named_logs |
|
|
|
from tqdm import tqdm |
|
from datetime import datetime |
|
import re |
|
|
|
|
|
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', |
|
acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64, |
|
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None): |
|
assert dataset_folder is not None and output_folder is not None |
|
|
|
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) |
|
C.GPU_NUM = str(gpu_num) |
|
|
|
if batch_size != 1 and acc_gradients != 1: |
|
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!') |
|
|
|
if resume is not None: |
|
try: |
|
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume |
|
output_folder = resume |
|
resume = True |
|
except AssertionError: |
|
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")) |
|
resume = False |
|
else: |
|
resume = False |
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
log_file = open(os.path.join(output_folder, 'log.txt'), 'w') |
|
C.TRAINING_DATASET = dataset_folder |
|
C.VALIDATION_DATASET = validation_folder |
|
C.ACCUM_GRADIENT_STEP = acc_gradients |
|
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1 |
|
C.EARLY_STOP_PATIENCE = early_stop_patience |
|
C.LEARNING_RATE = lr |
|
C.LIMIT_NUM_SAMPLES = None |
|
C.EPOCHS = max_epochs |
|
|
|
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \ |
|
"GPU: {}\n" \ |
|
"BATCH SIZE: {}\n" \ |
|
"LR: {}\n" \ |
|
"SIMILARITY: {}\n" \ |
|
"REG. WEIGHT: {}\n" \ |
|
"EPOCHS: {:d}\n" \ |
|
"ACCUM. GRAD: {}\n" \ |
|
"EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), |
|
C.TRAINING_DATASET, |
|
C.VALIDATION_DATASET, |
|
C.GPU_NUM, |
|
C.BATCH_SIZE, |
|
C.LEARNING_RATE, |
|
simil, |
|
rw, |
|
C.EPOCHS, |
|
C.ACCUM_GRADIENT_STEP, |
|
C.EARLY_STOP_PATIENCE) |
|
|
|
log_file.write(aux) |
|
print(aux) |
|
|
|
|
|
|
|
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True, |
|
C.TRAINING_PERC, True, ['none'], directory_val=C.VALIDATION_DATASET) |
|
|
|
train_generator = data_generator.get_train_generator() |
|
validation_generator = data_generator.get_validation_generator() |
|
|
|
image_input_shape = train_generator.get_data_shape()[-1][:-1] |
|
image_output_shape = [image_size] * 3 |
|
|
|
|
|
config = tf.compat.v1.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
config.log_device_placement = False |
|
sess = tf.Session(config=config) |
|
tf.keras.backend.set_session(sess) |
|
|
|
|
|
input_layer_train = Input(shape=train_generator.get_data_shape()[-1], name='input_train') |
|
augm_layer_train = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, |
|
max_deformation=C.MAX_AUG_DEF, |
|
max_rotation=C.MAX_AUG_ANGLE, |
|
num_control_points=C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=True, |
|
only_resize=False, |
|
trainable=False) |
|
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train)) |
|
|
|
input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid') |
|
augm_layer_valid = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, |
|
max_deformation=C.MAX_AUG_DEF, |
|
max_rotation=C.MAX_AUG_ANGLE, |
|
num_control_points=C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=False, |
|
only_resize=False, |
|
trainable=False) |
|
augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid)) |
|
|
|
|
|
|
|
|
|
enc_features = unet |
|
dec_features = enc_features[::-1] + head |
|
nb_features = [enc_features, dec_features] |
|
network = vxm.networks.VxmDense(inshape=image_output_shape, |
|
nb_unet_features=nb_features, |
|
int_steps=0) |
|
network.summary(line_length=150) |
|
|
|
resume_epoch = 0 |
|
if resume: |
|
cp_dir = os.path.join(output_folder, 'checkpoints') |
|
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))] |
|
if len(cp_file_list): |
|
cp_file_list.sort() |
|
checkpoint_file = cp_file_list[-1] |
|
if os.path.exists(checkpoint_file): |
|
network.load_weights(checkpoint_file, by_name=True) |
|
print('Loaded checkpoint file: ' + checkpoint_file) |
|
try: |
|
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1]) |
|
except TypeError: |
|
|
|
resume_epoch = 0 |
|
print('Resuming from epoch: {:d}'.format(resume_epoch)) |
|
else: |
|
warnings.warn('Checkpoint file NOT found. Training from scratch') |
|
|
|
|
|
SSIM_KER_SIZE = 5 |
|
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3] |
|
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS) |
|
if simil.lower() == 'mse': |
|
loss_fnc = vxm.losses.MSE().loss |
|
elif simil.lower() == 'ncc': |
|
loss_fnc = NCC(image_input_shape).loss |
|
elif simil.lower() == 'ssim': |
|
loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss |
|
elif simil.lower() == 'ms_ssim': |
|
loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss |
|
elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse': |
|
@function_decorator('MSSSIM_MSE__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return vxm.losses.MSE().loss(y_true, y_pred) +\ |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) |
|
elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc': |
|
@function_decorator('MSSSIM_NCC__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return NCC(image_input_shape).loss(y_true, y_pred) +\ |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) |
|
elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse': |
|
@function_decorator('SSIM_MSE__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return vxm.losses.MSE().loss(y_true, y_pred) +\ |
|
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) |
|
elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc': |
|
@function_decorator('SSIM_NCC__loss') |
|
def loss_fnc(y_true, y_pred): |
|
return NCC(image_input_shape).loss(y_true, y_pred) +\ |
|
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) |
|
else: |
|
raise ValueError('Unknown similarity metric: ' + simil) |
|
|
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True) |
|
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True) |
|
os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True) |
|
|
|
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'), |
|
save_best_only=True, monitor='val_loss', verbose=1, mode='min') |
|
callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'), |
|
save_weights_only=True, monitor='val_loss', verbose=0, mode='min') |
|
|
|
|
|
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'), |
|
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0, |
|
update_freq='epoch', |
|
write_graph=True, write_grads=True |
|
) |
|
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001) |
|
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10) |
|
|
|
losses = {'transformer': loss_fnc, |
|
'flow': vxm.losses.Grad('l2').loss} |
|
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric, |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric, |
|
tf.keras.losses.MSE, |
|
NCC(image_input_shape).metric], |
|
|
|
} |
|
loss_weights = {'transformer': 1., |
|
'flow': rw} |
|
|
|
|
|
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE) |
|
network.compile(optimizer=optimizer, |
|
loss=losses, |
|
loss_weights=loss_weights, |
|
metrics=metrics) |
|
|
|
callback_tensorboard.set_model(network) |
|
callback_best_model.set_model(network) |
|
callback_save_model.set_model(network) |
|
callback_early_stop.set_model(network) |
|
callback_lr.set_model(network) |
|
|
|
|
|
summary = SummaryDictionary(network, C.BATCH_SIZE) |
|
names = network.metrics_names |
|
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
|
with sess.as_default(): |
|
callback_tensorboard.on_train_begin() |
|
callback_early_stop.on_train_begin() |
|
callback_best_model.on_train_begin() |
|
callback_save_model.on_train_begin() |
|
for epoch in range(resume_epoch, C.EPOCHS): |
|
callback_tensorboard.on_epoch_begin(epoch) |
|
callback_early_stop.on_epoch_begin(epoch) |
|
callback_best_model.on_epoch_begin(epoch) |
|
callback_save_model.on_epoch_begin(epoch) |
|
callback_lr.on_epoch_begin(epoch) |
|
print("\nEpoch {}/{}".format(epoch, C.EPOCHS)) |
|
print('TRAINING') |
|
|
|
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
|
progress_bar = Progbar(len(train_generator), width=30, verbose=1) |
|
for step, (in_batch, _) in enumerate(train_generator, 1): |
|
|
|
callback_best_model.on_train_batch_begin(step) |
|
callback_save_model.on_train_batch_begin(step) |
|
callback_early_stop.on_train_batch_begin(step) |
|
callback_lr.on_train_batch_begin(step) |
|
try: |
|
fix_img, mov_img, *_ = augm_model_train.predict(in_batch) |
|
np.nan_to_num(fix_img, copy=False) |
|
np.nan_to_num(mov_img, copy=False) |
|
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)): |
|
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img), |
|
np.unique(mov_img)) |
|
print(msg) |
|
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg)) |
|
|
|
except InvalidArgumentError as err: |
|
print('TF Error : {}'.format(str(err))) |
|
continue |
|
|
|
ret = network.train_on_batch(x=(mov_img, fix_img), |
|
y=(fix_img, fix_img)) |
|
if np.isnan(ret).any(): |
|
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True) |
|
save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz')) |
|
save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz')) |
|
pred_img, dm = network((mov_img, fix_img)) |
|
save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz')) |
|
save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz')) |
|
log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
|
raise ValueError('CORRUPTION ERROR: Halting training') |
|
|
|
summary.on_train_batch_end(ret) |
|
|
|
callback_best_model.on_train_batch_end(step, named_logs(network, ret)) |
|
callback_save_model.on_train_batch_end(step, named_logs(network, ret)) |
|
callback_early_stop.on_train_batch_end(step, named_logs(network, ret)) |
|
callback_lr.on_train_batch_end(step, named_logs(network, ret)) |
|
progress_bar.update(step, zip(names, ret)) |
|
log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
|
val_values = progress_bar._values.copy() |
|
ret = [val_values[x][0]/val_values[x][1] for x in names] |
|
|
|
print('\nVALIDATION') |
|
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
|
progress_bar = Progbar(len(validation_generator), width=30, verbose=1) |
|
for step, (in_batch, _) in enumerate(validation_generator, 1): |
|
|
|
|
|
try: |
|
fix_img, mov_img, *_ = augm_model_valid.predict(in_batch) |
|
except InvalidArgumentError as err: |
|
print('TF Error : {}'.format(str(err))) |
|
continue |
|
|
|
ret = network.test_on_batch(x=(mov_img, fix_img), |
|
y=(fix_img, fix_img)) |
|
|
|
summary.on_validation_batch_end(ret) |
|
|
|
|
|
progress_bar.update(step, zip(names, ret)) |
|
log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
|
val_values = progress_bar._values.copy() |
|
ret = [val_values[x][0]/val_values[x][1] for x in names] |
|
|
|
train_generator.on_epoch_end() |
|
validation_generator.on_epoch_end() |
|
epoch_summary = summary.on_epoch_end() |
|
callback_tensorboard.on_epoch_end(epoch, epoch_summary) |
|
callback_early_stop.on_epoch_end(epoch, epoch_summary) |
|
callback_best_model.on_epoch_end(epoch, epoch_summary) |
|
callback_save_model.on_epoch_end(epoch, epoch_summary) |
|
callback_lr.on_epoch_end(epoch, epoch_summary) |
|
print('End of epoch {}: '.format(epoch), ret, '\n') |
|
|
|
callback_tensorboard.on_train_end() |
|
callback_save_model.on_train_end() |
|
callback_best_model.on_train_end() |
|
callback_early_stop.on_train_end() |
|
callback_lr.on_train_end() |
|
|
|
|
|
if __name__ == '__main__': |
|
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
config = tf.compat.v1.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
config.log_device_placement = False |
|
tf.keras.backend.set_session(tf.Session(config=config)) |
|
|
|
launch_train('/mnt/EncryptedData1/Users/javier/vessel_registration/LiTS/None', |
|
'TrainOutput/THESIS/UW_None_mse_ssim_haus', 0, mse=True) |
|
|