jpdefrutos commited on
Commit
7968536
·
1 Parent(s): 15c9383

Added --output, --erase, and --save-nifti flags, as with the other evaluation scripts

Browse files
Files changed (1) hide show
  1. SoA_methods/eval_ants.py +46 -32
SoA_methods/eval_ants.py CHANGED
@@ -18,7 +18,7 @@ from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolato
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
20
  import DeepDeformationMapRegistration.utils.constants as C
21
-
22
  import medpy.metric as medpy_metrics
23
 
24
  import voxelmorph as vxm
@@ -41,16 +41,19 @@ INV_TRFS = 'invtransforms'
41
  if __name__ == '__main__':
42
  parser = ArgumentParser()
43
  parser.add_argument('--dataset', type=str, help='Directory with the images')
44
- parser.add_argument('--outdir', type=str, help='Output directory')
45
  parser.add_argument('--gpu', type=int, help='GPU')
 
 
46
  args = parser.parse_args()
47
 
48
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
50
-
51
- os.makedirs(args.outdir, exist_ok=True)
52
- os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
53
- os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
 
54
  dataset_files = os.listdir(args.dataset)
55
  dataset_files.sort()
56
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
@@ -59,7 +62,7 @@ if __name__ == '__main__':
59
 
60
  f = h5py.File(dataset_files[0], 'r')
61
  image_shape = list(f['fix_image'][:].shape[:-1])
62
- nb_labels = f['fix_segmentations'][:].shape[-1]
63
  f.close()
64
 
65
  #### TF prep
@@ -96,10 +99,10 @@ if __name__ == '__main__':
96
  print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
97
  # dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
98
  # Header of the metrics csv file
99
- csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
100
 
101
- metrics_file = {'SyN': os.path.join(args.outdir, 'SyN', 'metrics.csv'),
102
- 'SyNCC': os.path.join(args.outdir, 'SyNCC', 'metrics.csv')}
103
  for k in metrics_file.keys():
104
  with open(metrics_file[k], 'w') as f:
105
  f.write(';'.join(csv_header)+'\n')
@@ -113,11 +116,14 @@ if __name__ == '__main__':
113
  fix_img = vol_file['fix_image'][:]
114
  mov_img = vol_file['mov_image'][:]
115
 
116
- fix_seg = vol_file['fix_segmentations'][:]
117
- mov_seg = vol_file['mov_segmentations'][:]
 
 
 
118
 
119
- fix_centroids = vol_file['fix_centroids'][:]
120
- mov_centroids = vol_file['mov_centroids'][:]
121
 
122
  # ndarray to ANTsImage
123
  fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
@@ -155,6 +161,8 @@ if __name__ == '__main__':
155
  fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
