jpdefrutos commited on
Commit
82949fb
·
1 Parent(s): 7bfdced

Skip the background segmentation mask

Browse files

Compute the HD95 metric
Allow the user to save or not the generated NIfTI images

Files changed (1) hide show
  1. COMET/Evaluate_network.py +20 -16
COMET/Evaluate_network.py CHANGED
@@ -63,6 +63,7 @@ if __name__ == '__main__':
63
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
64
  parser.add_argument('--outdirname', type=str, default='Evaluate')
65
  parser.add_argument('--fullres', action='store_true', default=False)
 
66
  args = parser.parse_args()
67
  if args.model is not None:
68
  assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
@@ -95,10 +96,10 @@ if __name__ == '__main__':
95
 
96
  with h5py.File(list_test_files[0], 'r') as f:
97
  image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
98
- nb_labels = f['fix_segmentations'][:].shape[-1]
99
 
100
  # Header of the metrics csv file
101
- csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
102
 
103
  # TF stuff
104
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
@@ -163,17 +164,17 @@ if __name__ == '__main__':
163
  print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
164
 
165
  try:
166
- network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
167
  'VxmDense': vxm.networks.VxmDense,
168
  'AdamAccumulated': AdamAccumulated,
169
  'loss': loss_fncs,
170
  'metric': metric_fncs},
171
  compile=False)
172
  except ValueError as e:
173
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
174
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
175
  nb_features = [enc_features, dec_features]
176
- if re.search('^UW|SEGGUIDED_', MODEL_FILE):
177
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
178
  nb_labels=nb_labels,
179
  nb_unet_features=nb_features,
@@ -181,6 +182,7 @@ if __name__ == '__main__':
181
  int_downsize=1,
182
  seg_downsize=1)
183
  else:
 
184
  network = vxm.networks.VxmDense(inshape=image_output_shape,
185
  nb_unet_features=nb_features,
186
  int_steps=0)
@@ -205,9 +207,9 @@ if __name__ == '__main__':
205
  with h5py.File(in_batch, 'r') as f:
206
  fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
207
  mov_img = f['mov_image'][:][np.newaxis, ...]
208
- fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
209
- mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
210
- fix_centroids = f['fix_centroids'][:]
211
  isotropic_shape = f['isotropic_shape'][:]
212
  voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
213
 
@@ -238,6 +240,7 @@ if __name__ == '__main__':
238
  # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
239
  dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
240
  hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
 
241
  dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
242
 
243
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
@@ -261,7 +264,7 @@ if __name__ == '__main__':
261
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
262
  # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
263
 
264
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
265
  with open(metrics_file, 'a') as f:
266
  f.write(';'.join(map(str, new_line))+'\n')
267
 
@@ -337,12 +340,13 @@ if __name__ == '__main__':
337
  with open(metrics_file_fr, 'a') as f:
338
  f.write(';'.join(map(str, new_line)) + '\n')
339
 
340
- save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
341
- save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
342
- save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
343
- save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
344
- save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
345
- save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
 
346
 
347
  # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
348
  # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
 
63
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
64
  parser.add_argument('--outdirname', type=str, default='Evaluate')
65
  parser.add_argument('--fullres', action='store_true', default=False)
66
+ parser.add_argument('--savenifti', type=bool, default=True)
67
  args = parser.parse_args()
68
  if args.model is not None:
69
  assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
 
96
 
97
  with h5py.File(list_test_files[0], 'r') as f:
98
  image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
99
+ nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
100
 
101
  # Header of the metrics csv file
102
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
103
 
104
  # TF stuff
105
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
 
164
  print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
165
 
166
  try:
167
+ network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
168
  'VxmDense': vxm.networks.VxmDense,
169
  'AdamAccumulated': AdamAccumulated,
170
  'loss': loss_fncs,
171
  'metric': metric_fncs},
172
  compile=False)
173
  except ValueError as e:
174
+ enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
175
+ dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
176
  nb_features = [enc_features, dec_features]
177
+ if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
178
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
179
  nb_labels=nb_labels,
180
  nb_unet_features=nb_features,
 
182
  int_downsize=1,
183
  seg_downsize=1)
184
  else:
185
+ # only load the weights into the same model. To get the same runtime
186
  network = vxm.networks.VxmDense(inshape=image_output_shape,
187
  nb_unet_features=nb_features,
188
  int_steps=0)
 
207
  with h5py.File(in_batch, 'r') as f:
208
  fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
209
  mov_img = f['mov_image'][:][np.newaxis, ...]
210
+ fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
211
+ mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
212
+ fix_centroids = f['fix_centroids'][1:, ...]
213
  isotropic_shape = f['isotropic_shape'][:]
214
  voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
215
 
 
240
  # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
241
  dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
242
  hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
243
+ hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
244
  dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
245
 
246
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
 
264
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
265
  # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
266
 
267
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
268
  with open(metrics_file, 'a') as f:
269
  f.write(';'.join(map(str, new_line))+'\n')
270
 
 
340
  with open(metrics_file_fr, 'a') as f:
341
  f.write(';'.join(map(str, new_line)) + '\n')
342
 
343
+ if args.savenifti:
344
+ save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
345
+ save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
346
+ save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
347
+ save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
348
+ save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
349
+ save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
350
 
351
  # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
352
  # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)