Commit
·
c292437
1
Parent(s):
c7383ff
added to flag to output the displacement map (takes long to resize back to the original resolution)
Browse filesNow the predicted image is resampled back to the original resolution instead, to have consistent results as with the downsized version
DeepDeformationMapRegistration/main.py
CHANGED
@@ -92,6 +92,19 @@ def pad_images(image_1: nib.Nifti1Image, image_2: nib.Nifti1Image):
|
|
92 |
return image_1_padded, image_2_padded
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
def pad_displacement_map(disp_map: np.ndarray, crop_min: np.ndarray, crop_max: np.ndarray, output_shape: (np.ndarray, list)) -> np.ndarray:
|
96 |
ret_val = disp_map
|
97 |
if np.all([d != i for d, i in zip(disp_map.shape[:3], output_shape)]):
|
@@ -183,7 +196,9 @@ def main():
|
|
183 |
parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
|
184 |
parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
|
185 |
parser.add_argument('--original-resolution', action='store_true',
|
186 |
-
help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time',
|
|
|
|
|
187 |
default=False)
|
188 |
args = parser.parse_args()
|
189 |
|
@@ -229,6 +244,7 @@ def main():
|
|
229 |
LOGGER.info('Loading image files')
|
230 |
fixed_image_or = nib.load(args.fixed)
|
231 |
moving_image_or = nib.load(args.moving)
|
|
|
232 |
image_shape_or = np.asarray(fixed_image_or.shape)
|
233 |
fixed_image_or, moving_image_or = pad_images(fixed_image_or, moving_image_or)
|
234 |
fixed_image_or = fixed_image_or[..., np.newaxis] # add channel dim
|
@@ -303,47 +319,76 @@ def main():
|
|
303 |
time_disp_map_start = time.time()
|
304 |
# disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
305 |
p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
306 |
-
disp_map = np.squeeze(disp_map)
|
307 |
time_disp_map_end = time.time()
|
308 |
-
LOGGER.info('\t... done')
|
|
|
309 |
debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
|
310 |
debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
|
311 |
# pred_image = min_max_norm(pred_image)
|
312 |
# pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
313 |
# fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
314 |
|
315 |
-
if args.original_resolution:
|
316 |
-
# Up sample the displacement map to the full res
|
317 |
-
LOGGER.info('Scaling displacement map...')
|
318 |
-
trf = np.eye(4)
|
319 |
-
np.fill_diagonal(trf, 1/zoom_factors)
|
320 |
-
disp_map = resize_displacement_map(disp_map, None, trf)
|
321 |
-
debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
|
322 |
-
disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
|
323 |
-
debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
|
324 |
-
disp_map_or = gaussian_filter(disp_map_or, 5)
|
325 |
-
debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
|
326 |
-
LOGGER.info('\t... done')
|
327 |
-
|
328 |
-
moving_image = moving_image_or
|
329 |
-
fixed_image = fixed_image_or
|
330 |
-
disp_map = disp_map_or
|
331 |
-
|
332 |
LOGGER.info('Applying displacement map...')
|
333 |
time_pred_img_start = time.time()
|
334 |
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
335 |
time_pred_img_end = time.time()
|
336 |
-
LOGGER.info('\t... done')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
|
338 |
LOGGER.info('Computing metrics...')
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
ssim = np.mean(ssim)
|
342 |
ms_ssim = ms_ssim[0]
|
343 |
-
pred_image = pred_image[0, ...]
|
344 |
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
LOGGER.info('Predicted image and displacement map saved in: '.format(args.outputdir))
|
348 |
LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
|
349 |
LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
|
|
|
92 |
return image_1_padded, image_2_padded
|
93 |
|
94 |
|
95 |
+
def pad_crop_to_original_shape(crop_image: np.asarray, output_shape: [tuple, np.asarray], top_left_corner: [tuple, np.asarray]):
|
96 |
+
"""
|
97 |
+
Pad crop_image so the output image has output_shape with the crop where it originally was found
|
98 |
+
"""
|
99 |
+
output_shape = np.asarray(output_shape)
|
100 |
+
top_left_corner = np.asarray(top_left_corner)
|
101 |
+
|
102 |
+
pad = [[c, o - (c + i)] for c, o, i in zip(top_left_corner[:3], output_shape[:3], crop_image.shape[:3])]
|
103 |
+
if len(crop_image.shape) == 4:
|
104 |
+
pad += [[0, 0]]
|
105 |
+
return np.pad(crop_image, pad, mode='constant', constant_values=np.min(crop_image)).astype(crop_image.dtype)
|
106 |
+
|
107 |
+
|
108 |
def pad_displacement_map(disp_map: np.ndarray, crop_min: np.ndarray, crop_max: np.ndarray, output_shape: (np.ndarray, list)) -> np.ndarray:
|
109 |
ret_val = disp_map
|
110 |
if np.all([d != i for d, i in zip(disp_map.shape[:3], output_shape)]):
|
|
|
196 |
parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
|
197 |
parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
|
198 |
parser.add_argument('--original-resolution', action='store_true',
|
199 |
+
help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time.',
|
200 |
+
default=False)
|
201 |
+
parser.add_argument('--save-displacement-map', action='store_true', help='Save the displacement map. An NPZ file will be created.',
|
202 |
default=False)
|
203 |
args = parser.parse_args()
|
204 |
|
|
|
244 |
LOGGER.info('Loading image files')
|
245 |
fixed_image_or = nib.load(args.fixed)
|
246 |
moving_image_or = nib.load(args.moving)
|
247 |
+
moving_image_header = moving_image_or.header.copy()
|
248 |
image_shape_or = np.asarray(fixed_image_or.shape)
|
249 |
fixed_image_or, moving_image_or = pad_images(fixed_image_or, moving_image_or)
|
250 |
fixed_image_or = fixed_image_or[..., np.newaxis] # add channel dim
|
|
|
319 |
time_disp_map_start = time.time()
|
320 |
# disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
321 |
p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
|
|
322 |
time_disp_map_end = time.time()
|
323 |
+
LOGGER.info(f'\t... done ({time_disp_map_end - time_disp_map_start})')
|
324 |
+
disp_map = np.squeeze(disp_map)
|
325 |
debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
|
326 |
debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
|
327 |
# pred_image = min_max_norm(pred_image)
|
328 |
# pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
329 |
# fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
LOGGER.info('Applying displacement map...')
|
332 |
time_pred_img_start = time.time()
|
333 |
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
334 |
time_pred_img_end = time.time()
|
335 |
+
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
336 |
+
pred_image = pred_image[0, ...]
|
337 |
+
|
338 |
+
if args.original_resolution:
|
339 |
+
LOGGER.info('Scaling predicted image...')
|
340 |
+
moving_image = moving_image_or
|
341 |
+
fixed_image = fixed_image_or
|
342 |
+
# disp_map = disp_map_or
|
343 |
+
pred_image = zoom(pred_image, 1/zoom_factors)
|
344 |
+
pred_image = pad_crop_to_original_shape(pred_image, fixed_image_or.shape, crop_min)
|
345 |
+
LOGGER.info('Done...')
|
346 |
|
347 |
LOGGER.info('Computing metrics...')
|
348 |
+
if args.original_resolution:
|
349 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
350 |
+
{'fix_img:0': fixed_image[np.newaxis,
|
351 |
+
crop_min[0]: crop_max[0],
|
352 |
+
crop_min[1]: crop_max[1],
|
353 |
+
crop_min[2]: crop_max[2],
|
354 |
+
...],
|
355 |
+
'pred_img:0': pred_image[np.newaxis,
|
356 |
+
crop_min[0]: crop_max[0],
|
357 |
+
crop_min[1]: crop_max[1],
|
358 |
+
crop_min[2]: crop_max[2],
|
359 |
+
...]}) # to only compare the deformed region!
|
360 |
+
else:
|
361 |
+
|
362 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
363 |
+
{'fix_img:0': fixed_image[np.newaxis, ...],
|
364 |
+
'pred_img:0': pred_image[np.newaxis, ...]})
|
365 |
ssim = np.mean(ssim)
|
366 |
ms_ssim = ms_ssim[0]
|
|
|
367 |
|
368 |
+
if args.original_resolution:
|
369 |
+
save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'), header=moving_image_header)
|
370 |
+
else:
|
371 |
+
save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
|
372 |
+
save_nifti(fixed_image, os.path.join(args.outputdir, 'fixed_image.nii.gz'))
|
373 |
+
save_nifti(moving_image, os.path.join(args.outputdir, 'moving_image.nii.gz'))
|
374 |
+
|
375 |
+
if args.save_displacement_map or args.debug:
|
376 |
+
if args.original_resolution:
|
377 |
+
# Up sample the displacement map to the full res
|
378 |
+
LOGGER.info('Scaling displacement map...')
|
379 |
+
trf = np.eye(4)
|
380 |
+
np.fill_diagonal(trf, 1 / zoom_factors)
|
381 |
+
disp_map = resize_displacement_map(disp_map, None, trf, moving_image_header.get_zooms())
|
382 |
+
debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
|
383 |
+
disp_map = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
|
384 |
+
debug_save_image(np.squeeze(disp_map), 'disp_map_2_padded', args.outputdir, args.debug)
|
385 |
+
disp_map = gaussian_filter(disp_map, 5)
|
386 |
+
debug_save_image(np.squeeze(disp_map), 'disp_map_3_smoothed', args.outputdir, args.debug)
|
387 |
+
LOGGER.info('\t... done')
|
388 |
+
if args.debug:
|
389 |
+
np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
|
390 |
+
else:
|
391 |
+
np.savez_compressed(os.path.join(os.path.join(args.outputdir, 'debug'), 'displacement_map.npz'), disp_map)
|
392 |
LOGGER.info('Predicted image and displacement map saved in: '.format(args.outputdir))
|
393 |
LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
|
394 |
LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
|
DeepDeformationMapRegistration/utils/misc.py
CHANGED
@@ -167,7 +167,7 @@ def segmentation_cardinal_to_ohe(segmentation, labels_list: list = None):
|
|
167 |
return cpy
|
168 |
|
169 |
|
170 |
-
def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray=None):
|
171 |
if scale_trf is None:
|
172 |
scale_trf = scale_transformation(displacement_map.shape, dest_shape)
|
173 |
else:
|
@@ -175,11 +175,12 @@ def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.
|
|
175 |
zoom_factors = scale_trf.diagonal()
|
176 |
# First scale the values, so we cut down the number of multiplications
|
177 |
dm_resized = np.copy(displacement_map)
|
178 |
-
dm_resized[..., 0] *= zoom_factors[0]
|
179 |
-
dm_resized[..., 1] *= zoom_factors[1]
|
180 |
-
dm_resized[..., 2] *= zoom_factors[2]
|
181 |
# Then rescale using zoom
|
182 |
dm_resized = zoom(dm_resized, zoom_factors)
|
|
|
|
|
|
|
|
|
183 |
return dm_resized
|
184 |
|
185 |
|
|
|
167 |
return cpy
|
168 |
|
169 |
|
170 |
+
def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray = None, resolution_factors: [tuple, np.ndarray] = np.ones((3,))):
|
171 |
if scale_trf is None:
|
172 |
scale_trf = scale_transformation(displacement_map.shape, dest_shape)
|
173 |
else:
|
|
|
175 |
zoom_factors = scale_trf.diagonal()
|
176 |
# First scale the values, so we cut down the number of multiplications
|
177 |
dm_resized = np.copy(displacement_map)
|
|
|
|
|
|
|
178 |
# Then rescale using zoom
|
179 |
dm_resized = zoom(dm_resized, zoom_factors)
|
180 |
+
dm_resized *= np.asarray(resolution_factors)
|
181 |
+
# dm_resized[..., 0] *= resolution_factors[0]
|
182 |
+
# dm_resized[..., 1] *= resolution_factors[1]
|
183 |
+
# dm_resized[..., 2] *= resolution_factors[2]
|
184 |
return dm_resized
|
185 |
|
186 |
|