Commit
·
82949fb
1
Parent(s):
7bfdced
Skip the background segmentation mask
Browse filesCompute the HD95 metric
Allow the user to save or not the generated NIfTI images
- COMET/Evaluate_network.py +20 -16
COMET/Evaluate_network.py
CHANGED
@@ -63,6 +63,7 @@ if __name__ == '__main__':
|
|
63 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
64 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
65 |
parser.add_argument('--fullres', action='store_true', default=False)
|
|
|
66 |
args = parser.parse_args()
|
67 |
if args.model is not None:
|
68 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
@@ -95,10 +96,10 @@ if __name__ == '__main__':
|
|
95 |
|
96 |
with h5py.File(list_test_files[0], 'r') as f:
|
97 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
98 |
-
nb_labels = f['fix_segmentations'][:].shape[-1]
|
99 |
|
100 |
# Header of the metrics csv file
|
101 |
-
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
102 |
|
103 |
# TF stuff
|
104 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
@@ -163,17 +164,17 @@ if __name__ == '__main__':
|
|
163 |
print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
|
164 |
|
165 |
try:
|
166 |
-
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
167 |
'VxmDense': vxm.networks.VxmDense,
|
168 |
'AdamAccumulated': AdamAccumulated,
|
169 |
'loss': loss_fncs,
|
170 |
'metric': metric_fncs},
|
171 |
compile=False)
|
172 |
except ValueError as e:
|
173 |
-
enc_features = [
|
174 |
-
dec_features = [
|
175 |
nb_features = [enc_features, dec_features]
|
176 |
-
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
177 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
178 |
nb_labels=nb_labels,
|
179 |
nb_unet_features=nb_features,
|
@@ -181,6 +182,7 @@ if __name__ == '__main__':
|
|
181 |
int_downsize=1,
|
182 |
seg_downsize=1)
|
183 |
else:
|
|
|
184 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
185 |
nb_unet_features=nb_features,
|
186 |
int_steps=0)
|
@@ -205,9 +207,9 @@ if __name__ == '__main__':
|
|
205 |
with h5py.File(in_batch, 'r') as f:
|
206 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
207 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
208 |
-
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
209 |
-
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
210 |
-
fix_centroids = f['fix_centroids'][
|
211 |
isotropic_shape = f['isotropic_shape'][:]
|
212 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
213 |
|
@@ -238,6 +240,7 @@ if __name__ == '__main__':
|
|
238 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
239 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
240 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
|
|
241 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
242 |
|
243 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
@@ -261,7 +264,7 @@ if __name__ == '__main__':
|
|
261 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
262 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
263 |
|
264 |
-
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
265 |
with open(metrics_file, 'a') as f:
|
266 |
f.write(';'.join(map(str, new_line))+'\n')
|
267 |
|
@@ -337,12 +340,13 @@ if __name__ == '__main__':
|
|
337 |
with open(metrics_file_fr, 'a') as f:
|
338 |
f.write(';'.join(map(str, new_line)) + '\n')
|
339 |
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
|
|
346 |
|
347 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
348 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
|
|
63 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
64 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
65 |
parser.add_argument('--fullres', action='store_true', default=False)
|
66 |
+
parser.add_argument('--savenifti', type=bool, default=True)
|
67 |
args = parser.parse_args()
|
68 |
if args.model is not None:
|
69 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
|
|
96 |
|
97 |
with h5py.File(list_test_files[0], 'r') as f:
|
98 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
99 |
+
nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
|
100 |
|
101 |
# Header of the metrics csv file
|
102 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
103 |
|
104 |
# TF stuff
|
105 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
|
|
164 |
print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
|
165 |
|
166 |
try:
|
167 |
+
network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
168 |
'VxmDense': vxm.networks.VxmDense,
|
169 |
'AdamAccumulated': AdamAccumulated,
|
170 |
'loss': loss_fncs,
|
171 |
'metric': metric_fncs},
|
172 |
compile=False)
|
173 |
except ValueError as e:
|
174 |
+
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
175 |
+
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
176 |
nb_features = [enc_features, dec_features]
|
177 |
+
if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
178 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
179 |
nb_labels=nb_labels,
|
180 |
nb_unet_features=nb_features,
|
|
|
182 |
int_downsize=1,
|
183 |
seg_downsize=1)
|
184 |
else:
|
185 |
+
# only load the weights into the same model. To get the same runtime
|
186 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
187 |
nb_unet_features=nb_features,
|
188 |
int_steps=0)
|
|
|
207 |
with h5py.File(in_batch, 'r') as f:
|
208 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
209 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
210 |
+
fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
211 |
+
mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
212 |
+
fix_centroids = f['fix_centroids'][1:, ...]
|
213 |
isotropic_shape = f['isotropic_shape'][:]
|
214 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
215 |
|
|
|
240 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
241 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
242 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
243 |
+
hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
|
244 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
245 |
|
246 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
|
|
264 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
265 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
266 |
|
267 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
|
268 |
with open(metrics_file, 'a') as f:
|
269 |
f.write(';'.join(map(str, new_line))+'\n')
|
270 |
|
|
|
340 |
with open(metrics_file_fr, 'a') as f:
|
341 |
f.write(';'.join(map(str, new_line)) + '\n')
|
342 |
|
343 |
+
if args.savenifti:
|
344 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
345 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
346 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
347 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
348 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
349 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
350 |
|
351 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
352 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|