jpdefrutos commited on
Commit
0e7de0a
·
1 Parent(s): a290524

Works when providing the segmentation masks of the images

Browse files
Files changed (1) hide show
  1. DeepDeformationMapRegistration/main.py +326 -355
DeepDeformationMapRegistration/main.py CHANGED
@@ -1,104 +1,238 @@
1
-
2
- # 1. Image files generator
3
-
4
- # timer start
5
- # 2. Preprocess the image
6
- # 3. Predict the displacement
7
- # timer stop
8
-
9
- # 4. Evaluate the registration: NCC; SSIM; DICE; HD95
10
-
11
  import os, sys
12
-
13
  import shutil
 
 
 
 
14
  import time
15
- import tkinter
16
-
17
- import h5py
18
- import matplotlib.pyplot as plt
19
 
20
  currentdir = os.path.dirname(os.path.realpath(__file__))
21
  parentdir = os.path.dirname(currentdir)
22
  sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
23
 
24
  import tensorflow as tf
25
- # tf.enable_eager_execution(config=config)
26
 
27
  import numpy as np
28
- import pandas as pd
 
 
 
 
29
  import voxelmorph as vxm
30
  from voxelmorph.tf.layers import SpatialTransformer
31
 
32
  import DeepDeformationMapRegistration.utils.constants as C
33
- from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
34
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
35
- from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
36
- from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
37
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
38
- from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
39
- from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
40
- from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
41
- from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation, GaussianFilter
42
- import medpy.metric as medpy_metrics
43
- from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
44
- from scipy.interpolate import RegularGridInterpolator
45
- from tqdm import tqdm
46
- import nibabel as nib
47
- from scipy.ndimage import gaussian_filter, zoom
 
 
 
 
 
 
 
 
 
48
 
49
- import h5py
50
- import re
51
- from Brain_study.data_generator import BatchGenerator
52
 
53
- import argparse
54
 
55
- from skimage.transform import warp
56
- import neurite as ne
 
57
 
58
- import tempfile
 
 
 
 
 
 
 
 
59
 
60
- import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- MODELS_FILE = {'liver': {'BL-N': './models/liver/bl_ncc.h5',
63
- 'BL-S': './models/liver/bl_ssim.h5',
64
- 'SG-ND': './models/liver/sg_ncc_dsc.h5',
65
- 'SD-NSD': './models/liver/sg_ncc_ssim_dsc.h5',
66
- 'UW-NSD': './models/liver/uw_ncc_ssim_dsc.h5',
67
- 'UW-NSDH': './models/liver/uw_ncc_ssim_dsc_hd.h5',
68
- },
69
- 'brain': {'BL-N': './models/brain/bl_ncc.h5',
70
- 'BL-S': './models/brain/bl_ssim.h5',
71
- 'SG-ND': './models/brain/sg_ncc_dsc.h5',
72
- 'SD-NSD': './models/brain/sg_ncc_ssim_dsc.h5',
73
- 'UW-NSD': './models/brain/uw_ncc_ssim_dsc.h5',
74
- 'UW-NSDH': './models/brain/uw_ncc_ssim_dsc_hd.h5',
75
- }
76
- }
77
 
78
  if __name__ == '__main__':
79
  parser = argparse.ArgumentParser()
80
  parser.add_argument('-f', '--fixed', type=str, help='Path to fixed image file (NIfTI)')
81
- parser.add_argument('-m', '--moving', type=str, help='Path to oving image file (NIfTI)')
 
 
 
82
  parser.add_argument('-o', '--outputdir', type=str, help='Output directory', default='./Registration_output')
83
- parser.add_argument('--gpu', type=int, help='In case of multi-GPU systems, limits the execution to the defined GPU number', default=None)
84
- parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH', default='UW-NSD')
 
 
 
 
 
 
 
85
  # parser.add_argument('--brain', type=bool, action='store_true', help='Perform brain MRi registration', default=False)
86
  args = parser.parse_args()
87
- logger = logging.getLogger(__name__)
88
 
89
  assert os.path.exists(args.fixed), 'Fixed image not found'
90
  assert os.path.exists(args.moving), 'Moving image not found'
91
  assert args.model in ['BL-N', 'BL-S', 'SG-ND', 'SG-NSD', 'UW-NSD', 'UW-NSDH'], 'Invalid model type'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
 
 
 
 
 
93
  if isinstance(args.gpu, int):
94
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
95
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
 
 
 
96
 
97
  # Load the file and preprocess it
98
- fixed_image = nib.load(args.fixed)
99
- moving_image = nib.load(args.moving)
 
 
 
 
 
 
 
100
 
101
  # TF stuff
 
