File size: 13,790 Bytes
70656f8 a7b71d6 f915f2e a7b71d6 70656f8 a7b71d6 70656f8 f915f2e 70656f8 a7b71d6 f915f2e 7968536 286a978 a7b71d6 70656f8 e5764e7 70656f8 7968536 286a978 7968536 70656f8 286a978 7968536 70656f8 286a978 e5764e7 70656f8 a7b71d6 7968536 70656f8 a7b71d6 70656f8 a7b71d6 70656f8 286a978 70656f8 286a978 70656f8 67a11d3 f915f2e 286a978 a7b71d6 7968536 70656f8 7968536 e5764e7 70656f8 f915f2e e5764e7 70656f8 e5764e7 70656f8 7968536 70656f8 7968536 70656f8 ca1d395 a7b71d6 e5764e7 a7b71d6 e5764e7 a7b71d6 70656f8 f915f2e a7b71d6 70656f8 ca1d395 70656f8 a7b71d6 ca1d395 e5764e7 70656f8 286a978 7968536 286a978 70656f8 a7b71d6 286a978 70656f8 a7b71d6 286a978 a7b71d6 7968536 286a978 70656f8 7968536 286a978 a7b71d6 7968536 e5764e7 286a978 e5764e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import h5py
import ants
import numpy as np
import nibabel as nb
import os, sys
from tqdm import tqdm
import re
import time
import pandas as pd
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
import DeepDeformationMapRegistration.utils.constants as C
import shutil
import medpy.metric as medpy_metrics
import voxelmorph as vxm
from argparse import ArgumentParser
import tensorflow as tf
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
DATASET_NAMES = 'test_sample_\d{4}.h5'
DATASET_FILENAME = 'volume'
IMGS_FOLDER = '/home/jpdefrutos/workspace/DeepDeformationMapRegistration/Centerline/imgs'
WARPED_MOV = 'warpedmovout'
WARPED_FIX = 'warpedfixout'
FWD_TRFS = 'fwdtransforms'
INV_TRFS = 'invtransforms'
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--dataset', type=str, help='Directory with the images')
parser.add_argument('--outdirname', type=str, help='Output directory')
parser.add_argument('--gpu', type=int, help='GPU')
parser.add_argument('--savenifti', type=bool, default=True)
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
args = parser.parse_args()
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
if args.erase:
shutil.rmtree(args.outdirname, ignore_errors=True)
os.makedirs(args.outdirname, exist_ok=True)
os.makedirs(os.path.join(args.outdirname, 'SyN'), exist_ok=True)
os.makedirs(os.path.join(args.outdirname, 'SyNCC'), exist_ok=True)
dataset_files = os.listdir(args.dataset)
dataset_files.sort()
dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
dataset_files.sort()
dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
f = h5py.File(dataset_files[0], 'r')
image_shape = list(f['fix_image'][:].shape[:-1])
nb_labels = f['fix_segmentations'][:].shape[-1] - 1
f.close()
#### TF prep
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
NCC(image_shape).metric,
vxm.losses.MSE().loss,
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric,
HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]
fix_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
pred_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
fix_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
pred_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
# dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
# hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
# dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
config.gpu_options.allow_growth = True
config.log_device_placement = False ## to log device placement (on which device the operation ran)
config.allow_soft_placement = True
sess = tf.Session(config=config)
tf.keras.backend.set_session(sess)
####
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "{:d}".format(os.cpu_count()) #https://github.com/ANTsX/ANTsPy/issues/261
print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
# dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
# Header of the metrics csv file
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE']
metrics_file = {'SyN': os.path.join(args.outdirname, 'SyN', 'metrics.csv'),
'SyNCC': os.path.join(args.outdirname, 'SyNCC', 'metrics.csv')}
for k in metrics_file.keys():
with open(metrics_file[k], 'w') as f:
f.write(';'.join(csv_header)+'\n')
print('Starting the loop')
for step, file_path in dataset_iterator:
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
dataset_iterator.set_description('{} ({}): loading data'.format(file_num, file_path))
with h5py.File(file_path, 'r') as vol_file:
fix_img = vol_file['fix_image'][:]
mov_img = vol_file['mov_image'][:]
fix_seg = vol_file['fix_segmentations'][..., 1:].astype(np.float32)
mov_seg = vol_file['mov_segmentations'][..., 1:].astype(np.float32)
fix_centroids = vol_file['fix_centroids'][1:, ...]
mov_centroids = vol_file['mov_centroids'][1:, ...]
isotropic_shape = vol_file['isotropic_shape'][:]
voxel_size = np.divide(fix_img.shape[:-1], isotropic_shape)
# ndarray to ANTsImage
fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img)) # SoA doesn't work fine with 1-ch images
dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
t0_syn = time.time()
reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
t1_syn = time.time()
dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
t0_syncc = time.time()
reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
t1_syncc = time.time()
mov_to_fix_trf_syn = reg_output_syn[FWD_TRFS]
mov_to_fix_trf_syncc = reg_output_syn[FWD_TRFS]
if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
print('ERR: Registration failed for: '+file_path)
else:
for reg_method, reg_output in zip(['SyN', 'SyNCC'], [reg_output_syn, reg_output_syncc]):
mov_to_fix_trf_list = reg_output[FWD_TRFS]
pred_img = reg_output[WARPED_MOV].numpy()
pred_img = pred_img[..., np.newaxis] # SoA doesn't work fine with 1-ch images
fix_seg_ants = ants.make_image(fix_seg.shape, np.squeeze(fix_seg))
mov_seg_ants = ants.make_image(mov_seg.shape, np.squeeze(mov_seg))
pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
transformlist=mov_to_fix_trf_list).numpy()
pred_seg = np.squeeze(pred_seg) # SoA adds an extra axis which shouldn't be there
dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
with sess.as_default():
dice = np.mean([medpy_metrics.dc(pred_seg[np.newaxis, ..., l], fix_seg[np.newaxis,..., l]) / np.sum(
fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
hd = np.mean(
[medpy_metrics.hd(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
hd95 = np.mean(
[medpy_metrics.hd95(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
dice_macro = np.mean(
[medpy_metrics.dc(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
# {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
# 'pred_seg:0': pred_seg[np.newaxis, ...] # Batch axis
# })
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
{'fix_img:0': fix_img[np.newaxis, ...], # Batch axis
'pred_img:0': pred_img[np.newaxis, ...] # Batch axis
})
ssim = np.mean(ssim)
ms_ssim = ms_ssim[0]
# TRE
disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
dm_interp = DisplacementMapInterpolator(fix_img.shape[:-1], 'griddata', step=2)
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
# upsample_scale = 128 / 64
# fix_centroids_isotropic = fix_centroids * upsample_scale
# pred_centroids_isotropic = pred_centroids * upsample_scale
fix_centroids_isotropic = fix_centroids * voxel_size
pred_centroids_isotropic = pred_centroids * voxel_size
# fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
# pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
tre = np.mean([v for v in tre_array if not np.isnan(v)])
if np.isnan(tre):
print('TRE is NaN for {} and file {}'.format(reg_method, step))
dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95,
t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
tre]
with open(metrics_file[reg_method], 'a') as f:
f.write(';'.join(map(str, new_line))+'\n')
if args.savenifti:
save_nifti(fix_img, os.path.join(args.outdirname, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(mov_img, os.path.join(args.outdirname, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(pred_img, os.path.join(args.outdirname, reg_method, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(fix_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(mov_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(pred_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], seg_batches=[fix_seg_card[np.newaxis, ...], mov_seg_card[np.newaxis, ...], pred_seg_card[np.newaxis, ...]], filename=os.path.join(args.outdirname, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], filename=os.path.join(args.outdirname, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdirname, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
for k in metrics_file.keys():
print('Summary {}\n=======\n'.format(k))
metrics_df = pd.read_csv(metrics_file[k], sep=';', header=0)
print('\nAVG:\n')
print(metrics_df.mean(axis=0))
print('\nSTD:\n')
print(metrics_df.std(axis=0))
print('\nHD95perc:\n')
print(metrics_df['HD'].describe(percentiles=[.95]))
print('\n=======\n') |