Commit
·
92caf3a
1
Parent(s):
82949fb
Skip the background segmentation mask
Browse filesCompute the HD95 metric
Allow the user to save or not the generated NIfTI images
Brain_study/Evaluate_network__test_fixed.py
CHANGED
@@ -55,6 +55,7 @@ if __name__ == '__main__':
|
|
55 |
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
|
56 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
57 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
|
|
58 |
|
59 |
args = parser.parse_args()
|
60 |
if args.model is not None:
|
@@ -76,10 +77,10 @@ if __name__ == '__main__':
|
|
76 |
|
77 |
with h5py.File(list_test_files[0], 'r') as f:
|
78 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
79 |
-
nb_labels = f['fix_segmentations'][:].shape[-1]
|
80 |
|
81 |
# Header of the metrics csv file
|
82 |
-
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
83 |
|
84 |
# TF stuff
|
85 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
@@ -136,17 +137,17 @@ if __name__ == '__main__':
|
|
136 |
print('DESTINATION FOLDER: ', output_folder)
|
137 |
|
138 |
try:
|
139 |
-
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
140 |
'VxmDense': vxm.networks.VxmDense,
|
141 |
'AdamAccumulated': AdamAccumulated,
|
142 |
'loss': loss_fncs,
|
143 |
'metric': metric_fncs},
|
144 |
compile=False)
|
145 |
except ValueError as e:
|
146 |
-
enc_features = [
|
147 |
-
dec_features = [
|
148 |
nb_features = [enc_features, dec_features]
|
149 |
-
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
150 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
151 |
nb_labels=nb_labels,
|
152 |
nb_unet_features=nb_features,
|
@@ -154,6 +155,7 @@ if __name__ == '__main__':
|
|
154 |
int_downsize=1,
|
155 |
seg_downsize=1)
|
156 |
else:
|
|
|
157 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
158 |
nb_unet_features=nb_features,
|
159 |
int_steps=0)
|
@@ -167,14 +169,15 @@ if __name__ == '__main__':
|
|
167 |
with sess.as_default():
|
168 |
sess.run(tf.global_variables_initializer())
|
169 |
network.load_weights(MODEL_FILE, by_name=True)
|
|
|
170 |
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
171 |
for step, in_batch in progress_bar:
|
172 |
with h5py.File(in_batch, 'r') as f:
|
173 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
174 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
175 |
-
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
176 |
-
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
177 |
-
fix_centroids = f['fix_centroids'][
|
178 |
isotropic_shape = f['isotropic_shape'][:]
|
179 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
180 |
|
@@ -205,6 +208,7 @@ if __name__ == '__main__':
|
|
205 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
206 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
207 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
|
|
208 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
209 |
|
210 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
@@ -229,16 +233,17 @@ if __name__ == '__main__':
|
|
229 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
230 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
231 |
|
232 |
-
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
233 |
with open(metrics_file, 'a') as f:
|
234 |
f.write(';'.join(map(str, new_line))+'\n')
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
242 |
|
243 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
244 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
|
|
55 |
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
|
56 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
57 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
58 |
+
parser.add_argument('--savenifti', type=bool, default=True)
|
59 |
|
60 |
args = parser.parse_args()
|
61 |
if args.model is not None:
|
|
|
77 |
|
78 |
with h5py.File(list_test_files[0], 'r') as f:
|
79 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
80 |
+
nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
|
81 |
|
82 |
# Header of the metrics csv file
|
83 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
84 |
|
85 |
# TF stuff
|
86 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
|
|
137 |
print('DESTINATION FOLDER: ', output_folder)
|
138 |
|
139 |
try:
|
140 |
+
network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
141 |
'VxmDense': vxm.networks.VxmDense,
|
142 |
'AdamAccumulated': AdamAccumulated,
|
143 |
'loss': loss_fncs,
|
144 |
'metric': metric_fncs},
|
145 |
compile=False)
|
146 |
except ValueError as e:
|
147 |
+
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
148 |
+
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
149 |
nb_features = [enc_features, dec_features]
|
150 |
+
if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
151 |
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
152 |
nb_labels=nb_labels,
|
153 |
nb_unet_features=nb_features,
|
|
|
155 |
int_downsize=1,
|
156 |
seg_downsize=1)
|
157 |
else:
|
158 |
+
# only load the weights into the same model. To get the same runtime
|
159 |
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
160 |
nb_unet_features=nb_features,
|
161 |
int_steps=0)
|
|
|
169 |
with sess.as_default():
|
170 |
sess.run(tf.global_variables_initializer())
|
171 |
network.load_weights(MODEL_FILE, by_name=True)
|
172 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
173 |
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
174 |
for step, in_batch in progress_bar:
|
175 |
with h5py.File(in_batch, 'r') as f:
|
176 |
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
177 |
mov_img = f['mov_image'][:][np.newaxis, ...]
|
178 |
+
fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
179 |
+
mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
|
180 |
+
fix_centroids = f['fix_centroids'][1:, ...]
|
181 |
isotropic_shape = f['isotropic_shape'][:]
|
182 |
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
183 |
|
|
|
208 |
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
209 |
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
210 |
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
211 |
+
hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
|
212 |
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
213 |
|
214 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
|
|
233 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
234 |
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
235 |
|
236 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
|
237 |
with open(metrics_file, 'a') as f:
|
238 |
f.write(';'.join(map(str, new_line))+'\n')
|
239 |
|
240 |
+
if args.savenifti:
|
241 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
242 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
243 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
244 |
+
save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
245 |
+
save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
246 |
+
save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
247 |
|
248 |
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
249 |
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|