import os, sys currentdir = os.path.dirname(os.path.realpath(__file__)) parentdir = os.path.dirname(currentdir) sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True' import numpy as np import tensorflow as tf from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping import voxelmorph as vxm import neurite as ne import h5py from datetime import datetime import DeepDeformationMapRegistration.utils.constants as C from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D from DeepDeformationMapRegistration.utils.misc import try_mkdir from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion from DeepDeformationMapRegistration.layers import UncertaintyWeighting os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER os.environ['CUDA_VISIBLE_DEVICES'] = '1' # const.GPU_NUM # Check availability before running using 'nvidia-smi' C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/ov_dataset/training' C.BATCH_SIZE = 256 C.LIMIT_NUM_SAMPLES = None C.EPOCHS = 10000 if PYCHARM_EXEC: path_prefix = os.path.join('scripts', 'tf') else: path_prefix = '' # Load data # Build data generator sample_list = [os.path.join(C.TRAINING_DATASET, f) for f in os.listdir(C.TRAINING_DATASET) if f.startswith('sample')] sample_list.sort() data_generator = DataGeneratorManager2D(sample_list[:C.LIMIT_NUM_SAMPLES], C.BATCH_SIZE, C.TRAINING_PERC, (64, 64, 1), fix_img_tag='dilated/input/fix', mov_img_tag='dilated/input/mov', multi_loss=True, ) # Build model in_shape_img, in_shape_grad = data_generator.train_generator.input_shape enc_features = [32, 32, 32, 32, 32, 32] # const.ENCODER_FILTERS dec_features = [32, 32, 32, 32, 32, 32, 32, 16] # const.ENCODER_FILTERS[::-1] nb_features = [enc_features, dec_features] vxm_model = vxm.networks.VxmDense(inshape=in_shape_img[:-1], nb_unet_features=nb_features, int_steps=0) #moving = tf.keras.Input(shape=in_shape_img, name='multiLoss_moving_input', dtype=tf.float32) #fixed = tf.keras.Input(shape=in_shape_img, name='multiLoss_fixed_input', dtype=tf.float32) grad = tf.keras.Input(shape=(*in_shape_img[:-1], 2), name='multiLoss_grad_input', dtype=tf.float32) def dice_loss(y_true, y_pred): # Dice().loss returns -Dice score return 1 + vxm.losses.Dice().loss(y_true, y_pred) #fixed_pred, dm_pred = vxm_model([moving, fixed]) multiLoss = UncertaintyWeighting(num_loss_fns=2, num_reg_fns=1, loss_fns=[HausdorffDistanceErosion(2, 2).loss, dice_loss], reg_fns=[vxm.losses.Grad('l2').loss], prior_loss_w=[1., 1.], prior_reg_w=[0.01], name='MultiLossLayer') loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], vxm_model.references.y_source, vxm_model.references.y_source, grad, vxm_model.references.pos_flow]) full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss]) # Compile the model full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None) # Train output_folder = os.path.join('train_2d_multiloss_haussdorf_dice_grad' + datetime.now().strftime("%H%M%S-%d%m%Y")) try_mkdir(output_folder) try_mkdir(os.path.join(output_folder, 'checkpoints')) try_mkdir(os.path.join(output_folder, 'tensorboard')) my_callbacks = [ # EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1), ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'), save_best_only=True, monitor='val_loss', verbose=0, mode='min'), ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'), save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'), # CSVLogger(train_log_name, ';'), # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg']) TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'), batch_size=C.BATCH_SIZE, write_images=True, histogram_freq=10, update_freq='epoch', write_grads=True), EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001) ] hist = full_model.fit_generator(data_generator.train_generator, epochs=C.EPOCHS, validation_data=data_generator.validation_generator, verbose=2, callbacks=my_callbacks)