jpdefrutos commited on
Commit
c292437
·
1 Parent(s): c7383ff

added to flag to output the displacement map (takes long to resize back to the original resolution)

Browse files

Now the predicted image is resampled back to the original resolution instead, to have consistent results as with the downsized version

DeepDeformationMapRegistration/main.py CHANGED
@@ -92,6 +92,19 @@ def pad_images(image_1: nib.Nifti1Image, image_2: nib.Nifti1Image):
92
  return image_1_padded, image_2_padded
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def pad_displacement_map(disp_map: np.ndarray, crop_min: np.ndarray, crop_max: np.ndarray, output_shape: (np.ndarray, list)) -> np.ndarray:
96
  ret_val = disp_map
97
  if np.all([d != i for d, i in zip(disp_map.shape[:3], output_shape)]):
@@ -183,7 +196,9 @@ def main():
183
  parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
184
  parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
185
  parser.add_argument('--original-resolution', action='store_true',
186
- help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time',
 
 
187
  default=False)
188
  args = parser.parse_args()
189
 
@@ -229,6 +244,7 @@ def main():
229
  LOGGER.info('Loading image files')
230
  fixed_image_or = nib.load(args.fixed)
231
  moving_image_or = nib.load(args.moving)
 
232
  image_shape_or = np.asarray(fixed_image_or.shape)
233
  fixed_image_or, moving_image_or = pad_images(fixed_image_or, moving_image_or)
234
  fixed_image_or = fixed_image_or[..., np.newaxis] # add channel dim
@@ -303,47 +319,76 @@ def main():
303
  time_disp_map_start = time.time()
304
  # disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
305
  p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
306
- disp_map = np.squeeze(disp_map)
307
  time_disp_map_end = time.time()
308
- LOGGER.info('\t... done')
 
309
  debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
310
  debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
311
  # pred_image = min_max_norm(pred_image)
312
  # pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
313
  # fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
314
 
315
- if args.original_resolution:
316
- # Up sample the displacement map to the full res
317
- LOGGER.info('Scaling displacement map...')
318
- trf = np.eye(4)
319
- np.fill_diagonal(trf, 1/zoom_factors)
320
- disp_map = resize_displacement_map(disp_map, None, trf)
321
- debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
322
- disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
323
- debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
324
- disp_map_or = gaussian_filter(disp_map_or, 5)
325
- debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
326
- LOGGER.info('\t... done')
327
-
328
- moving_image = moving_image_or
329
- fixed_image = fixed_image_or
330
- disp_map = disp_map_or
331
-
332
  LOGGER.info('Applying displacement map...')
333
  time_pred_img_start = time.time()
334
  pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
335
  time_pred_img_end = time.time()
336
- LOGGER.info('\t... done')
 
 
 
 
 
 
 
 
 
 
337
 
338
  LOGGER.info('Computing metrics...')
339
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
340
- {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': pred_image})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  ssim = np.mean(ssim)
342
  ms_ssim = ms_ssim[0]
343
- pred_image = pred_image[0, ...]
344
 
345
- save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
346
- np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  LOGGER.info('Predicted image and displacement map saved in: '.format(args.outputdir))
348
  LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
349
  LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
 
92
  return image_1_padded, image_2_padded
93
 
94
 
95
+ def pad_crop_to_original_shape(crop_image: np.asarray, output_shape: [tuple, np.asarray], top_left_corner: [tuple, np.asarray]):
96
+ """
97
+ Pad crop_image so the output image has output_shape with the crop where it originally was found
98
+ """
99
+ output_shape = np.asarray(output_shape)
100
+ top_left_corner = np.asarray(top_left_corner)
101
+
102
+ pad = [[c, o - (c + i)] for c, o, i in zip(top_left_corner[:3], output_shape[:3], crop_image.shape[:3])]
103
+ if len(crop_image.shape) == 4:
104
+ pad += [[0, 0]]
105
+ return np.pad(crop_image, pad, mode='constant', constant_values=np.min(crop_image)).astype(crop_image.dtype)
106
+
107
+
108
  def pad_displacement_map(disp_map: np.ndarray, crop_min: np.ndarray, crop_max: np.ndarray, output_shape: (np.ndarray, list)) -> np.ndarray:
109
  ret_val = disp_map
110
  if np.all([d != i for d, i in zip(disp_map.shape[:3], output_shape)]):
 
196
  parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
197
  parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
198
  parser.add_argument('--original-resolution', action='store_true',
199
+ help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time.',
200
+ default=False)
201
+ parser.add_argument('--save-displacement-map', action='store_true', help='Save the displacement map. An NPZ file will be created.',
202
  default=False)
203
  args = parser.parse_args()
204
 
 
244
  LOGGER.info('Loading image files')
245
  fixed_image_or = nib.load(args.fixed)
246
  moving_image_or = nib.load(args.moving)
247
+ moving_image_header = moving_image_or.header.copy()
248
  image_shape_or = np.asarray(fixed_image_or.shape)
249
  fixed_image_or, moving_image_or = pad_images(fixed_image_or, moving_image_or)
250
  fixed_image_or = fixed_image_or[..., np.newaxis] # add channel dim
 
319
  time_disp_map_start = time.time()
320
  # disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
321
  p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
 
322
  time_disp_map_end = time.time()
323
+ LOGGER.info(f'\t... done ({time_disp_map_end - time_disp_map_start})')
324
+ disp_map = np.squeeze(disp_map)
325
  debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
326
  debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
327
  # pred_image = min_max_norm(pred_image)
328
  # pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
329
  # fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  LOGGER.info('Applying displacement map...')
332
  time_pred_img_start = time.time()
333
  pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
334
  time_pred_img_end = time.time()
335
+ LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
336
+ pred_image = pred_image[0, ...]
337
+
338
+ if args.original_resolution:
339
+ LOGGER.info('Scaling predicted image...')
340
+ moving_image = moving_image_or
341
+ fixed_image = fixed_image_or
342
+ # disp_map = disp_map_or
343
+ pred_image = zoom(pred_image, 1/zoom_factors)
344
+ pred_image = pad_crop_to_original_shape(pred_image, fixed_image_or.shape, crop_min)
345
+ LOGGER.info('Done...')
346
 
347
  LOGGER.info('Computing metrics...')
348
+ if args.original_resolution:
349
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
350
+ {'fix_img:0': fixed_image[np.newaxis,
351
+ crop_min[0]: crop_max[0],
352
+ crop_min[1]: crop_max[1],
353
+ crop_min[2]: crop_max[2],
354
+ ...],
355
+ 'pred_img:0': pred_image[np.newaxis,
356
+ crop_min[0]: crop_max[0],
357
+ crop_min[1]: crop_max[1],
358
+ crop_min[2]: crop_max[2],
359
+ ...]}) # to only compare the deformed region!
360
+ else:
361
+
362
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
363
+ {'fix_img:0': fixed_image[np.newaxis, ...],
364
+ 'pred_img:0': pred_image[np.newaxis, ...]})
365
  ssim = np.mean(ssim)
