|
import os, sys |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
from datetime import datetime |
|
|
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau |
|
from tensorflow.python.keras.utils import Progbar |
|
from tensorflow.keras import Input |
|
from tensorflow.keras.models import Model |
|
from tensorflow.python.framework.errors import InvalidArgumentError |
|
|
|
import ddmr.utils.constants as C |
|
from ddmr.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, \ |
|
HausdorffDistanceErosion |
|
from ddmr.ms_ssim_tf import MultiScaleStructuralSimilarity |
|
from ddmr.ms_ssim_tf import _MSSSIM_WEIGHTS |
|
from ddmr.utils.acummulated_optimizer import AdamAccumulated |
|
from ddmr.utils.misc import function_decorator |
|
from ddmr.layers import AugmentationLayer, UncertaintyWeighting |
|
from ddmr.utils.nifti_utils import save_nifti |
|
|
|
from Brain_study.data_generator import BatchGenerator |
|
from Brain_study.utils import SummaryDictionary, named_logs |
|
|
|
import COMET.augmentation_constants as COMET_C |
|
from COMET.utils import freeze_layers_by_group |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
import voxelmorph as vxm |
|
import h5py |
|
import re |
|
import itertools |
|
import warnings |
|
|
|
|
|
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, |
|
simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3, |
|
freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64, |
|
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None): |
|
|
|
assert dataset_folder is not None and output_folder is not None |
|
if model_file != '': |
|
assert '.h5' in model_file, 'The model must be an H5 file' |
|
|
|
|
|
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) |
|
C.GPU_NUM = str(gpu_num) |
|
|
|
if batch_size != 1 and acc_gradients != 1: |
|
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!') |
|
|
|
if resume is not None: |
|
try: |
|
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume |
|
output_folder = resume |
|
resume = True |
|
except AssertionError: |
|
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")) |
|
resume = False |
|
else: |
|
resume = False |
|
|
|
os.makedirs(output_folder, exist_ok=True) |
|
|
|
log_file = open(os.path.join(output_folder, 'log.txt'), 'w') |
|
C.TRAINING_DATASET = dataset_folder |
|
C.VALIDATION_DATASET = validation_folder |
|
C.ACCUM_GRADIENT_STEP = acc_gradients |
|
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1 |
|
C.EARLY_STOP_PATIENCE = early_stop_patience |
|
C.LEARNING_RATE = lr |
|
C.LIMIT_NUM_SAMPLES = None |
|
C.EPOCHS = max_epochs |
|
|
|
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \ |
|
"GPU: {}\n" \ |
|
"BATCH SIZE: {}\n" \ |
|
"LR: {}\n" \ |
|
"SIMILARITY: {}\n" \ |
|
"SEGMENTATION: {}\n" \ |
|
"REG. WEIGHT: {}\n" \ |
|
"EPOCHS: {:d}\n" \ |
|
"ACCUM. GRAD: {}\n" \ |
|
"EARLY STOP PATIENCE: {}\n" \ |
|
"FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), |
|
C.TRAINING_DATASET, |
|
C.VALIDATION_DATASET, |
|
C.GPU_NUM, |
|
C.BATCH_SIZE, |
|
C.LEARNING_RATE, |
|
simil, |
|
segm, |
|
rw, |
|
C.EPOCHS, |
|
C.ACCUM_GRADIENT_STEP, |
|
C.EARLY_STOP_PATIENCE, |
|
freeze_layers) |
|
|
|
log_file.write(aux) |
|
print(aux) |
|
|
|
|
|
used_labels = [0, 1, 2] |
|
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True, |
|
C.TRAINING_PERC, labels=used_labels, combine_segmentations=False, |
|
directory_val=C.VALIDATION_DATASET) |
|
|
|
train_generator = data_generator.get_train_generator() |
|
validation_generator = data_generator.get_validation_generator() |
|
|
|
image_input_shape = train_generator.get_data_shape()[-1][:-1] |
|
image_output_shape = [image_size] * 3 |
|
nb_labels = len(train_generator.get_segmentation_labels()) |
|
|
|
|
|
|
|
config = tf.compat.v1.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
config.log_device_placement = False |
|
sess = tf.Session(config=config) |
|
tf.keras.backend.set_session(sess) |
|
|
|
|
|
|
|
enc_features = unet |
|
dec_features = enc_features[::-1] + head |
|
nb_features = [enc_features, dec_features] |
|
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape, |
|
nb_labels=nb_labels, |
|
nb_unet_features=nb_features, |
|
int_steps=0, |
|
int_downsize=1, |
|
seg_downsize=1) |
|
if model_file != '': |
|
print('MODEL LOCATION: ', model_file) |
|
network.load_weights(model_file, by_name=True) |
|
|
|
resume_epoch = 0 |
|
if resume: |
|
cp_dir = os.path.join(output_folder, 'checkpoints') |
|
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))] |
|
if len(cp_file_list): |
|
cp_file_list.sort() |
|
checkpoint_file = cp_file_list[-1] |
|
if os.path.exists(checkpoint_file): |
|
network.load_weights(checkpoint_file, by_name=True) |
|
print('Loaded checkpoint file: ' + checkpoint_file) |
|
try: |
|
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1]) |
|
except TypeError: |
|
|
|
resume_epoch = 0 |
|
print('Resuming from epoch: {:d}'.format(resume_epoch)) |
|
else: |
|
warnings.warn('Checkpoint file NOT found. Training from scratch') |
|
|
|
|
|
_, frozen_layers = freeze_layers_by_group(network, freeze_layers) |
|
if frozen_layers is not None: |
|
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers])) |
|
else: |
|
msg = "[INF] None frozen layers" |
|
print(msg) |
|
log_file.write(msg) |
|
|
|
network.summary(line_length=C.SUMMARY_LINE_LENGTH) |
|
network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.write) |
|
|
|
input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train') |
|
augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, |
|
max_deformation=COMET_C.MAX_AUG_DEF, |
|
max_rotation=COMET_C.MAX_AUG_ANGLE, |
|
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=COMET_C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=False, |
|
only_resize=False, |
|
trainable=False) |
|
augm_model = Model(inputs=input_layer_train, outputs=augm_layer(input_layer_train)) |
|
|
|
|
|
|
|
|
|
SSIM_KER_SIZE = 5 |
|
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3] |
|
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS) |
|
loss_simil = [] |
|
prior_loss_w = [] |
|
for s in simil: |
|
if s == 'ssim': |
|
loss_simil.append(StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss) |
|
prior_loss_w.append(1.) |
|
elif s == 'ms_ssim': |
|
loss_simil.append(MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, |
|
power_factors=MS_SSIM_WEIGHTS).loss) |
|
prior_loss_w.append(1.) |
|
elif s == 'ncc': |
|
loss_simil.append(NCC(image_input_shape).loss) |
|
prior_loss_w.append(1.) |
|
elif s == 'mse': |
|
loss_simil.append(vxm.losses.MSE().loss) |
|
prior_loss_w.append(1.) |
|
else: |
|
raise ValueError('Unknown similarity function: ', s) |
|
|
|
loss_segm = [] |
|
for s in segm: |
|
if s == 'dice': |
|
loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], |
|
num_labels=nb_labels).loss) |
|
prior_loss_w.append(1.) |
|
elif s == 'dice_macro': |
|
loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], |
|
num_labels=nb_labels).loss_macro) |
|
prior_loss_w.append(1.) |
|
elif s == 'hd': |
|
loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [ |
|
train_generator.get_data_shape()[2][-1]]).loss) |
|
prior_loss_w.append(1.) |
|
else: |
|
raise ValueError('Unknown similarity function: ', s) |
|
|
|
|
|
grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32) |
|
fix_seg = tf.keras.Input(shape=(*image_output_shape, len(train_generator.get_segmentation_labels())), |
|
name='multiLoss_fix_seg_input', dtype=tf.float32) |
|
|
|
multiLoss = UncertaintyWeighting(num_loss_fns=len(loss_simil) + len(loss_segm), |
|
num_reg_fns=1, |
|
loss_fns=[*loss_simil, |
|
*loss_segm], |
|
reg_fns=[vxm.losses.Grad('l2').loss], |
|
prior_loss_w=prior_loss_w, |
|
|
|
prior_reg_w=[prior_reg_w], |
|
name='MultiLossLayer') |
|
loss = multiLoss([*[network.inputs[1]] * len(loss_simil), *[fix_seg] * len(loss_segm), |
|
*[network.outputs[0]] * len(loss_simil), *[network.outputs[2]] * len(loss_simil), |
|
grad, |
|
network.outputs[1]]) |
|
|
|
|
|
|
|
full_model = tf.keras.Model(inputs=network.inputs + [fix_seg, grad], |
|
outputs=network.outputs + [loss]) |
|
|
|
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True) |
|
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True) |
|
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'), |
|
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0, |
|
update_freq='epoch', |
|
write_graph=True, write_grads=True |
|
) |
|
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, |
|
min_delta=0.00001) |
|
|
|
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'), |
|
save_best_only=True, monitor='val_loss', verbose=1, mode='min') |
|
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'), |
|
save_weights_only=True, monitor='val_loss', verbose=0, mode='min') |
|
|
|
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10) |
|
|
|
optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE) |
|
full_model.compile(optimizer=optimizer, |
|
loss=None, ) |
|
|
|
|
|
callback_tensorboard.set_model(full_model) |
|
callback_early_stop.set_model(full_model) |
|
callback_best_model.set_model(network) |
|
callback_save_checkpoint.set_model(network) |
|
callback_lr.set_model(full_model) |
|
|
|
summary = SummaryDictionary(full_model, C.BATCH_SIZE) |
|
names = full_model.metrics_names |
|
zero_grads = tf.zeros_like(network.references.pos_flow, name='dummy_zero_grads') |
|
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'))) |
|
|
|
with sess.as_default(): |
|
|
|
callback_tensorboard.on_train_begin() |
|
callback_early_stop.on_train_begin() |
|
callback_best_model.on_train_begin() |
|
callback_save_checkpoint.on_train_begin() |
|
callback_lr.on_train_begin() |
|
|
|
for epoch in range(resume_epoch, C.EPOCHS): |
|
callback_tensorboard.on_epoch_begin(epoch) |
|
callback_early_stop.on_epoch_begin(epoch) |
|
callback_best_model.on_epoch_begin(epoch) |
|
callback_save_checkpoint.on_epoch_begin(epoch) |
|
callback_lr.on_epoch_begin(epoch) |
|
|
|
print("\nEpoch {}/{}".format(epoch, C.EPOCHS)) |
|
print("TRAIN") |
|
|
|
log_file.write( |
|
'\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
|
progress_bar = Progbar(len(train_generator), width=30, verbose=1) |
|
for step, (in_batch, _) in enumerate(train_generator, 1): |
|
callback_best_model.on_train_batch_begin(step) |
|
callback_save_checkpoint.on_train_batch_begin(step) |
|
callback_early_stop.on_train_batch_begin(step) |
|
callback_lr.on_train_batch_begin(step) |
|
|
|
try: |
|
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch) |
|
np.nan_to_num(fix_img, copy=False) |
|
np.nan_to_num(mov_img, copy=False) |
|
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf( |
|
np.sum(fix_img)): |
|
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img), |
|
np.unique(mov_img)) |
|
print(msg) |
|
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg)) |
|
|
|
except InvalidArgumentError as err: |
|
print('TF Error : {}'.format(str(err))) |
|
continue |
|
|
|
ret = full_model.train_on_batch( |
|
x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads)) |
|
|
|
summary.on_train_batch_end(ret) |
|
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret)) |
|
callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret)) |
|
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret)) |
|
callback_lr.on_train_batch_end(step, named_logs(full_model, ret)) |
|
progress_bar.update(step, zip(names, ret)) |
|
log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
|
val_values = progress_bar._values.copy() |
|
ret = [val_values[x][0] / val_values[x][1] for x in names] |
|
|
|
print('\nVALIDATION') |
|
log_file.write( |
|
'\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch)) |
|
progress_bar = Progbar(len(validation_generator), width=30, verbose=1) |
|
for step, (in_batch, _) in enumerate(validation_generator, 1): |
|
try: |
|
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch) |
|
except InvalidArgumentError as err: |
|
print('TF Error : {}'.format(str(err))) |
|
continue |
|
|
|
ret = full_model.test_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads)) |
|
|
|
summary.on_validation_batch_end(ret) |
|
progress_bar.update(step, zip(names, ret)) |
|
log_file.write('\t\tStep {:03d}: {}'.format(step, ret)) |
|
val_values = progress_bar._values.copy() |
|
ret = [val_values[x][0] / val_values[x][1] for x in names] |
|
|
|
train_generator.on_epoch_end() |
|
validation_generator.on_epoch_end() |
|
epoch_summary = summary.on_epoch_end() |
|
callback_tensorboard.on_epoch_end(epoch, epoch_summary) |
|
callback_best_model.on_epoch_end(epoch, epoch_summary) |
|
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary) |
|
callback_early_stop.on_epoch_end(epoch, epoch_summary) |
|
callback_lr.on_epoch_end(epoch, epoch_summary) |
|
print('End of epoch {}: '.format(epoch), ret, '\n') |
|
|
|
callback_tensorboard.on_train_end() |
|
callback_best_model.on_train_end() |
|
callback_save_checkpoint.on_train_end() |
|
callback_early_stop.on_train_end() |
|
callback_lr.on_train_end() |
|
|
|
|