jpdefrutos commited on
Commit
92caf3a
·
1 Parent(s): 82949fb

Skip the background segmentation mask

Browse files

Compute the HD95 metric
Allow the user to save or not the generated NIfTI images

Brain_study/Evaluate_network__test_fixed.py CHANGED
@@ -55,6 +55,7 @@ if __name__ == '__main__':
55
  default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
56
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
57
  parser.add_argument('--outdirname', type=str, default='Evaluate')
 
58
 
59
  args = parser.parse_args()
60
  if args.model is not None:
@@ -76,10 +77,10 @@ if __name__ == '__main__':
76
 
77
  with h5py.File(list_test_files[0], 'r') as f:
78
  image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
79
- nb_labels = f['fix_segmentations'][:].shape[-1]
80
 
81
  # Header of the metrics csv file
82
- csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
83
 
84
  # TF stuff
85
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
@@ -136,17 +137,17 @@ if __name__ == '__main__':
136
  print('DESTINATION FOLDER: ', output_folder)
137
 
138
  try:
139
- network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
140
  'VxmDense': vxm.networks.VxmDense,
141
  'AdamAccumulated': AdamAccumulated,
142
  'loss': loss_fncs,
143
  'metric': metric_fncs},
144
  compile=False)
145
  except ValueError as e:
146
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
147
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
148
  nb_features = [enc_features, dec_features]
149
- if re.search('^UW|SEGGUIDED_', MODEL_FILE):
150
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
151
  nb_labels=nb_labels,
152
  nb_unet_features=nb_features,
@@ -154,6 +155,7 @@ if __name__ == '__main__':
154
  int_downsize=1,
155
  seg_downsize=1)
156
  else:
 
157
  network = vxm.networks.VxmDense(inshape=image_output_shape,
158
  nb_unet_features=nb_features,
159
  int_steps=0)
@@ -167,14 +169,15 @@ if __name__ == '__main__':
167
  with sess.as_default():
168
  sess.run(tf.global_variables_initializer())
169
  network.load_weights(MODEL_FILE, by_name=True)
 
170
  progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
171
  for step, in_batch in progress_bar:
172
  with h5py.File(in_batch, 'r') as f:
173
  fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
174
  mov_img = f['mov_image'][:][np.newaxis, ...]
175
- fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
176
- mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
177
- fix_centroids = f['fix_centroids'][:]
178
  isotropic_shape = f['isotropic_shape'][:]
179
  voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
180
 
@@ -205,6 +208,7 @@ if __name__ == '__main__':
205
  # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
206
  dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
207
  hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
 
208
  dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
209
 
210
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
@@ -229,16 +233,17 @@ if __name__ == '__main__':
229
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
230
  # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
231
 
232
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
233
  with open(metrics_file, 'a') as f:
234
  f.write(';'.join(map(str, new_line))+'\n')
235
 
236
- save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
237
- save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
238
- save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
239
- save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
240
- save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
241
- save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
 
242
 
243
  # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
244
  # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
 
55
  default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
56
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
57
  parser.add_argument('--outdirname', type=str, default='Evaluate')
58
+ parser.add_argument('--savenifti', type=bool, default=True)
59
 
60
  args = parser.parse_args()
61
  if args.model is not None:
 
77
 
78
  with h5py.File(list_test_files[0], 'r') as f:
79
  image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
80
+ nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
81
 
82
  # Header of the metrics csv file
83
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
84
 
85
  # TF stuff
86
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
 
137
  print('DESTINATION FOLDER: ', output_folder)
138
 
139
  try:
140
+ network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
141
  'VxmDense': vxm.networks.VxmDense,
142
  'AdamAccumulated': AdamAccumulated,
143
  'loss': loss_fncs,
144
  'metric': metric_fncs},
145
  compile=False)
146
  except ValueError as e:
147
+ enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
148
+ dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
149
  nb_features = [enc_features, dec_features]
150
+ if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
151
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
152
  nb_labels=nb_labels,
153
  nb_unet_features=nb_features,
 
155
  int_downsize=1,
156
  seg_downsize=1)
157
  else:
158
+ # only load the weights into the same model. To get the same runtime
159
  network = vxm.networks.VxmDense(inshape=image_output_shape,
160
  nb_unet_features=nb_features,
161
  int_steps=0)
 
169
  with sess.as_default():
170
  sess.run(tf.global_variables_initializer())
171
  network.load_weights(MODEL_FILE, by_name=True)
172
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
173
  progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
174
  for step, in_batch in progress_bar:
175
  with h5py.File(in_batch, 'r') as f:
176
  fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
177
  mov_img = f['mov_image'][:][np.newaxis, ...]
178
+ fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
179
+ mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
180
+ fix_centroids = f['fix_centroids'][1:, ...]
181
  isotropic_shape = f['isotropic_shape'][:]
182
  voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
183
 
 
208
  # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
209
  dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
210
  hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
211
+ hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
212
  dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
213
 
214
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
 
233
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
234
  # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
235
 
236
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
237
  with open(metrics_file, 'a') as f:
238
  f.write(';'.join(map(str, new_line))+'\n')
239
 
240
+ if args.savenifti:
241
+ save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
242
+ save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
243
+ save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
244
+ save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
245
+ save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
246
+ save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
247
 
248
  # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
249
  # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)