102
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
103
  config.gpu_options.allow_growth = True
104
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
@@ -108,303 +242,140 @@ if __name__ == '__main__':
108
  tf.compat.v1.keras.backend.set_session(sess)
109
 
110
  # Preprocess data
111
- if args.erase:
112
- shutil.rmtree(args.outputdir, ignore_errors=True)
113
- os.makedirs(args.outputdir, exist_ok=True)
114
- lm_output_dir = os.path.join(args.outputdir, 'livermask')
115
- os.makedirs(lm_output_dir, exist_ok=True)
116
-
117
  # 1. Run Livermask to get the mask around the liver in both the fixed and moving image
118
- logger.info('Getting parenchyma segmentations...')
119
- livermask_cmd = "python -m livermaks.livermask --input {} --output {}".format(args.fixed, os.path.join(lm_output_dir, 'fixed.nii.gz'))
120
- os.system(livermask_cmd)
121
- logger.info('... fixed image done...')
122
- livermask_cmd = "python -m livermaks.livermask --input {} --output {}".format(args.moving, os.path.join(lm_output_dir, 'moving.nii.gz'))
123
- os.system(livermask_cmd)
124
- logger.info('... moving image done.')
125
-
126
- # 2. Crop around the liver
127
- # 2.1 Load the segmentations
128
- # 2.2 Find the outermost box containing both boxes
129
- # 2.3 Crop the fixed and moving images using such boxes
130
- # 2.4 Resize the images to the expected input size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # 3. Build the whole graph
133
-
134
-
135
- # Loss and metric functions. Common to all models
136
- # loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
137
- # NCC(image_input_shape).loss,
138
- # vxm.losses.MSE().loss,
139
- # MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
140
- #
141
- # metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
142
- # NCC(image_input_shape).metric,
143
- # vxm.losses.MSE().loss,
144
- # MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
145
- # GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
146
- # HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
147
- # GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
148
-
149
  ### METRICS GRAPH ###
