jpdefrutos commited on
Commit
a7b71d6
·
1 Parent(s): 7d5b555

ANTs evaluation script

Browse files
Files changed (1) hide show
  1. ANTs/eval_ants.py +72 -44
ANTs/eval_ants.py CHANGED
@@ -1,14 +1,21 @@
1
  import h5py
2
  import ants
3
  import numpy as np
 
4
  import DeepDeformationMapRegistration.utils.constants as C
5
  import os
6
- from tqdm import tqdm
7
  import re
 
 
8
 
9
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
10
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
11
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
 
 
 
 
12
 
13
  from argparse import ArgumentParser
14
 
@@ -30,6 +37,7 @@ if __name__ == '__main__':
30
  parser.add_argument('--outdir', type=str, help='Output directory')
31
  args = parser.parse_args()
32
 
 
33
  dataset_files = os.listdir(args.dataset)
34
  dataset_files.sort()
35
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
@@ -37,22 +45,23 @@ if __name__ == '__main__':
37
  dataset_iterator = tqdm(dataset_files)
38
 
39
  f = h5py.File(dataset_files[0], 'r')
40
- image_output_shape = list(f['fix_image'][:].shape[:-1])
 
41
  f.close()
42
 
43
  #### TF prep
44
  metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
45
- NCC(image_input_shape).metric,
46
  vxm.losses.MSE().loss,
47
  MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
48
- GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
49
- HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
50
- GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
51
 
52
- fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
53
- pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
54
- fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
55
- pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
56
 
57
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
58
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
@@ -70,14 +79,18 @@ if __name__ == '__main__':
70
  sess = tf.Session(config=config)
71
  tf.keras.backend.set_session(sess)
72
  ####
73
- dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
 
 
74
 
75
- metrics_file = os.path.join(output_folder, 'metrics.csv')
 
 
76
 
77
- for file_path in dataset_iterator:
78
  file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
79
 
80
- dataset_iterator.set_description('{} ({}): laoding data'.format(file_num, dataset_name))
81
  with h5py.File(file_path, 'r') as vol_file:
82
  fix_img = vol_file['fix_image'][:]
83
  mov_img = vol_file['mov_image'][:]
@@ -85,63 +98,78 @@ if __name__ == '__main__':
85
  fix_seg = vol_file['fix_segmentations'][:]
86
  mov_seg = vol_file['mov_segmentations'][:]
87
 
88
- fix_centroid = vol_file['fix_centroids'][:]
89
- mov_centroid = vol_file['mov_centroids'][:]
90
 
91
  # ndarray to ANTsImage
92
- fix_img = ants.make_image(fix_img.shape, fix_img)
93
- mov_img = ants.make_image(mov_img.shape, mov_img)
 
 
 
 
 
 
 
 
94
 
95
- reg_output_syn = ants.registration(fix_img, mov_img, 'SyN')
96
- reg_output_syncc = ants.registration(fix_img, mov_img, 'SyNCC')
97
  mov_to_fix_trf_syn = reg_output_syn[FWD_TRFS]
98
  mov_to_fix_trf_syncc = reg_output_syn[FWD_TRFS]
99
  if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
100
  print('ERR: Registration failed for: '+file_path)
101
  else:
102
  for reg_output in [reg_output_syn, reg_output_syncc]:
103
- mov_to_fix_trf = reg_output[FWD_TRFS]
104
  pred_img = reg_output[WARPED_MOV].numpy()
105
- pred_seg = mov_to_fix_trf.apply_to_image(ants.make_image(mov_seg.shape, mov_seg)).numpy()
106
 
 
 
 
 
 
107
  with sess.as_default():
108
  dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
109
- {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
 
 
110
 
111
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
112
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
113
  fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
114
 
115
  ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
116
- {'fix_img:0': fix_img, 'pred_img:0': pred_img})
 
 
117
  ms_ssim = ms_ssim[0]
118
- tf.keras.backend.clear_session()
119
 
120
- # TRE
121
- pred_centroids = dm_interp(mov_to_fix_trf.numpy(), mov_centroid, backwards=True) + mov_centroid
122
- upsample_scale = 128 / 64
123
- fix_centroids_isotropic = fix_centroids * upsample_scale
124
- pred_centroids_isotropic = pred_centroids * upsample_scale
 
125
 
126
- fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
127
- pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
128
- tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
129
- tre = np.mean([v for v in tre_array if not np.isnan(v)])
130
 
131
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
132
  with open(metrics_file, 'a') as f:
133
  f.write(';'.join(map(str, new_line))+'\n')
134
 
135
- save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
136
- save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
137
- save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
138
- save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
139
- save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
140
- save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
141
 
142
- plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
143
- plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
144
- save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
145
 
146
  print('Summary\n=======\n')
147
  print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(
 
1
  import h5py
2
  import ants
3
  import numpy as np
4
+ import nibabel as nb
5
  import DeepDeformationMapRegistration.utils.constants as C
6
  import os
7
+ from tqdm import tqdm
8
  import re
9
+ import time
10
+ import pandas as pd
11
 
12
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
13
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
14
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
15
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
16
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
17
+
18
+ import voxelmorph as vxm
19
 
20
  from argparse import ArgumentParser
21
 
 
37
  parser.add_argument('--outdir', type=str, help='Output directory')
38
  args = parser.parse_args()
39
 
40
+ os.makedirs(args.outdir, exist_ok=True)
41
  dataset_files = os.listdir(args.dataset)
42
  dataset_files.sort()
43
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
 
45
  dataset_iterator = tqdm(dataset_files)
46
 
47
  f = h5py.File(dataset_files[0], 'r')
48
+ image_shape = list(f['fix_image'][:].shape[:-1])
49
+ nb_labels = f['fix_segmentations'][:].shape[-1]
50
  f.close()
51
 
52
  #### TF prep
53
  metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
54
+ NCC(image_shape).metric,
55
  vxm.losses.MSE().loss,
56
  MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
57
+ GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric,
58
+ HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
59
+ GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]
60
 