156
  hd = np.mean(
157
  [medpy_metrics.hd(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
 
 
158
  dice_macro = np.mean(
159
  [medpy_metrics.dc(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
160
 
@@ -178,32 +186,38 @@ if __name__ == '__main__':
178
  disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
179
  dm_interp = DisplacementMapInterpolator(fix_img.shape[:-1], 'griddata', step=2)
180
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
 
181
  # upsample_scale = 128 / 64
182
  # fix_centroids_isotropic = fix_centroids * upsample_scale
183
  # pred_centroids_isotropic = pred_centroids * upsample_scale
184
 
 
 
 
185
  # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
186
  # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
187
  tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
188
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
189
-
190
- # dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
191
- # new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
192
- # t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
193
- # tre]
194
- # with open(metrics_file[reg_method], 'a') as f:
195
- # f.write(';'.join(map(str, new_line))+'\n')
196
-
197
- save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
198
- save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
199
- save_nifti(pred_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
200
- save_nifti(fix_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
201
- save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
202
- save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
203
-
204
- plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], seg_batches=[fix_seg_card[np.newaxis, ...], mov_seg_card[np.newaxis, ...], pred_seg_card[np.newaxis, ...]], filename=os.path.join(args.outdir, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
205
- plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], filename=os.path.join(args.outdir, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
206
- save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
 
 
207
 
208
  for k in metrics_file.keys():
209
  print('Summary {}\n=======\n'.format(k))
 
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
20
  import DeepDeformationMapRegistration.utils.constants as C
21
+ import shutil
22
  import medpy.metric as medpy_metrics
23
 
24
  import voxelmorph as vxm
 
41
  if __name__ == '__main__':
42
  parser = ArgumentParser()
43
  parser.add_argument('--dataset', type=str, help='Directory with the images')
44
+ parser.add_argument('--outdirname', type=str, help='Output directory')
45
  parser.add_argument('--gpu', type=int, help='GPU')
46
+ parser.add_argument('--savenifti', type=bool, default=True)
47
+ parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
48
  args = parser.parse_args()
49
 
50
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
51
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
52
+ if args.erase:
53
+ shutil.rmtree(args.outdirname, ignore_errors=True)
54
+ os.makedirs(args.outdirname, exist_ok=True)
55
+ os.makedirs(os.path.join(args.outdirname, 'SyN'), exist_ok=True)
56
+ os.makedirs(os.path.join(args.outdirname, 'SyNCC'), exist_ok=True)
57
  dataset_files = os.listdir(args.dataset)
58
  dataset_files.sort()
59
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
 
62
 
63
  f = h5py.File(dataset_files[0], 'r')
64
  image_shape = list(f['fix_image'][:].shape[:-1])
65
+ nb_labels = f['fix_segmentations'][:].shape[-1] - 1
66
  f.close()
67
 
68
  #### TF prep
 
99
  print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
100
  # dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
101
  # Header of the metrics csv file
102
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE']
103
 
104
+ metrics_file = {'SyN': os.path.join(args.outdirname, 'SyN', 'metrics.csv'),
105
+ 'SyNCC': os.path.join(args.outdirname, 'SyNCC', 'metrics.csv')}
106
  for k in metrics_file.keys():
107
  with open(metrics_file[k], 'w') as f:
108
  f.write(';'.join(csv_header)+'\n')
 
116
  fix_img = vol_file['fix_image'][:]
117
  mov_img = vol_file['mov_image'][:]
118
 
119
+ fix_seg = vol_file['fix_segmentations'][..., 1:].astype(np.float32)
120
+ mov_seg = vol_file['mov_segmentations'][..., 1:].astype(np.float32)
121
+
122
+ fix_centroids = vol_file['fix_centroids'][1:, ...]
123
+ mov_centroids = vol_file['mov_centroids'][1:, ...]
124
 
125
+ isotropic_shape = vol_file['isotropic_shape'][:]
126
+ voxel_size = np.divide(fix_img.shape[:-1], isotropic_shape)
127
 
128
  # ndarray to ANTsImage
129
  fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img)) # SoA doesn't work fine with 1-ch images
 
161
  fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
162
  hd = np.mean(
163
  [medpy_metrics.hd(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
164
+ hd95 = np.mean(
165
+ [medpy_metrics.hd95(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
166
  dice_macro = np.mean(
167
  [medpy_metrics.dc(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
168
 
 
186
  disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
187
  dm_interp = DisplacementMapInterpolator(fix_img.shape[:-1], 'griddata', step=2)
188
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
189
+ # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
190
  # upsample_scale = 128 / 64
191
  # fix_centroids_isotropic = fix_centroids * upsample_scale
192
  # pred_centroids_isotropic = pred_centroids * upsample_scale
193
 
194
+ fix_centroids_isotropic = fix_centroids * voxel_size
195
+ pred_centroids_isotropic = pred_centroids * voxel_size
196
+
197
  # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
198
  # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
199
  tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
200
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
201
+ if np.isnan(tre):
202
+ print('TRE is NaN for {} and file {}'.format(reg_method, step))
203
+
204
+ dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
205
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95,
206
+ t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
207
+ tre]
208
+ with open(metrics_file[reg_method], 'a') as f:
209
+ f.write(';'.join(map(str, new_line))+'\n')
210
+ if args.savenifti:
211
+ save_nifti(fix_img, os.path.join(args.outdirname, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
212
+ save_nifti(mov_img, os.path.join(args.outdirname, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
213
+ save_nifti(pred_img, os.path.join(args.outdirname, reg_method, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
214
+ save_nifti(fix_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
215
+ save_nifti(mov_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
216
+ save_nifti(pred_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
217
+
218
+ plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], seg_batches=[fix_seg_card[np.newaxis, ...], mov_seg_card[np.newaxis, ...], pred_seg_card[np.newaxis, ...]], filename=os.path.join(args.outdirname, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
219
+ plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], filename=os.path.join(args.outdirname, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
220
+ save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdirname, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
221
 
222
  for k in metrics_file.keys():
223
  print('Summary {}\n=======\n'.format(k))