150
- # fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
151
- # pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
152
- # fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
153
- # pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
154
- #
155
- # ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
156
- # ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
157
- # mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
158
- # ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
159
- # dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
160
- # hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
161
- # dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
162
- # hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
163
-
164
- # Needed for VxmDense type of network
165
- warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
166
-
167
- dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
168
-
169
- for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
170
- print('MODEL LOCATION: ', MODEL_FILE)
171
-
172
- # data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
173
- output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
174
- # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
175
-
176
- print('DESTINATION FOLDER: ', output_folder)
177
-
178
- if args.fullres:
179
- output_folder_fr = os.path.join(DATA_ROOT_DIR, args.outdirname, 'full_resolution') # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
180
- # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
181
- if args.erase:
182
- shutil.rmtree(output_folder_fr, ignore_errors=True)
183
- os.makedirs(output_folder_fr, exist_ok=True)
184
- print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
185
-
186
- try:
187
- network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
188
- 'VxmDense': vxm.networks.VxmDense,
189
- 'AdamAccumulated': AdamAccumulated,
190
- 'loss': loss_fncs,
191
- 'metric': metric_fncs},
192
- compile=False)
193
- except ValueError as e:
194
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
195
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
196
- nb_features = [enc_features, dec_features]
197
- if re.search('^UW|SEGGUIDED_', MODEL_FILE):
198
- network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
199
- nb_labels=nb_labels,
200
- nb_unet_features=nb_features,
201
- int_steps=0,
202
- int_downsize=1,
203
- seg_downsize=1)
204
- else:
205
- network = vxm.networks.VxmDense(inshape=image_output_shape,
206
- nb_unet_features=nb_features,
207
- int_steps=0)
208
- network.load_weights(MODEL_FILE, by_name=True)
209
- # Record metrics
210
- metrics_file = os.path.join(output_folder, 'metrics.csv')
211
- with open(metrics_file, 'w') as f:
212
- f.write(';'.join(csv_header)+'\n')
213
-
214
- if args.fullres:
215
- metrics_file_fr = os.path.join(output_folder_fr, 'metrics.csv')
216
- with open(metrics_file_fr, 'w') as f:
217
- f.write(';'.join(csv_header) + '\n')
218
-
219
- ssim = ncc = mse = ms_ssim = dice = hd = 0
220
- with sess.as_default():
221
- sess.run(tf.global_variables_initializer())
222
- network.load_weights(MODEL_FILE, by_name=True)
223
- network.summary(line_length=C.SUMMARY_LINE_LENGTH)
224
- progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
225
- for step, in_batch in progress_bar:
226
- with h5py.File(in_batch, 'r') as f:
227
- fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
228
- mov_img = f['mov_image'][:][np.newaxis, ...]
229
- fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
230
- mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
231
- fix_centroids = f['fix_centroids'][:]
232
- isotropic_shape = f['isotropic_shape'][:]
233
- voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
234
-
235
- if network.name == 'vxm_dense_semi_supervised_seg':
236
- t0 = time.time()
237
- pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
238
- t1 = time.time()
239
- else:
240
- t0 = time.time()
241
- pred_img, disp_map = network.predict([mov_img, fix_img])
242
- pred_seg = warp_segmentation.predict([mov_seg, disp_map])
243
- t1 = time.time()
244
-
245
- pred_img = min_max_norm(pred_img)
246
- mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels+1), brain_study=False)
247
- # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
248
- pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
249
-
250
- # Up sample the segmentation masks to isotropic resolution
251
- zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
252
- pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
253
- fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
254
-
255
- pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
256
- fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
257
-
258
- # I need the labels to be OHE to compute the segmentation metrics.
259
- # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
260
- dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
261
- hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
262
- dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
263
-
264
- pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
265
- mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
266
- fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
267
-
268
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
269
- ssim = np.mean(ssim)
270
- ms_ssim = ms_ssim[0]
271
-
272
- # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
273
- fix_centroids_isotropic = fix_centroids * voxel_size
274
- # mov_centroids_isotropic = mov_centroids * voxel_size
275
- pred_centroids_isotropic = pred_centroids * voxel_size
276
-
277
- # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
278
- # # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
279
- # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
280
- # Now we can measure the TRE in mm
281
- tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
282
- tre = np.mean([v for v in tre_array if not np.isnan(v)])
283
- # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
284
-
285
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
286
- with open(metrics_file, 'a') as f:
287
- f.write(';'.join(map(str, new_line))+'\n')
288
-
289
- save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
290
- save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
291
- save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
292
- save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
293
- save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
294
- save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
295
-
296
- # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
297
- # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
298
- # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
299
- # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
300
- # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
301
- # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
302
-
303
- # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
304
- # _ = plt.hist(magnitude.flatten())
305
- # plt.title('Histogram of disp. magnitudes')
306
- # plt.show(block=False)
307
- # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
308
- # plt.close()
309
-
310
- plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
311
- plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
312
- save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
313
-
314
- progress_bar.set_description('SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
315
-
316
- if args.fullres:
317
- with h5py.File(list_test_fr_files[step - 1], 'r') as f:
318
- fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
319
- mov_img = f['mov_image'][:][np.newaxis, ...]
320
- fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
321
- mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
322
- fix_centroids = f['fix_centroids'][:]
323
-
324
- # Up sample the displacement map to the full res
325
- trf = scale_transformation(image_output_shape, fix_img.shape[1:-1])
326
- disp_map_fr = resize_displacement_map(np.squeeze(disp_map), None, trf)[np.newaxis, ...]
327
- disp_map_fr = gaussian_filter(disp_map_fr, 5)
328
- # disp_mad_fr = sess.run(smooth_filter, feed_dict={'dm:0': disp_map_fr})
329
-
330
- # Predicted image
331
- pred_img_fr = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([mov_img, disp_map_fr]).eval()
332
- pred_seg_fr = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([mov_seg, disp_map_fr]).eval()
333
-
334
- # Predicted centroids
335
- dm_interp_fr = DisplacementMapInterpolator(fix_img.shape[1:-1], 'griddata', step=2)
336
- pred_centroids = dm_interp_fr(disp_map_fr, mov_centroids, backwards=True) + mov_centroids
337
-
338
- # Metrics - segmentation
339
- dice = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) / np.sum(fix_seg[..., l]) for l in range(nb_labels)])
340
- hd = np.mean(safe_medpy_metric(pred_seg[0, ...], fix_seg[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
341
- dice_macro = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) for l in range(nb_labels)])
342
-
343
- pred_seg_card_fr = segmentation_ohe_to_cardinal(pred_seg_fr).astype(np.float32)
344
- mov_seg_card_fr = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
345
- fix_seg_card_fr = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
346
-
347
- # Metrics - image
348
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
349
- {'fix_img:0': fix_img, 'pred_img:0': pred_img_fr})
350
- ssim = np.mean(ssim)
351
- ms_ssim = ms_ssim[0]
352
-
353
- # Metrics - registration
354
- tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
355
-
356
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1 - t0, tre, len(missing_lbls),
357
- missing_lbls]
358
- with open(metrics_file_fr, 'a') as f:
359
- f.write(';'.join(map(str, new_line)) + '\n')
360
-
361
- save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
362
- save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
363
- save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
364
- save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
365
- save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
366
- save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
367
-
368
- # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
369
- # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
370
- # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
371
- # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
372
- # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
373
- # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
374
-
375
- # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
376
- # _ = plt.hist(magnitude.flatten())
377
- # plt.title('Histogram of disp. magnitudes')
378
- # plt.show(block=False)
379
- # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
380
- # plt.close()
381
-
382
- plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, seg_batches=[fix_seg_card_fr, mov_seg_card_fr, pred_seg_card_fr], filename=os.path.join(output_folder_fr, '{:03d}_figures_seg.png'.format(step)), show=False, step=10)
383
- plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, filename=os.path.join(output_folder_fr, '{:03d}_figures_img.png'.format(step)), show=False, step=10)
384
- # save_disp_map_img(disp_map_fr, 'Displacement map', os.path.join(output_folder_fr, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=10)
385
-
386
- progress_bar.set_description('[FR] SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
387
-
388
- print('Summary\n=======\n')
389
- metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
390
- print('\nAVG:\n')
391
- print(metrics_df.mean(axis=0))
392
- print('\nSTD:\n')
393
- print(metrics_df.std(axis=0))
394
- print('\nHD95perc:\n')
395
- print(metrics_df['HD'].describe(percentiles=[.95]))
396
- print('\n=======\n')
397
- if args.fullres:
398
- print('Summary full resolution\n=======\n')
399
- metrics_df = pd.read_csv(metrics_file_fr, sep=';', header=0)
400
- print('\nAVG:\n')
401
- print(metrics_df.mean(axis=0))
402
- print('\nSTD:\n')
403
- print(metrics_df.std(axis=0))
404
- print('\nHD95perc:\n')
405
- print(metrics_df['HD'].describe(percentiles=[.95]))
406
- print('\n=======\n')
407
- tf.keras.backend.clear_session()
408
- # sess.close()
409
- del network
410
- print('Done')
 
1
+ import datetime
 
 
 
 
 
 
 
 
 
2
  import os, sys
 
3
  import shutil
4
+ import re
5
+ import argparse
6
+ import subprocess
7
+ import logging
8
  import time
 
 
 
 
9
 
10
  currentdir = os.path.dirname(os.path.realpath(__file__))
11
  parentdir = os.path.dirname(currentdir)
12
  sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
13
 
14
  import tensorflow as tf
 
15
 
16
  import numpy as np
17
+ import nibabel as nib
18
+ from scipy.ndimage import gaussian_filter, zoom
19
+ from skimage.measure import regionprops
20
+ import SimpleITK as sitk
21
+
22
  import voxelmorph as vxm
23
  from voxelmorph.tf.layers import SpatialTransformer
24
 
25
  import DeepDeformationMapRegistration.utils.constants as C
 
26
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
27
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC
 
28
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
29
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm
30
+ from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
31
+
32
+
33
+ MODELS_FILE = {'L': {'BL-N': './models/liver/bl_ncc.h5',
34
+ 'BL-S': './models/liver/bl_ncc_ssim.h5',
35
+ 'SG-ND': './models/liver/sg_ncc_dsc.h5',
36
+ 'SD-NSD': './models/liver/sg_ncc_ssim_dsc.h5',
37
+ 'UW-NSD': './models/liver/uw_ncc_ssim_dsc.h5',
38
+ 'UW-NSDH': './models/liver/uw_ncc_ssim_dsc_hd.h5',
39
+ },
40
+ 'B': {'BL-N': './models/brain/bl_ncc.h5',
41
+ 'BL-S': './models/brain/bl_ncc_ssim.h5',
42
+ 'SG-ND': './models/brain/sg_ncc_dsc.h5',
43
+ 'SD-NSD': './models/brain/sg_ncc_ssim_dsc.h5',
44
+ 'UW-NSD': './models/brain/uw_ncc_ssim_dsc.h5',
45
+ 'UW-NSDH': './models/brain/uw_ncc_ssim_dsc_hd.h5',
46
+ }
47
+ }
48
 
49
+ IMAGE_INTPUT_SHAPE = np.asarray([128, 128, 128, 1])
 
 
50
 
 
51
 
52
+ def rigidly_align_images(image_1: str, image_2: str) -> nib.Nifti1Image:
53
+ """
54
+ Rigidly align the images and resample to the same array size, to the dense displacement map is correct
55
 
56
+ """
57
+ def resample_to_isotropic(image: sitk.Image) -> sitk.Image:
58
+ spacing = image.GetSpacing()
59
+ spacing = min(spacing)
60
+ resamp_spacing = [spacing] * image.GetDimension()
61
+ resamp_size = [int(round(or_size*or_space/spacing)) for or_size, or_space in zip(image.GetSize(), image.GetSpacing())]
62
+ return sitk.Resample(image,
63
+ resamp_size, sitk.Transform(), sitk.sitkLinear,image.GetOrigin(),
64
+ resamp_spacing, image.GetDirection(), 0, image.GetPixelID())
65
 
66
+ image_1 = sitk.ReadImage(image_1, sitk.sitkFloat32)
67
+ image_2 = sitk.ReadImage(image_2, sitk.sitkFloat32)
68
+
69
+ image_1 = resample_to_isotropic(image_1)
70
+ image_2 = resample_to_isotropic(image_2)
71
+
72
+ rig_reg = sitk.ImageRegistrationMethod()
73
+ rig_reg.SetMetricAsMeanSquares()
74
+ rig_reg.SetOptimizerAsRegularStepGradientDescent(4.0, 0.01, 200)
75
+ rig_reg.SetInitialTransform(sitk.TranslationTransform(image_1.GetDimension()))
76
+ rig_reg.SetInterpolator(sitk.sitkLinear)
77
+
78
+ print('Running rigid registration...')
79
+ rig_reg_trf = rig_reg.Execute(image_1, image_2)
80
+ print('Rigid registration completed\n----------------------------')
81
+ print('Optimizer stop condition: {}'.format(rig_reg.GetOptimizerStopConditionDescription()))
82
+ print('Iteration: {}'.format(rig_reg.GetOptimizerIteration()))
83
+ print('Metric value: {}'.format(rig_reg.GetMetricValue()))
84
+
85
+ resampler = sitk.ResampleImageFilter()
86
+ resampler.SetReferenceImage(image_1)
87
+ resampler.SetInterpolator(sitk.sitkLinear)
88
+ resampler.SetDefaultPixelValue(100)
89
+ resampler.SetTransform(rig_reg_trf)
90
+
91
+ image_2 = resampler.Execute(image_2)
92
+
93
+ # TODO: Build a common image to hold both image_1 and image_2
94
+
95
+
96
+ def pad_images(image_1: nib.Nifti1Image, image_2: nib.Nifti1Image):
97
+ """
98
+ Align image_1 and image_2 by the top left corner and pad them to the largest dimensions along the three axes
99
+ """
100
+ joint_image_shape = np.maximum(image_1.shape, image_2.shape)
101
+ pad_1 = [[0, p] for p in joint_image_shape - image_1.shape]
102
+ pad_2 = [[0, p] for p in joint_image_shape - image_2.shape]
103
+ image_1_padded = np.pad(image_1.dataobj, pad_1, mode='edge')
104
+ image_2_padded = np.pad(image_2.dataobj, pad_2, mode='edge')
105
+
106
+ return image_1_padded, image_2_padded
107
+
108
+
109
+ def pad_displacement_map(disp_map: np.ndarray, crop_min: np.ndarray, crop_max: np.ndarray, output_shape: (np.ndarray, list)) -> np.ndarray:
110
+ padding = [[crop_min[i], image_shape_or[i] - crop_max[i]] for i in range(3)] + [[0, 0]]
111
+ return np.pad(disp_map, padding, mode='constant')
112
+
113
+
114
+ def run_livermask(input_image_path, outputdir, filename: str = 'segmentation') -> np.ndarray:
115
+ logger.info('Getting parenchyma segmentations...')
116
+ shutil.copy2(input_image_path, os.path.join(outputdir, f'{filename}.nii.gz'))
117
+ livermask_cmd = "{} -m livermask.livermask --input {} --output {}".format(sys.executable,
118
+ input_image_path,
119
+ os.path.join(outputdir,
120
+ f'{filename}.nii.gz'))
121
+ subprocess.run(livermask_cmd)
122
+ logger.info('done!')
123
+ segmentation_path = os.path.join(outputdir, f'{filename}.nii.gz')
124
+ return np.asarray(nib.load(segmentation_path).dataobj, dtype=int)
125
+
126
+
127
+ def debug_save_image(image: (np.ndarray, nib.Nifti1Image), filename: str, outputdir: str, debug: bool = True):
128
+ def disp_map_modulus(disp_map, scale: float = None):
129
+ disp_map_mod = np.sqrt(np.sum(np.power(disp_map, 2), -1))
130
+ if scale:
131
+ min_disp = np.min(disp_map_mod)
132
+ max_disp = np.max(disp_map_mod)
133
+ disp_map_mod = disp_map_mod - min_disp / (max_disp - min_disp)
134
+ disp_map_mod *= scale
135
+ logger.debug('Scaled displacement map to [0., 1.] range')
136
+ return disp_map_mod
137
+
138
+ if debug:
139
+ os.makedirs(os.path.join(outputdir, 'debug'), exist_ok=True)
140
+ if image.shape[-1] > 1:
141
+ image = disp_map_modulus(image, 1.)
142
+ save_nifti(image, os.path.join(outputdir, 'debug', filename+'.nii.gz'), verbose=False)
143
+ logger.debug(f'Saved {filename} at {os.path.join(outputdir, filename+".nii.gz")}')
144
+
145
+
146
+ def get_roi(image_filepath: str,
147
+ anatomy: str,
148
+ outputdir: str,
149
+ filename_filepath: str = 'segmentation',
150
+ segmentation_file: str = None,
151
+ debug: bool = False) -> list:
152
+ segm = None
153
+ if segmentation_file is None and anatomy == 'L':
154
+ segm = run_livermask(image_filepath, outputdir, filename_filepath)
155
+ logger.info(f'Loaded segmentation using livermask from {os.path.join(outputdir, filename_filepath)}')
156
+ elif segmentation_file is not None:
157
+ segm = np.asarray(nib.load(segmentation_file).dataobj, dtype=int)
158
+ logger.info(f'Loaded fixed segmentation from {segmentation_file}')
159
+ else:
160
+ logger.warning('No segmentation provided! Using the full volume')
161
+ if segm is not None:
162
+ segm[segm > 0] = 1
163
+ ret_val = regionprops(segm)[0].bbox
164
+ debug_save_image(segm, f'img_1_{filename_filepath}', outputdir, debug)
165
+ else:
166
+ ret_val = [0, 0, 0] + list(nib.load(image_filepath).shape[:3])
167
+ logger.debug(f'ROI found at coordinates {ret_val}')
168
+ return ret_val
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == '__main__':
172
  parser = argparse.ArgumentParser()
173
  parser.add_argument('-f', '--fixed', type=str, help='Path to fixed image file (NIfTI)')
174
+ parser.add_argument('-m', '--moving', type=str, help='Path to moving segmentation image file (NIfTI)', default=None)
175
+ parser.add_argument('-F', '--fixedsegm', type=str, help='Path to fixed image segmentation file(NIfTI)',
176
+ default=None)
177
+ parser.add_argument('-M', '--movingsegm', type=str, help='Path to moving image file (NIfTI)')
178
  parser.add_argument('-o', '--outputdir', type=str, help='Output directory', default='./Registration_output')
179
+ parser.add_argument('-a', '--anatomy', type=str, help='Anatomical structure: liver (L) (Default) or brain (B)',
180
+ default='L')
181
+ parser.add_argument('--gpu', type=int,
182
+ help='In case of multi-GPU systems, limits the execution to the defined GPU number',
183
+ default=None)
184
+ parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH',
185
+ default='UW-NSD')
186
+ parser.add_argument('--debug', '-d', action='store_true', help='Produce additional debug information', default=False)
187
+ parser.add_argument('-y', action='store_true', help='Erase output folder if this has content', default=False)
188
  # parser.add_argument('--brain', type=bool, action='store_true', help='Perform brain MRi registration', default=False)
