Commit
·
a3cbfc7
1
Parent(s):
cd051bb
Added flag to return the results in the original input image resolution
Browse files
DeepDeformationMapRegistration/main.py
CHANGED
@@ -6,6 +6,7 @@ import argparse
|
|
6 |
import subprocess
|
7 |
import logging
|
8 |
import time
|
|
|
9 |
|
10 |
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
11 |
# parentdir = os.path.dirname(currentdir)
|
@@ -180,15 +181,17 @@ def main():
|
|
180 |
default=None)
|
181 |
parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
|
182 |
default='UW-NSD')
|
183 |
-
parser.add_argument('
|
184 |
parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
|
|
|
|
|
|
|
185 |
args = parser.parse_args()
|
186 |
|
187 |
assert os.path.exists(args.fixed), 'Fixed image not found'
|
188 |
assert os.path.exists(args.moving), 'Moving image not found'
|
189 |
assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
|
190 |
assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
|
191 |
-
|
192 |
if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
|
193 |
if args.clear_outputdir:
|
194 |
erase = 'y'
|
@@ -217,6 +220,12 @@ def main():
|
|
217 |
LOGGER.setLevel('DEBUG')
|
218 |
LOGGER.debug('DEBUG MODE ENABLED')
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
# Load the file and preprocess it
|
221 |
LOGGER.info('Loading image files')
|
222 |
fixed_image_or = nib.load(args.fixed)
|
@@ -295,6 +304,7 @@ def main():
|
|
295 |
time_disp_map_start = time.time()
|
296 |
# disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
297 |
p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
|
|
298 |
time_disp_map_end = time.time()
|
299 |
LOGGER.info('\t... done')
|
300 |
debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
|
@@ -303,33 +313,38 @@ def main():
|
|
303 |
# pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
304 |
# fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
LOGGER.info('Applying displacement map...')
|
319 |
time_pred_img_start = time.time()
|
320 |
-
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([
|
321 |
time_pred_img_end = time.time()
|
322 |
LOGGER.info('\t... done')
|
323 |
|
324 |
LOGGER.info('Computing metrics...')
|
325 |
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
326 |
-
{'fix_img:0':
|
327 |
ssim = np.mean(ssim)
|
328 |
ms_ssim = ms_ssim[0]
|
329 |
pred_image = pred_image[0, ...]
|
330 |
|
331 |
save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
|
332 |
-
np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'),
|
333 |
LOGGER.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
|
334 |
LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
|
335 |
LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
|
@@ -340,15 +355,15 @@ def main():
|
|
340 |
LOGGER.info('MSE: {:.03f}'.format(mse))
|
341 |
LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
342 |
|
343 |
-
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
344 |
-
|
345 |
-
ssim = np.mean(ssim)
|
346 |
-
ms_ssim = ms_ssim[0]
|
347 |
-
LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
|
348 |
-
LOGGER.info('SSIM: {:.03f}'.format(ssim))
|
349 |
-
LOGGER.info('NCC: {:.03f}'.format(ncc))
|
350 |
-
LOGGER.info('MSE: {:.03f}'.format(mse))
|
351 |
-
LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
352 |
|
353 |
del registration_model
|
354 |
LOGGER.info('Done')
|
|
|
6 |
import subprocess
|
7 |
import logging
|
8 |
import time
|
9 |
+
import warnings
|
10 |
|
11 |
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
12 |
# parentdir = os.path.dirname(currentdir)
|
|
|
181 |
default=None)
|
182 |
parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
|
183 |
default='UW-NSD')
|
184 |
+
parser.add_argument('-d', '--debug', action='store_true', help='Produce additional debug information', default=False)
|
185 |
parser.add_argument('-c', '--clear-outputdir', action='store_true', help='Clear output folder if this has content', default=False)
|
186 |
+
parser.add_argument('--original-resolution', action='store_true',
|
187 |
+
help='Re-scale the displacement map to the originla resolution and apply it to the original moving image. WARNING: longer processing time',
|
188 |
+
default=False)
|
189 |
args = parser.parse_args()
|
190 |
|
191 |
assert os.path.exists(args.fixed), 'Fixed image not found'
|
192 |
assert os.path.exists(args.moving), 'Moving image not found'
|
193 |
assert args.model in C.MODEL_TYPES.keys(), 'Invalid model type'
|
194 |
assert args.anatomy in C.ANATOMIES.keys(), 'Invalid anatomy option'
|
|
|
195 |
if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
|
196 |
if args.clear_outputdir:
|
197 |
erase = 'y'
|
|
|
220 |
LOGGER.setLevel('DEBUG')
|
221 |
LOGGER.debug('DEBUG MODE ENABLED')
|
222 |
|
223 |
+
if args.original_resolution:
|
224 |
+
LOGGER.info('The results will be rescaled back to the original image resolution. '
|
225 |
+
'Expect longer post-processing times.')
|
226 |
+
else:
|
227 |
+
LOGGER.info(f'The results will NOT be rescaled. Output shape will be {C.IMG_SHAPE[:3]}.')
|
228 |
+
|
229 |
# Load the file and preprocess it
|
230 |
LOGGER.info('Loading image files')
|
231 |
fixed_image_or = nib.load(args.fixed)
|
|
|
304 |
time_disp_map_start = time.time()
|
305 |
# disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
306 |
p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
|
307 |
+
disp_map = np.squeeze(disp_map)
|
308 |
time_disp_map_end = time.time()
|
309 |
LOGGER.info('\t... done')
|
310 |
debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
|
|
|
313 |
# pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
314 |
# fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
315 |
|
316 |
+
if args.original_resolution:
|
317 |
+
# Up sample the displacement map to the full res
|
318 |
+
LOGGER.info('Scaling displacement map...')
|
319 |
+
trf = np.eye(4)
|
320 |
+
np.fill_diagonal(trf, 1/zoom_factors)
|
321 |
+
disp_map = resize_displacement_map(disp_map, None, trf)
|
322 |
+
debug_save_image(disp_map, 'disp_map_1_upsampled', args.outputdir, args.debug)
|
323 |
+
disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
|
324 |
+
debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
|
325 |
+
disp_map_or = gaussian_filter(disp_map_or, 5)
|
326 |
+
debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
|
327 |
+
LOGGER.info('\t... done')
|
328 |
+
|
329 |
+
moving_image = moving_image_or
|
330 |
+
fixed_image = fixed_image_or
|
331 |
+
disp_map = disp_map_or
|
332 |
|
333 |
LOGGER.info('Applying displacement map...')
|
334 |
time_pred_img_start = time.time()
|
335 |
+
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
336 |
time_pred_img_end = time.time()
|
337 |
LOGGER.info('\t... done')
|
338 |
|
339 |
LOGGER.info('Computing metrics...')
|
340 |
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
341 |
+
{'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': pred_image})
|
342 |
ssim = np.mean(ssim)
|
343 |
ms_ssim = ms_ssim[0]
|
344 |
pred_image = pred_image[0, ...]
|
345 |
|
346 |
save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
|
347 |
+
np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map)
|
348 |
LOGGER.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
|
349 |
LOGGER.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
|
350 |
LOGGER.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
|
|
|
355 |
LOGGER.info('MSE: {:.03f}'.format(mse))
|
356 |
LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
357 |
|
358 |
+
# ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
359 |
+
# {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': p})
|
360 |
+
# ssim = np.mean(ssim)
|
361 |
+
# ms_ssim = ms_ssim[0]
|
362 |
+
# LOGGER.info('\nSimilarity metrics (ROI)\n------------------')
|
363 |
+
# LOGGER.info('SSIM: {:.03f}'.format(ssim))
|
364 |
+
# LOGGER.info('NCC: {:.03f}'.format(ncc))
|
365 |
+
# LOGGER.info('MSE: {:.03f}'.format(mse))
|
366 |
+
# LOGGER.info('MS SSIM: {:.03f}'.format(ms_ssim))
|
367 |
|
368 |
del registration_model
|
369 |
LOGGER.info('Done')
|