jpdefrutos commited on
Commit
a3cbfc7
·
1 Parent(s): cd051bb

Added flag to return the results in the original input image resolution

Browse files
DeepDeformationMapRegistration/main.py CHANGED
@@ -6,6 +6,7 @@ import argparse
6
  import subprocess
7
  import logging
8
  import time
 
9
 
10
  # currentdir = os.path.dirname(os.path.realpath(__file__))
11
  # parentdir = os.path.dirname(currentdir)
@@ -180,15 +181,17 @@ def main():
180
  default=None)
181
  parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
182
  default='UW-NSD')
183
- parser.add_argument('--debug', '-d', action='store_true', help='Produce additional debug information', default=False)
184
  parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
 
 
 
185
  args = parser.parse_args()
186
 
187
  assert os.path.exists(args.fixed), 'Fixed image not found'
188
  assert os.path.exists(args.moving), 'Moving image not found'
189
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
190
  assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
191
-
192
  if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
193
  if args.clear_outputdir:
194
  erase = 'y'
@@ -217,6 +220,12 @@ def main():
217
  LOGGER.setLevel('DEBUG')
218
  LOGGER.debug('DEBUG MODE ENABLED')
219
 
 
 
 
 
 
 
220
  # Load the file and preprocess it
221
  LOGGER.info('Loading image files')
222
  fixed_image_or = nib.load(args.fixed)
@@ -295,6 +304,7 @@ def main():
295
  time_disp_map_start = time.time()
296
  # disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
297
  p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
 
298
  time_disp_map_end = time.time()
299
  LOGGER.info('\t... done')
300
  debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
@@ -303,33 +313,38 @@ def main():
303
  # pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
304
  # fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
305
 
306
- # Up sample the displacement map to the full res
307
- LOGGER.info('Scaling displacement map...')
308
- trf = np.eye(4)
309
- np.fill_diagonal(trf, 1/zoom_factors)
310
- disp_map = resize_displacement_map(np.squeeze(disp_map), None, trf)
311
- debug_save_image(np.squeeze(disp_map), 'disp_map_1_upsampled', args.outputdir, args.debug)
312
- disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
313
- debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
314
- disp_map_or = gaussian_filter(disp_map_or, 5)
315
- debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
316
- LOGGER.info('\t... done')
 
 
 
 
 
317
 
318
  LOGGER.info('Applying displacement map...')
319
  time_pred_img_start = time.time()
320
- pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image_or[np.newaxis, ...], disp_map_or[np.newaxis, ...]]).eval()
321
  time_pred_img_end = time.time()
322
  LOGGER.info('\t... done')
323
 
324
  LOGGER.info('Computing metrics...')
325
  ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
326
- {'fix_img:0': fixed_image_or[np.newaxis, ...], 'pred_img:0': pred_image})
327
  ssim = np.mean(ssim)
328
  ms_ssim = ms_ssim[0]
329
  pred_image = pred_image[0, ...]
330
 
331
  save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
332
- np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map_or)
333
  LOGGER.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
334
  LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
335
  LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
@@ -340,15 +355,15 @@ def main():
340
  LOGGER.info('MSE: {:.03f}'.format(mse))
341
  LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
342
 
343
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
344
- {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': p})
345
- ssim = np.mean(ssim)
346
- ms_ssim = ms_ssim[0]
347
- LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
348
- LOGGER.info('SSIM: {:.03f}'.format(ssim))
349
- LOGGER.info('NCC: {:.03f}'.format(ncc))
350
- LOGGER.info('MSE: {:.03f}'.format(mse))
351
- LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
352
 
353
  del registration_model
354
  LOGGER.info('Done')
 
6
  import subprocess
7
  import logging
8
  import time
9
+ import warnings
10
 
11
  # currentdir = os.path.dirname(os.path.realpath(__file__))
12
  # parentdir = os.path.dirname(currentdir)
 
181
  default=None)
182
  parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
183
  default='UW-NSD')
184
+ parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
185
  parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
186
+ parser.add_argument('--original-resolution', action='store_true',
187
+ help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time',
188
+ default=False)
189
  args = parser.parse_args()
190
 
191
  assert os.path.exists(args.fixed), 'Fixed image not found'
192
  assert os.path.exists(args.moving), 'Moving image not found'
193
  assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
194
  assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
 
195
  if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
196
  if args.clear_outputdir:
197
  erase = 'y'
 
220
  LOGGER.setLevel('DEBUG')
221
  LOGGER.debug('DEBUG MODE ENABLED')
222
 
223
+ if args.original_resolution:
224
+ LOGGER.info('The results will be rescaled back to the original image resolution. '
225
+ 'Expect longer post-processing times.')
226
+ else:
227
+ LOGGER.info(f'The results will NOT be rescaled. Output shape will be {C.IMG_SHAPE[:3]}.')
228
+
229
  # Load the file and preprocess it
230
  LOGGER.info('Loading image files')
231
  fixed_image_or = nib.load(args.fixed)
 
304
  time_disp_map_start = time.time()
305
  # disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
306
  p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
307
+ disp_map = np.squeeze(disp_map)
308
  time_disp_map_end = time.time()
309
  LOGGER.info('\t... done')
310
  debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
 
313
  # pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
314
  # fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
315
 
316
+ if args.original_resolution:
317
+ # Up sample the displacement map to the full res
318
+ LOGGER.info('Scaling displacement map...')
319
+ trf = np.eye(4)
320
+ np.fill_diagonal(trf, 1/zoom_factors)
321
+ disp_map = resize_displacement_map(disp_map, None, trf)
322
+ debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
323
+ disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
324
+ debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
325
+ disp_map_or = gaussian_filter(disp_map_or, 5)
326
+ debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
327
+ LOGGER.info('\t... done')
328
+
329
+ moving_image = moving_image_or
330
+ fixed_image = fixed_image_or
331
+ disp_map = disp_map_or
332
 
333
  LOGGER.info('Applying displacement map...')
334
  time_pred_img_start = time.time()
335
+ pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
336
  time_pred_img_end = time.time()
337
  LOGGER.info('\t... done')
338
 
339
  LOGGER.info('Computing metrics...')
340
  ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
341
+ {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': pred_image})
342
  ssim = np.mean(ssim)
343
  ms_ssim = ms_ssim[0]
344
  pred_image = pred_image[0, ...]
345
 
346
  save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
347
+ np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
348
  LOGGER.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
349
  LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
350
  LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
 
355
  LOGGER.info('MSE: {:.03f}'.format(mse))
356
  LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
357
 
358
+ # ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
359
+ # {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': p})
360
+ # ssim = np.mean(ssim)
361
+ # ms_ssim = ms_ssim[0]
362
+ # LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
363
+ # LOGGER.info('SSIM: {:.03f}'.format(ssim))
364
+ # LOGGER.info('NCC: {:.03f}'.format(ncc))
365
+ # LOGGER.info('MSE: {:.03f}'.format(mse))
366
+ # LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
367
 
368
  del registration_model
369
  LOGGER.info('Done')