189
  args = parser.parse_args()
 
190
 
191
  assert os.path.exists(args.fixed), 'Fixed image not found'
192
  assert os.path.exists(args.moving), 'Moving image not found'
193
  assert args.model in ['BL-N', 'BL-S', 'SG-ND', 'SG-NSD', 'UW-NSD', 'UW-NSDH'], 'Invalid model type'
194
+ assert args.anatomy in ['L', 'B'], 'Invalid anatomy option'
195
+
196
+ if os.path.exists(args.outputdir) and len(os.listdir(args.outputdir)):
197
+ if args.y:
198
+ erase = 'y'
199
+ else:
200
+ erase = input('Output directory is not empty, erase content? (y/n)')
201
+ if erase.lower() in ['y', 'yes']:
202
+ shutil.rmtree(args.outputdir, ignore_errors=True)
203
+ print('Erased directory: ' + args.outputdir)
204
+ elif erase.lower() in ['n', 'no']:
205
+ args.outputdir = os.path.join(args.outputdir, datetime.datetime.now().strftime('%H%M%S_%Y%m%d'))
206
+ print('New output directory: ' + args.outputdir)
207
+ os.makedirs(args.outputdir, exist_ok=True)
208
 
209
+ log_format = '%(asctime)s [%(levelname)s]:\t%(message)s'
210
+ logging.basicConfig(filename=os.path.join(args.outputdir, 'log.log'), filemode='w',
211
+ format=log_format, datefmt='%Y-%m-%d %H:%M:%S')
212
+ logger = logging.getLogger(__name__)
213
+ stdout_handler = logging.StreamHandler(sys.stdout)
214
+ stdout_handler.setFormatter(logging.Formatter(log_format, datefmt='%Y-%m-%d %H:%M:%S'))
215
+ logger.addHandler(stdout_handler)
216
  if isinstance(args.gpu, int):
