Commit
·
92caf3a
1
Parent(s):
82949fb
Skip the background segmentation mask
Browse filesCompute the HD95 metric
Allow the user to save or not the generated NIfTI images
Brain_study/Evaluate_network__test_fixed.py
CHANGED
|
@@ -55,6 +55,7 @@ if __name__ == '__main__':
|
|
| 55 |
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
|
| 56 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
| 57 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
|
|
|
| 58 |
|
| 59 |
args = parser.parse_args()
|
| 60 |
if args.model is not None:
|
|
@@ -76,10 +77,10 @@ if __name__ == '__main__':
|
|
| 76 |
|
| 77 |
with h5py.File(list_test_files[0], 'r') as f:
|
| 78 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
| 79 |
-
nb_labels = f['fix_segmentations'][:].shape[-1]
|
| 80 |
|
| 81 |
# Header of the metrics csv file
|
| 82 |
-
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 83 |
|
| 84 |
# TF stuff
|
| 85 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
|
@@ -136,17 +137,17 @@ if __name__ == '__main__':
|
|
| 136 |
print('DESTINATION FOLDER: ', output_folder)
|
| 137 |
|
| 138 |
try:
|
| 139 |
-
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
| 140 |
'VxmDense': vxm.networks.VxmDense,
|
| 141 |
'AdamAccumulated': AdamAccumulated,
|
| 142 |
'loss': loss_fncs,
|
| 143 |
'metric': metric_fncs},
|
| 144 |
compile=False)
|
| 145 |
except ValueError as e:
|
| 146 |
-
enc_features = [
|
| 147 |
-
dec_features = [
|
| 148 |
nb_features = [enc_features, dec_features]
|
| 149 |
-
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
| 150 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
| 151 |
nb_labels=nb_labels,
|
| 152 |
nb_unet_features=nb_features,
|
|
@@ -154,6 +155,7 @@ if __name__ == '__main__':
|
|
| 154 |
int_downsize=1,
|
| 155 |
seg_downsize=1)
|
| 156 |
else:
|
|
|
|
| 157 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
| 158 |
nb_unet_features=nb_features,
|
| 159 |
int_steps=0)
|
|
@@ -167,14 +169,15 @@ if __name__ == '__main__':
|
|
| 167 |
with sess.as_default():
|
| 168 |
sess.run(tf.global_variables_initializer())
|
| 169 |
network.load_weights(MODEL_FILE, by_name=True)
|
|
|
|
| 170 |
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
| 171 |
for step, in_batch in progress_bar:
|
| 172 |
with h5py.File(in_batch, 'r') as f:
|
| 173 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
| 174 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
| 175 |
-
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
| 176 |
-
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
| 177 |
-
fix_centroids = f['fix_centroids'][
|
| 178 |
isotropic_shape = f['isotropic_shape'][:]
|
| 179 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
| 180 |
|
|
@@ -205,6 +208,7 @@ if __name__ == '__main__':
|
|
| 205 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
| 206 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 207 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
|
|
|
| 208 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 209 |
|
| 210 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
|
@@ -229,16 +233,17 @@ if __name__ == '__main__':
|
|
| 229 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
| 230 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 231 |
|
| 232 |
-
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
| 233 |
with open(metrics_file, 'a') as f:
|
| 234 |
f.write(';'.join(map(str, new_line))+'\n')
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
| 242 |
|
| 243 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
| 244 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
|
|
|
| 55 |
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
|
| 56 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
| 57 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
| 58 |
+
parser.add_argument('--savenifti', type=bool, default=True)
|
| 59 |
|
| 60 |
args = parser.parse_args()
|
| 61 |
if args.model is not None:
|
|
|
|
| 77 |
|
| 78 |
with h5py.File(list_test_files[0], 'r') as f:
|
| 79 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
| 80 |
+
nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
|
| 81 |
|
| 82 |
# Header of the metrics csv file
|
| 83 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 84 |
|
| 85 |
# TF stuff
|
| 86 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
|
|
|
| 137 |
print('DESTINATION FOLDER: ', output_folder)
|
| 138 |
|
| 139 |
try:
|
| 140 |
+
network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
| 141 |
'VxmDense': vxm.networks.VxmDense,
|
| 142 |
'AdamAccumulated': AdamAccumulated,
|
| 143 |
'loss': loss_fncs,
|
| 144 |
'metric': metric_fncs},
|
| 145 |
compile=False)
|
| 146 |
except ValueError as e:
|
| 147 |
+
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
| 148 |
+
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
| 149 |
nb_features = [enc_features, dec_features]
|
| 150 |
+
if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
| 151 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
| 152 |
nb_labels=nb_labels,
|
| 153 |
nb_unet_features=nb_features,
|
|
|
|
| 155 |
int_downsize=1,
|
| 156 |
seg_downsize=1)
|
| 157 |
else:
|
| 158 |
+
# only load the weights into the same model. To get the same runtime
|
| 159 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
| 160 |
nb_unet_features=nb_features,
|
| 161 |
int_steps=0)
|
|
|
|
| 169 |
with sess.as_default():
|
| 170 |
sess.run(tf.global_variables_initializer())
|
| 171 |
network.load_weights(MODEL_FILE, by_name=True)
|
| 172 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
| 173 |
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
| 174 |
for step, in_batch in progress_bar:
|
| 175 |
with h5py.File(in_batch, 'r') as f:
|
| 176 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
| 177 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
| 178 |
+
fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
| 179 |
+
mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
| 180 |
+
fix_centroids = f['fix_centroids'][1:, ...]
|
| 181 |
isotropic_shape = f['isotropic_shape'][:]
|
| 182 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
| 183 |
|
|
|
|
| 208 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
| 209 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 210 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
| 211 |
+
hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
|
| 212 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
| 213 |
|
| 214 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
|
|
|
| 233 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
| 234 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
| 235 |
|
| 236 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
|
| 237 |
with open(metrics_file, 'a') as f:
|
| 238 |
f.write(';'.join(map(str, new_line))+'\n')
|
| 239 |
|
| 240 |
+
if args.savenifti:
|
| 241 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 242 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 243 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 244 |
+
save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 245 |
+
save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 246 |
+
save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
| 247 |
|
| 248 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
| 249 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|