jpdefrutos commited on
Commit
f915f2e
·
1 Parent(s): ca1d395

Added the Method column to the CSV file, to distinguish better between SyN and SyNCC results

Browse files

Setup the # threads to improve computation speed. See: https://github.com/ANTsX/ANTsPy/issues/261

Files changed (1) hide show
  1. SoA_methods/eval_ants.py +14 -7
SoA_methods/eval_ants.py CHANGED
@@ -2,18 +2,22 @@ import h5py
2
  import ants
3
  import numpy as np
4
  import nibabel as nb
5
- import DeepDeformationMapRegistration.utils.constants as C
6
- import os
7
  from tqdm import tqdm
8
  import re
9
  import time
10
  import pandas as pd
11
 
 
 
 
 
12
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
13
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
14
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
15
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
16
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
 
17
 
18
  import voxelmorph as vxm
19
 
@@ -79,18 +83,21 @@ if __name__ == '__main__':
79
  sess = tf.Session(config=config)
80
  tf.keras.backend.set_session(sess)
81
  ####
 
 
82
  dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
83
  # Header of the metrics csv file
84
- csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time_SyN', 'Time_SyNCC', 'TRE']
85
 
86
  metrics_file = os.path.join(args.outdir, 'metrics.csv')
87
  with open(metrics_file, 'w') as f:
88
  f.write(';'.join(csv_header)+'\n')
89
 
90
- for step, file_path in tqdm(enumerate(dataset_iterator)):
 
91
  file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
92
 
93
- dataset_iterator.set_description('{} ({}): laoding data'.format(file_num, args.dataset))
94
  with h5py.File(file_path, 'r') as vol_file:
95
  fix_img = vol_file['fix_image'][:]
96
  mov_img = vol_file['mov_image'][:]
@@ -118,7 +125,7 @@ if __name__ == '__main__':
118
  if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
119
  print('ERR: Registration failed for: '+file_path)
120
  else:
121
- for reg_output in [reg_output_syn, reg_output_syncc]:
122
  mov_to_fix_trf_list = reg_output[FWD_TRFS]
123
  pred_img = reg_output[WARPED_MOV].numpy()
124
  pred_img = pred_img[..., np.newaxis] # SoA doesn't work fine with 1-ch images
@@ -156,7 +163,7 @@ if __name__ == '__main__':
156
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
157
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
158
 
159
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1_syn-t0_syn, t1_syncc-t0_syncc, tre]
160
  with open(metrics_file, 'a') as f:
161
  f.write(';'.join(map(str, new_line))+'\n')
162
 
 
2
  import ants
3
  import numpy as np
4
  import nibabel as nb
5
+ import os, sys
 
6
  from tqdm import tqdm
7
  import re
8
  import time
9
  import pandas as pd
10
 
11
+ currentdir = os.path.dirname(os.path.realpath(__file__))
12
+ parentdir = os.path.dirname(currentdir)
13
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
14
+
15
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
16
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
17
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
20
+ import DeepDeformationMapRegistration.utils.constants as C
21
 
22
  import voxelmorph as vxm
23
 
 
83
  sess = tf.Session(config=config)
84
  tf.keras.backend.set_session(sess)
85
  ####
86
+ os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "12" #https://github.com/ANTsX/ANTsPy/issues/261
87
+ print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
88
  dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
89
  # Header of the metrics csv file
90
+ csv_header = ['File', 'Method', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time_SyN', 'Time_SyNCC', 'TRE']
91
 
92
  metrics_file = os.path.join(args.outdir, 'metrics.csv')
93
  with open(metrics_file, 'w') as f:
94
  f.write(';'.join(csv_header)+'\n')
95
 
96
+ print('Starting the loop')
97
+ for step, file_path in tqdm(enumerate(dataset_iterator), desc="Running ANTs"):
98
  file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
99
 
100
+ dataset_iterator.set_description('{} ({}): loading data'.format(file_num, args.dataset))
101
  with h5py.File(file_path, 'r') as vol_file:
102
  fix_img = vol_file['fix_image'][:]
103
  mov_img = vol_file['mov_image'][:]
 
125
  if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
126
  print('ERR: Registration failed for: '+file_path)
127
  else:
128
+ for reg_method, reg_output in zip(['SyN', 'SyNCC'], [reg_output_syn, reg_output_syncc]):
129
  mov_to_fix_trf_list = reg_output[FWD_TRFS]
130
  pred_img = reg_output[WARPED_MOV].numpy()
131
  pred_img = pred_img[..., np.newaxis] # SoA doesn't work fine with 1-ch images
 
163
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
164
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
165
 
166
+ new_line = [step, reg_method, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1_syn-t0_syn, t1_syncc-t0_syncc, tre]
167
  with open(metrics_file, 'a') as f:
168
  f.write(';'.join(map(str, new_line))+'\n')
169