217
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
218
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
219
+ if args.debug:
220
+ logger.setLevel('DEBUG')
221
+ logger.debug('DEBUG MODE ENABLED')
222
 
223
  # Load the file and preprocess it
224
+ logger.info('Loading image files')
225
+ fixed_image_or = nib.load(args.fixed)
226
+ moving_image_or = nib.load(args.moving)
227
+ image_shape_or = np.asarray(fixed_image_or.shape)
228
+ fixed_image_or, moving_image_or = pad_images(fixed_image_or, moving_image_or)
229
+ fixed_image_or = fixed_image_or[..., np.newaxis] # add channel dim
230
+ moving_image_or = moving_image_or[..., np.newaxis] # add channel dim
231
+ debug_save_image(fixed_image_or, 'img_0_loaded_fix_image', args.outputdir, args.debug)
232
+ debug_save_image(moving_image_or, 'img_0_loaded_moving_image', args.outputdir, args.debug)
233
 
234
  # TF stuff
235
+ logger.info('Setting up configuration')
236
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
237
  config.gpu_options.allow_growth = True
238
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
 
242
  tf.compat.v1.keras.backend.set_session(sess)
243
 
244
  # Preprocess data
 
 
 
 
 
 
245
  # 1. Run Livermask to get the mask around the liver in both the fixed and moving image
