Commit
·
f915f2e
1
Parent(s):
ca1d395
Added the Method column to the CSV file, to distinguish better between SyN and SyNCC results
Browse filesSetup the # threads to improve computation speed. See: https://github.com/ANTsX/ANTsPy/issues/261
- SoA_methods/eval_ants.py +14 -7
SoA_methods/eval_ants.py
CHANGED
@@ -2,18 +2,22 @@ import h5py
|
|
2 |
import ants
|
3 |
import numpy as np
|
4 |
import nibabel as nb
|
5 |
-
import
|
6 |
-
import os
|
7 |
from tqdm import tqdm
|
8 |
import re
|
9 |
import time
|
10 |
import pandas as pd
|
11 |
|
|
|
|
|
|
|
|
|
12 |
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
13 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
14 |
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
|
15 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
16 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
|
|
17 |
|
18 |
import voxelmorph as vxm
|
19 |
|
@@ -79,18 +83,21 @@ if __name__ == '__main__':
|
|
79 |
sess = tf.Session(config=config)
|
80 |
tf.keras.backend.set_session(sess)
|
81 |
####
|
|
|
|
|
82 |
dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
|
83 |
# Header of the metrics csv file
|
84 |
-
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time_SyN', 'Time_SyNCC', 'TRE']
|
85 |
|
86 |
metrics_file = os.path.join(args.outdir, 'metrics.csv')
|
87 |
with open(metrics_file, 'w') as f:
|
88 |
f.write(';'.join(csv_header)+'\n')
|
89 |
|
90 |
-
|
|
|
91 |
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
|
92 |
|
93 |
-
dataset_iterator.set_description('{} ({}):
|
94 |
with h5py.File(file_path, 'r') as vol_file:
|
95 |
fix_img = vol_file['fix_image'][:]
|
96 |
mov_img = vol_file['mov_image'][:]
|
@@ -118,7 +125,7 @@ if __name__ == '__main__':
|
|
118 |
if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
|
119 |
print('ERR: Registration failed for: '+file_path)
|
120 |
else:
|
121 |
-
for reg_output in [reg_output_syn, reg_output_syncc]:
|
122 |
mov_to_fix_trf_list = reg_output[FWD_TRFS]
|
123 |
pred_img = reg_output[WARPED_MOV].numpy()
|
124 |
pred_img = pred_img[..., np.newaxis] # SoA doesn't work fine with 1-ch images
|
@@ -156,7 +163,7 @@ if __name__ == '__main__':
|
|
156 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
157 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
158 |
|
159 |
-
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1_syn-t0_syn, t1_syncc-t0_syncc, tre]
|
160 |
with open(metrics_file, 'a') as f:
|
161 |
f.write(';'.join(map(str, new_line))+'\n')
|
162 |
|
|
|
2 |
import ants
|
3 |
import numpy as np
|
4 |
import nibabel as nb
|
5 |
+
import os, sys
|
|
|
6 |
from tqdm import tqdm
|
7 |
import re
|
8 |
import time
|
9 |
import pandas as pd
|
10 |
|
11 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
12 |
+
parentdir = os.path.dirname(currentdir)
|
13 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
14 |
+
|
15 |
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
16 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
17 |
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
|
18 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
20 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
|
22 |
import voxelmorph as vxm
|
23 |
|
|
|
83 |
sess = tf.Session(config=config)
|
84 |
tf.keras.backend.set_session(sess)
|
85 |
####
|
86 |
+
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "12" #https://github.com/ANTsX/ANTsPy/issues/261
|
87 |
+
print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
|
88 |
dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
|
89 |
# Header of the metrics csv file
|
90 |
+
csv_header = ['File', 'Method', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time_SyN', 'Time_SyNCC', 'TRE']
|
91 |
|
92 |
metrics_file = os.path.join(args.outdir, 'metrics.csv')
|
93 |
with open(metrics_file, 'w') as f:
|
94 |
f.write(';'.join(csv_header)+'\n')
|
95 |
|
96 |
+
print('Starting the loop')
|
97 |
+
for step, file_path in tqdm(enumerate(dataset_iterator), desc="Running ANTs"):
|
98 |
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
|
99 |
|
100 |
+
dataset_iterator.set_description('{} ({}): loading data'.format(file_num, args.dataset))
|
101 |
with h5py.File(file_path, 'r') as vol_file:
|
102 |
fix_img = vol_file['fix_image'][:]
|
103 |
mov_img = vol_file['mov_image'][:]
|
|
|
125 |
if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
|
126 |
print('ERR: Registration failed for: '+file_path)
|
127 |
else:
|
128 |
+
for reg_method, reg_output in zip(['SyN', 'SyNCC'], [reg_output_syn, reg_output_syncc]):
|
129 |
mov_to_fix_trf_list = reg_output[FWD_TRFS]
|
130 |
pred_img = reg_output[WARPED_MOV].numpy()
|
131 |
pred_img = pred_img[..., np.newaxis] # SoA doesn't work fine with 1-ch images
|
|
|
163 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
164 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
165 |
|
166 |
+
new_line = [step, reg_method, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1_syn-t0_syn, t1_syncc-t0_syncc, tre]
|
167 |
with open(metrics_file, 'a') as f:
|
168 |
f.write(';'.join(map(str, new_line))+'\n')
|
169 |
|