|
import os, sys |
|
|
|
import shutil |
|
|
|
import h5py |
|
import matplotlib.pyplot as plt |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
import tensorflow as tf |
|
|
|
|
|
import numpy as np |
|
import pandas as pd |
|
import voxelmorph as vxm |
|
|
|
import ddmr.utils.constants as C |
|
from ddmr.utils.nifti_utils import save_nifti |
|
from ddmr.layers import AugmentationLayer |
|
from ddmr.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion |
|
from ddmr.ms_ssim_tf import MultiScaleStructuralSimilarity |
|
from ddmr.utils.acummulated_optimizer import AdamAccumulated |
|
from ddmr.utils.visualization import save_disp_map_img, plot_predictions |
|
from ddmr.utils.misc import segmentation_ohe_to_cardinal |
|
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation |
|
from scipy.interpolate import RegularGridInterpolator |
|
from tqdm import tqdm |
|
|
|
import h5py |
|
|
|
from Brain_study.data_generator import BatchGenerator |
|
|
|
import argparse |
|
|
|
from skimage.transform import warp |
|
import neurite as ne |
|
|
|
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training' |
|
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5' |
|
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/' |
|
|
|
OUTPUT_FOLDER_NAME = 'Evaluate' |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-m', '--model', type=str, help='.h5 of the model', default='') |
|
parser.add_argument('-d', '--dir', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default='') |
|
parser.add_argument('--gpu', type=int, help='GPU', default=0) |
|
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', |
|
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training') |
|
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False) |
|
args = parser.parse_args() |
|
if args.model != '': |
|
assert '.h5' in args.model, 'No checkpoint file provided, use -d/--dir instead' |
|
MODEL_FILE = args.model |
|
DATA_ROOT_DIR = os.path.split(args.model)[0] |
|
elif args.dir != '': |
|
assert '.h5' not in args.model, 'Provided checkpoint file, user -m/--model instead' |
|
MODEL_FILE = os.path.join(args.dir, 'checkpoints', 'best_model.h5') |
|
DATA_ROOT_DIR = args.dir |
|
else: |
|
raise ValueError("Provide either the model file or the directory ./containing checkpoints/best_model.h5") |
|
|
|
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
|
DATASET = args.dataset |
|
|
|
print('MODEL LOCATION: ', MODEL_FILE) |
|
|
|
|
|
output_folder = os.path.join(DATA_ROOT_DIR, OUTPUT_FOLDER_NAME) |
|
|
|
if args.erase: |
|
shutil.rmtree(output_folder, ignore_errors=True) |
|
os.makedirs(output_folder, exist_ok=True) |
|
print('DESTINATION FOLDER: ', output_folder) |
|
|
|
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all']) |
|
|
|
img_generator = data_generator.get_train_generator() |
|
nb_labels = len(img_generator.get_segmentation_labels()) |
|
image_input_shape = img_generator.get_data_shape()[-1][:-1] |
|
image_output_shape = [64] * 3 |
|
|
|
|
|
|
|
input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm') |
|
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, |
|
max_deformation=C.MAX_AUG_DEF, |
|
max_rotation=C.MAX_AUG_ANGLE, |
|
num_control_points=C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=False, |
|
only_resize=False, |
|
trainable=False) |
|
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm)) |
|
|
|
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss, |
|
NCC(image_input_shape).loss, |
|
vxm.losses.MSE().loss, |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss] |
|
|
|
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric, |
|
NCC(image_input_shape).metric, |
|
vxm.losses.MSE().loss, |
|
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric, |
|
GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).loss, |
|
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).loss] |
|
|
|
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg, |
|
'VxmDense': vxm.networks.VxmDense, |
|
'AdamAccumulated': AdamAccumulated, |
|
'loss': loss_fncs, |
|
'metric': metric_fncs}, |
|
compile=False) |
|
|
|
|
|
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels) |
|
|
|
|
|
metrics = pd.DataFrame(columns=['File', 'SSIM', 'MS-SSIM', 'MSE', 'DICE', 'HD']) |
|
config = tf.compat.v1.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
config.log_device_placement = False |
|
|
|
sess = tf.Session(config=config) |
|
tf.keras.backend.set_session(sess) |
|
with sess.as_default(): |
|
sess.run(tf.global_variables_initializer()) |
|
network.load_weights(MODEL_FILE, by_name=True) |
|
progress_bar = tqdm(enumerate(img_generator, 1), desc='Evaluation', total=len(img_generator)) |
|
for step, (in_batch, _) in progress_bar: |
|
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch) |
|
|
|
if network.name == 'vxm_dense_semi_supervised_seg': |
|
pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) |
|
else: |
|
pred_img, disp_map = network.predict([mov_img, fix_img]) |
|
pred_seg = warp_segmentation.predict([mov_seg, disp_map]) |
|
|
|
|
|
dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval() |
|
hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval() |
|
|
|
pred_seg = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32) |
|
mov_seg = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32) |
|
fix_seg = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32) |
|
|
|
mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1) |
|
dest_coords = mov_coords + disp_map[0, ...] |
|
|
|
ssim = StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric(fix_img, pred_img).eval() |
|
ms_ssim = MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric(fix_img, pred_img).eval()[0] |
|
mse = vxm.losses.MSE().loss(fix_img, pred_img).eval() |
|
|
|
metrics.append({'File': step, |
|
'SSIM': ssim, |
|
'MS-SSIM': ms_ssim, |
|
'MSE': mse, |
|
'DICE': dice, |
|
'HD': hd}, ignore_index=True) |
|
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False) |
|
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False) |
|
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False) |
|
save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False) |
|
save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False) |
|
save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False) |
|
|
|
magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1)) |
|
_ = plt.hist(magnitude.flatten()) |
|
plt.title('Histogram of disp. magnitudes') |
|
|
|
plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice))) |
|
plt.close() |
|
|
|
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg, mov_seg, pred_seg], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False) |
|
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False) |
|
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False) |
|
|
|
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice)) |
|
|
|
metrics.to_csv(os.path.join(output_folder, 'metrics.csv')) |
|
print('Done') |
|
|