61
+ fix_img_ph = tf.placeholder(tf.float32, (1, *image_shape, 1), name='fix_img')
62
+ pred_img_ph = tf.placeholder(tf.float32, (1, *image_shape, 1), name='pred_img')
63
+ fix_seg_ph = tf.placeholder(tf.float32, (1, *image_shape, nb_labels), name='fix_seg')
64
+ pred_seg_ph = tf.placeholder(tf.float32, (1, *image_shape, nb_labels), name='pred_seg')
65
 
66
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
67
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
 
79
  sess = tf.Session(config=config)
80
  tf.keras.backend.set_session(sess)
81
  ####
82
+ dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
83
+ # Header of the metrics csv file
84
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time_SyN', 'Time_SyNCC', 'TRE']
85
 
86
+ metrics_file = os.path.join(args.outdir, 'metrics.csv')
87
+ with open(metrics_file, 'w') as f:
88
+ f.write(';'.join(csv_header)+'\n')
89
 
90
+ for step, file_path in tqdm(enumerate(dataset_iterator)):
91
  file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
92
 
93
+ dataset_iterator.set_description('{} ({}): laoding data'.format(file_num, args.dataset))
94
  with h5py.File(file_path, 'r') as vol_file:
95
  fix_img = vol_file['fix_image'][:]
96
  mov_img = vol_file['mov_image'][:]
 
98
  fix_seg = vol_file['fix_segmentations'][:]
99
  mov_seg = vol_file['mov_segmentations'][:]
100
 
101
+ fix_centroids = vol_file['fix_centroids'][:]
102
+ mov_centroids = vol_file['mov_centroids'][:]
103
 
104
  # ndarray to ANTsImage
105
+ fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # ANTs doesn't work fine with 1-ch images
106
+ mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img)) # ANTs doesn't work fine with 1-ch images
107
+
108
+ t0_syn = time.time()
109
+ reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
110
+ t1_syn = time.time()
111
+
112
+ t0_syncc = time.time()
113
+ reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
114
+ t1_syncc = time.time()
115
 
 
 
116
  mov_to_fix_trf_syn = reg_output_syn[FWD_TRFS]
117
  mov_to_fix_trf_syncc = reg_output_syn[FWD_TRFS]
118
  if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
119
  print('ERR: Registration failed for: '+file_path)
120
  else:
121
  for reg_output in [reg_output_syn, reg_output_syncc]:
122
+ mov_to_fix_trf_list = reg_output[FWD_TRFS]
123
  pred_img = reg_output[WARPED_MOV].numpy()
124
+ pred_img = pred_img[..., np.newaxis] # ANTs doesn't work fine with 1-ch images
125
 
126
+ fix_seg_ants = ants.make_image(fix_seg.shape, np.squeeze(fix_seg))
127
+ mov_seg_ants = ants.make_image(mov_seg.shape, np.squeeze(mov_seg))
128
+ pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
129
+ transformlist=mov_to_fix_trf_list).numpy()
130
+ pred_seg = np.squeeze(pred_seg) # ANTs adds an extra axis which shouldn't be there
131
  with sess.as_default():
132
  dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
133
+ {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
134
+ 'pred_seg:0': pred_seg[np.newaxis, ...] # Batch axis
135
+ })
136
 
137
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
138
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
139
  fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
140
 
141
  ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
142
+ {'fix_img:0': fix_img[np.newaxis, ...], # Batch axis
143
+ 'pred_img:0': pred_img[np.newaxis, ...] # Batch axis
144
+ })
145
  ms_ssim = ms_ssim[0]
 
146
 
147
+ # TRE
148
+ disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
149
+ pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
150
+ upsample_scale = 128 / 64
151
+ fix_centroids_isotropic = fix_centroids * upsample_scale
152
+ pred_centroids_isotropic = pred_centroids * upsample_scale
153
 
154
+ fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
155
+ pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
156
+ tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
157
+ tre = np.mean([v for v in tre_array if not np.isnan(v)])
158
 
159
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1_syn-t0_syn, t1_syncc-t0_syncc, tre]
160
  with open(metrics_file, 'a') as f:
161
  f.write(';'.join(map(str, new_line))+'\n')
162
 
163
+ save_nifti(fix_img[0, ...], os.path.join(args.outdir, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
164
+ save_nifti(mov_img[0, ...], os.path.join(args.outdir, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
165
+ save_nifti(pred_img[0, ...], os.path.join(args.outdir, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
166
+ save_nifti(fix_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
167
+ save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
168
+ save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
169
 
170
+ plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...], pred_img[np.newaxis, ...], os.path.join(args.outdir, '{:03d}_figures_img.png'.format(step)), show=False)
171
+ plot_predictions(fix_seg[np.newaxis, ...], mov_seg[np.newaxis, ...], disp_map[np.newaxis, ...], pred_seg[np.newaxis, ...], os.path.join(args.outdir, '{:03d}_figures_seg.png'.format(step)), show=False)
172
+ save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, '{:03d}_disp_map_fig.png'.format(step)), show=False)
173
 
174
  print('Summary\n=======\n')
175
  print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(