366
  ms_ssim = ms_ssim[0]
 
367
 
368
+ if args.original_resolution:
369
+ save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'), header=moving_image_header)
370
+ else:
371
+ save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
372
+ save_nifti(fixed_image, os.path.join(args.outputdir, 'fixed_image.nii.gz'))
373
+ save_nifti(moving_image, os.path.join(args.outputdir, 'moving_image.nii.gz'))
374
+
375
+ if args.save_displacement_map or args.debug:
376
+ if args.original_resolution:
377
+ # Up sample the displacement map to the full res
378
+ LOGGER.info('Scaling displacement map...')
379
+ trf = np.eye(4)
380
+ np.fill_diagonal(trf, 1 / zoom_factors)
381
+ disp_map = resize_displacement_map(disp_map, None, trf, moving_image_header.get_zooms())
382
+ debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
383
+ disp_map = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
384
+ debug_save_image(np.squeeze(disp_map), 'disp_map_2_padded', args.outputdir, args.debug)
385
+ disp_map = gaussian_filter(disp_map, 5)
386
+ debug_save_image(np.squeeze(disp_map), 'disp_map_3_smoothed', args.outputdir, args.debug)
387
+ LOGGER.info('\t... done')
388
+ if args.debug:
389
+ np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
390
+ else:
391
+ np.savez_compressed(os.path.join(os.path.join(args.outputdir, 'debug'), 'displacement_map.npz'), disp_map)
392
  LOGGER.info('Predicted image and displacement map saved in: '.format(args.outputdir))
393
  LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
394
  LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
DeepDeformationMapRegistration/utils/misc.py CHANGED
@@ -167,7 +167,7 @@ def segmentation_cardinal_to_ohe(segmentation, labels_list: list = None):
167
  return cpy
168
 
169
 
170
- def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray=None):
171
  if scale_trf is None:
172
  scale_trf = scale_transformation(displacement_map.shape, dest_shape)
173
  else:
@@ -175,11 +175,12 @@ def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.
175
  zoom_factors = scale_trf.diagonal()
176
  # First scale the values, so we cut down the number of multiplications
177
  dm_resized = np.copy(displacement_map)
178
- dm_resized[..., 0] *= zoom_factors[0]
179
- dm_resized[..., 1] *= zoom_factors[1]
180
- dm_resized[..., 2] *= zoom_factors[2]
181
  # Then rescale using zoom
182
  dm_resized = zoom(dm_resized, zoom_factors)
 
 
 
 
183
  return dm_resized
184
 
185
 
 
167
  return cpy
168
 
169
 
170
+ def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray = None, resolution_factors: [tuple, np.ndarray] = np.ones((3,))):
171
  if scale_trf is None:
172
  scale_trf = scale_transformation(displacement_map.shape, dest_shape)
173
  else:
 
175
  zoom_factors = scale_trf.diagonal()
176
  # First scale the values, so we cut down the number of multiplications
177
  dm_resized = np.copy(displacement_map)
 
 
 
178
  # Then rescale using zoom
179
  dm_resized = zoom(dm_resized, zoom_factors)
180
+ dm_resized *= np.asarray(resolution_factors)
181
+ # dm_resized[..., 0] *= resolution_factors[0]
182
+ # dm_resized[..., 1] *= resolution_factors[1]
183
+ # dm_resized[..., 2] *= resolution_factors[2]
184
  return dm_resized
185
 
186