246
+ logger.info('Getting ROI')
247
+ fixed_segm_bbox = get_roi(args.fixed, args.anatomy, args.outputdir,
248
+ 'fixed_segmentation', args.fixedsegm, args.debug)
249
+ moving_segm_bbox = get_roi(args.moving, args.anatomy, args.outputdir,
250
+ 'moving_segmentation', args.movingsegm, args.debug)
251
+
252
+ crop_min = np.amin(np.vstack([fixed_segm_bbox[:3], moving_segm_bbox[:3]]), axis=0)
253
+ crop_max = np.amax(np.vstack([fixed_segm_bbox[3:], moving_segm_bbox[3:]]), axis=0)
254
+
255
+ # 2.2 Crop the fixed and moving images using such boxes
256
+ fixed_image = fixed_image_or[crop_min[0]: crop_max[0],
257
+ crop_min[1]: crop_max[1],
258
+ crop_min[2]: crop_max[2], ...]
259
+ debug_save_image(fixed_image, 'img_2_cropped_fixed_image', args.outputdir, args.debug)
260
+
261
+ moving_image = moving_image_or[crop_min[0]: crop_max[0],
262
+ crop_min[1]: crop_max[1],
263
+ crop_min[2]: crop_max[2], ...]
264
+ debug_save_image(moving_image, 'img_2_cropped_moving_image', args.outputdir, args.debug)
265
+
266
+ image_shape_crop = fixed_image.shape
267
+ # 2.3 Resize the images to the expected input size
268
+ zoom_factors = IMAGE_INTPUT_SHAPE / image_shape_crop
269
+ fixed_image = zoom(fixed_image, zoom_factors)
270
+ moving_image = zoom(moving_image, zoom_factors)
271
+ fixed_image = min_max_norm(fixed_image)
272
+ moving_image = min_max_norm(moving_image)
273
+ debug_save_image(fixed_image, 'img_3_preproc_fixed_image', args.outputdir, args.debug)
274
+ debug_save_image(moving_image, 'img_3_preproc_moving_image', args.outputdir, args.debug)
275
 
276
  # 3. Build the whole graph
277
+ logger.info('Building TF graph')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  ### METRICS GRAPH ###
279
+ fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
280
+ pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
281
+
282
+ ssim_tf = StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric(fix_img_ph, pred_img_ph)
283
+ ncc_tf = NCC(image_shape_or).metric(fix_img_ph, pred_img_ph)
284
+ mse_tf = vxm.losses.MSE().loss(fix_img_ph, pred_img_ph)
285
+ ms_ssim_tf = MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric(fix_img_ph, pred_img_ph)
286
+
287
+ logger.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
288
+ MODEL_FILE = MODELS_FILE[args.anatomy][args.model]
289
+
290
+ # try:
291
+ # network = tf.keras.models.load_model(MODEL_FILE,
292
+ # {'VxmDense': vxm.networks.VxmDense,
293
+ # # 'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
294
+ # 'AdamAccumulated': AdamAccumulated
295
+ # },
296
+ # compile=False)
297
+ # except ValueError as e:
298
+ # enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
299
+ # dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
300
+ # nb_features = [enc_features, dec_features]
301
+ # if re.search('^UW|SEGGUIDED_', MODEL_FILE):
302
+ # network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
303
+ # nb_unet_features=nb_features,
304
+ # int_steps=0,
305
+ # int_downsize=1,
306
+ # seg_downsize=1)
307
+ # else:
308
+ # network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
309
+ # nb_unet_features=nb_features,
310
+ # int_steps=0)
311
+ # network.load_weights(MODEL_FILE, by_name=True)
312
+
313
+ enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
314
+ dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
315
+ nb_features = [enc_features, dec_features]
316
+ network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
317
+ nb_unet_features=nb_features,
318
+ int_steps=0)
319
+ network.load_weights(MODEL_FILE, by_name=True)
320
+ network.trainable = False
321
+
322
+ registration_model = network.get_registration_model()
323
+ deb_model = network.apply_transform
324
+
325
+ logger.info('Performing registration')
326
+ with sess.as_default():
327
+ if args.debug:
328
+ registration_model.summary(line_length=C.SUMMARY_LINE_LENGTH)
329
+ time_disp_map_start = time.time()
330
+ # disp_map = registration_model.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
331
+ p, disp_map = network.predict([moving_image[np.newaxis, ...], fixed_image[np.newaxis, ...]])
332
+ time_disp_map_end = time.time()
333
+ debug_save_image(np.squeeze(disp_map), 'disp_map_0_raw', args.outputdir, args.debug)
334
+ debug_save_image(p[0, ...], 'img_4_net_pred_image', args.outputdir, args.debug)
335
+ # pred_image = min_max_norm(pred_image)
336
+ # pred_image_isot = zoom(pred_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
337
+ # fixed_image_isot = zoom(fixed_image[0, ...], zoom_factors, order=3)[np.newaxis, ...]
338
+
339
+ # Up sample the displacement map to the full res
340
+ trf = np.eye(4)
341
+ np.fill_diagonal(trf, 1/zoom_factors)
342
+ disp_map = resize_displacement_map(np.squeeze(disp_map), None, trf)
343
+ debug_save_image(np.squeeze(disp_map), 'disp_map_1_upsampled', args.outputdir, args.debug)
344
+ disp_map_or = pad_displacement_map(disp_map, crop_min, crop_max, image_shape_or)
345
+ debug_save_image(np.squeeze(disp_map_or), 'disp_map_2_padded', args.outputdir, args.debug)
346
+ disp_map_or = gaussian_filter(disp_map_or, 5)
347
+ debug_save_image(np.squeeze(disp_map_or), 'disp_map_3_smoothed', args.outputdir, args.debug)
348
+
349
+ time_pred_img_start = time.time()
350
+ pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image_or[np.newaxis, ...], disp_map_or[np.newaxis, ...]]).eval()
351
+ time_pred_img_end = time.time()
352
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
353
+ {'fix_img:0': fixed_image_or[np.newaxis, ...], 'pred_img:0': pred_image})
354
+ ssim = np.mean(ssim)
355
+ ms_ssim = ms_ssim[0]
356
+ pred_image = pred_image[0, ...]
357
+
358
+ save_nifti(pred_image, os.path.join(args.outputdir, 'pred_image.nii.gz'))
359
+ np.savez_compressed(os.path.join(args.outputdir, 'displacement_map.npz'), disp_map_or)
360
+ logger.info('Predicted image (full image) and displacement map saved in: '.format(args.outputdir))
361
+ logger.info(f'Displacement map prediction time: {time_disp_map_end - time_disp_map_start} s')
362
+ logger.info(f'Predicted image time: {time_pred_img_end - time_pred_img_start} s')
363
+
364
+ logger.info('Similarity metrics (Full image)\n------------------')
365
+ logger.info('SSIM: {:.03f}'.format(ssim))
366
+ logger.info('NCC: {:.03f}'.format(ncc))
367
+ logger.info('MSE: {:.03f}'.format(mse))
368
+ logger.info('MS SSIM: {:.03f}'.format(ms_ssim))
369
+
370
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
371
+ {'fix_img:0': fixed_image[np.newaxis, ...], 'pred_img:0': p})
372
+ ssim = np.mean(ssim)
373
+ ms_ssim = ms_ssim[0]
374
+ logger.info('\nSimilarity metrics (ROI)\n------------------')
375
+ logger.info('SSIM: {:.03f}'.format(ssim))
376
+ logger.info('NCC: {:.03f}'.format(ncc))
377
+ logger.info('MSE: {:.03f}'.format(mse))
378
+ logger.info('MS SSIM: {:.03f}'.format(ms_ssim))
379
+
380
+ del registration_model
381
+ logger.info('Done')