jpdefrutos commited on
Commit
a290524
·
1 Parent(s): dc36465

Working on the clean repo

Browse files
Files changed (2) hide show
  1. DeepDeformationMapRegistration/main.py +410 -0
  2. setup.py +26 -0
DeepDeformationMapRegistration/main.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 1. Image files generator
3
+
4
+ # timer start
5
+ # 2. Preprocess the image
6
+ # 3. Predict the displacement
7
+ # timer stop
8
+
9
+ # 4. Evaluate the registration: NCC; SSIM; DICE; HD95
10
+
11
+ import os, sys
12
+
13
+ import shutil
14
+ import time
15
+ import tkinter
16
+
17
+ import h5py
18
+ import matplotlib.pyplot as plt
19
+
20
+ currentdir = os.path.dirname(os.path.realpath(__file__))
21
+ parentdir = os.path.dirname(currentdir)
22
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
23
+
24
+ import tensorflow as tf
25
+ # tf.enable_eager_execution(config=config)
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+ import voxelmorph as vxm
30
+ from voxelmorph.tf.layers import SpatialTransformer
31
+
32
+ import DeepDeformationMapRegistration.utils.constants as C
33
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
34
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
35
+ from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
36
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
37
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
38
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
39
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
40
+ from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
41
+ from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation, GaussianFilter
42
+ import medpy.metric as medpy_metrics
43
+ from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
44
+ from scipy.interpolate import RegularGridInterpolator
45
+ from tqdm import tqdm
46
+ import nibabel as nib
47
+ from scipy.ndimage import gaussian_filter, zoom
48
+
49
+ import h5py
50
+ import re
51
+ from Brain_study.data_generator import BatchGenerator
52
+
53
+ import argparse
54
+
55
+ from skimage.transform import warp
56
+ import neurite as ne
57
+
58
+ import tempfile
59
+
60
+ import logging
61
+
62
+ MODELS_FILE = {'liver': {'BL-N': './models/liver/bl_ncc.h5',
63
+ 'BL-S': './models/liver/bl_ssim.h5',
64
+ 'SG-ND': './models/liver/sg_ncc_dsc.h5',
65
+ 'SD-NSD': './models/liver/sg_ncc_ssim_dsc.h5',
66
+ 'UW-NSD': './models/liver/uw_ncc_ssim_dsc.h5',
67
+ 'UW-NSDH': './models/liver/uw_ncc_ssim_dsc_hd.h5',
68
+ },
69
+ 'brain': {'BL-N': './models/brain/bl_ncc.h5',
70
+ 'BL-S': './models/brain/bl_ssim.h5',
71
+ 'SG-ND': './models/brain/sg_ncc_dsc.h5',
72
+ 'SD-NSD': './models/brain/sg_ncc_ssim_dsc.h5',
73
+ 'UW-NSD': './models/brain/uw_ncc_ssim_dsc.h5',
74
+ 'UW-NSDH': './models/brain/uw_ncc_ssim_dsc_hd.h5',
75
+ }
76
+ }
77
+
78
+ if __name__ == '__main__':
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument('-f', '--fixed', type=str, help='Path to fixed image file (NIfTI)')
81
+ parser.add_argument('-m', '--moving', type=str, help='Path to oving image file (NIfTI)')
82
+ parser.add_argument('-o', '--outputdir', type=str, help='Output directory', default='./Registration_output')
83
+ parser.add_argument('--gpu', type=int, help='In case of multi-GPU systems, limits the execution to the defined GPU number', default=None)
84
+ parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH', default='UW-NSD')
85
+ # parser.add_argument('--brain', type=bool, action='store_true', help='Perform brain MRi registration', default=False)
86
+ args = parser.parse_args()
87
+ logger = logging.getLogger(__name__)
88
+
89
+ assert os.path.exists(args.fixed), 'Fixed image not found'
90
+ assert os.path.exists(args.moving), 'Moving image not found'
91
+ assert args.model in ['BL-N', 'BL-S', 'SG-ND', 'SG-NSD', 'UW-NSD', 'UW-NSDH'], 'Invalid model type'
92
+
93
+ if isinstance(args.gpu, int):
94
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
95
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
96
+
97
+ # Load the file and preprocess it
98
+ fixed_image = nib.load(args.fixed)
99
+ moving_image = nib.load(args.moving)
100
+
101
+ # TF stuff
102
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
103
+ config.gpu_options.allow_growth = True
104
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
105
+ config.allow_soft_placement = True
106
+
107
+ sess = tf.compat.v1.Session(config=config)
108
+ tf.compat.v1.keras.backend.set_session(sess)
109
+
110
+ # Preprocess data
111
+ if args.erase:
112
+ shutil.rmtree(args.outputdir, ignore_errors=True)
113
+ os.makedirs(args.outputdir, exist_ok=True)
114
+ lm_output_dir = os.path.join(args.outputdir, 'livermask')
115
+ os.makedirs(lm_output_dir, exist_ok=True)
116
+
117
+ # 1. Run Livermask to get the mask around the liver in both the fixed and moving image
118
+ logger.info('Getting parenchyma segmentations...')
119
+ livermask_cmd = "python -m livermaks.livermask --input {} --output {}".format(args.fixed, os.path.join(lm_output_dir, 'fixed.nii.gz'))
120
+ os.system(livermask_cmd)
121
+ logger.info('... fixed image done...')
122
+ livermask_cmd = "python -m livermaks.livermask --input {} --output {}".format(args.moving, os.path.join(lm_output_dir, 'moving.nii.gz'))
123
+ os.system(livermask_cmd)
124
+ logger.info('... moving image done.')
125
+
126
+ # 2. Crop around the liver
127
+ # 2.1 Load the segmentations
128
+ # 2.2 Find the outermost box containing both boxes
129
+ # 2.3 Crop the fixed and moving images using such boxes
130
+ # 2.4 Resize the images to the expected input size
131
+
132
+ # 3. Build the whole graph
133
+
134
+
135
+ # Loss and metric functions. Common to all models
136
+ # loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
137
+ # NCC(image_input_shape).loss,
138
+ # vxm.losses.MSE().loss,
139
+ # MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
140
+ #
141
+ # metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
142
+ # NCC(image_input_shape).metric,
143
+ # vxm.losses.MSE().loss,
144
+ # MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
145
+ # GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
146
+ # HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
147
+ # GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
148
+
149
+ ### METRICS GRAPH ###
150
+ # fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
151
+ # pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
152
+ # fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
153
+ # pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
154
+ #
155
+ # ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
156
+ # ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
157
+ # mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
158
+ # ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
159
+ # dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
160
+ # hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
161
+ # dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
162
+ # hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
163
+
164
+ # Needed for VxmDense type of network
165
+ warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
166
+
167
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
168
+
169
+ for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
170
+ print('MODEL LOCATION: ', MODEL_FILE)
171
+
172
+ # data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
173
+ output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
174
+ # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
175
+
176
+ print('DESTINATION FOLDER: ', output_folder)
177
+
178
+ if args.fullres:
179
+ output_folder_fr = os.path.join(DATA_ROOT_DIR, args.outdirname, 'full_resolution') # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
180
+ # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
181
+ if args.erase:
182
+ shutil.rmtree(output_folder_fr, ignore_errors=True)
183
+ os.makedirs(output_folder_fr, exist_ok=True)
184
+ print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
185
+
186
+ try:
187
+ network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
188
+ 'VxmDense': vxm.networks.VxmDense,
189
+ 'AdamAccumulated': AdamAccumulated,
190
+ 'loss': loss_fncs,
191
+ 'metric': metric_fncs},
192
+ compile=False)
193
+ except ValueError as e:
194
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
195
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
196
+ nb_features = [enc_features, dec_features]
197
+ if re.search('^UW|SEGGUIDED_', MODEL_FILE):
198
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
199
+ nb_labels=nb_labels,
200
+ nb_unet_features=nb_features,
201
+ int_steps=0,
202
+ int_downsize=1,
203
+ seg_downsize=1)
204
+ else:
205
+ network = vxm.networks.VxmDense(inshape=image_output_shape,
206
+ nb_unet_features=nb_features,
207
+ int_steps=0)
208
+ network.load_weights(MODEL_FILE, by_name=True)
209
+ # Record metrics
210
+ metrics_file = os.path.join(output_folder, 'metrics.csv')
211
+ with open(metrics_file, 'w') as f:
212
+ f.write(';'.join(csv_header)+'\n')
213
+
214
+ if args.fullres:
215
+ metrics_file_fr = os.path.join(output_folder_fr, 'metrics.csv')
216
+ with open(metrics_file_fr, 'w') as f:
217
+ f.write(';'.join(csv_header) + '\n')
218
+
219
+ ssim = ncc = mse = ms_ssim = dice = hd = 0
220
+ with sess.as_default():
221
+ sess.run(tf.global_variables_initializer())
222
+ network.load_weights(MODEL_FILE, by_name=True)
223
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
224
+ progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
225
+ for step, in_batch in progress_bar:
226
+ with h5py.File(in_batch, 'r') as f:
227
+ fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
228
+ mov_img = f['mov_image'][:][np.newaxis, ...]
229
+ fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
230
+ mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
231
+ fix_centroids = f['fix_centroids'][:]
232
+ isotropic_shape = f['isotropic_shape'][:]
233
+ voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
234
+
235
+ if network.name == 'vxm_dense_semi_supervised_seg':
236
+ t0 = time.time()
237
+ pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
238
+ t1 = time.time()
239
+ else:
240
+ t0 = time.time()
241
+ pred_img, disp_map = network.predict([mov_img, fix_img])
242
+ pred_seg = warp_segmentation.predict([mov_seg, disp_map])
243
+ t1 = time.time()
244
+
245
+ pred_img = min_max_norm(pred_img)
246
+ mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels+1), brain_study=False)
247
+ # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
248
+ pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
249
+
250
+ # Up sample the segmentation masks to isotropic resolution
251
+ zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
252
+ pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
253
+ fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
254
+
255
+ pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
256
+ fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
257
+
258
+ # I need the labels to be OHE to compute the segmentation metrics.
259
+ # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
260
+ dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
261
+ hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
262
+ dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
263
+
264
+ pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
265
+ mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
266
+ fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
267
+
268
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
269
+ ssim = np.mean(ssim)
270
+ ms_ssim = ms_ssim[0]
271
+
272
+ # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
273
+ fix_centroids_isotropic = fix_centroids * voxel_size
274
+ # mov_centroids_isotropic = mov_centroids * voxel_size
275
+ pred_centroids_isotropic = pred_centroids * voxel_size
276
+
277
+ # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
278
+ # # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
279
+ # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
280
+ # Now we can measure the TRE in mm
281
+ tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
282
+ tre = np.mean([v for v in tre_array if not np.isnan(v)])
283
+ # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
284
+
285
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
286
+ with open(metrics_file, 'a') as f:
287
+ f.write(';'.join(map(str, new_line))+'\n')
288
+
289
+ save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
290
+ save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
291
+ save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
292
+ save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
293
+ save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
294
+ save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
295
+
296
+ # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
297
+ # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
298
+ # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
299
+ # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
300
+ # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
301
+ # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
302
+
303
+ # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
304
+ # _ = plt.hist(magnitude.flatten())
305
+ # plt.title('Histogram of disp. magnitudes')
306
+ # plt.show(block=False)
307
+ # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
308
+ # plt.close()
309
+
310
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
311
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
312
+ save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
313
+
314
+ progress_bar.set_description('SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
315
+
316
+ if args.fullres:
317
+ with h5py.File(list_test_fr_files[step - 1], 'r') as f:
318
+ fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
319
+ mov_img = f['mov_image'][:][np.newaxis, ...]
320
+ fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
321
+ mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
322
+ fix_centroids = f['fix_centroids'][:]
323
+
324
+ # Up sample the displacement map to the full res
325
+ trf = scale_transformation(image_output_shape, fix_img.shape[1:-1])
326
+ disp_map_fr = resize_displacement_map(np.squeeze(disp_map), None, trf)[np.newaxis, ...]
327
+ disp_map_fr = gaussian_filter(disp_map_fr, 5)
328
+ # disp_mad_fr = sess.run(smooth_filter, feed_dict={'dm:0': disp_map_fr})
329
+
330
+ # Predicted image
331
+ pred_img_fr = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([mov_img, disp_map_fr]).eval()
332
+ pred_seg_fr = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([mov_seg, disp_map_fr]).eval()
333
+
334
+ # Predicted centroids
335
+ dm_interp_fr = DisplacementMapInterpolator(fix_img.shape[1:-1], 'griddata', step=2)
336
+ pred_centroids = dm_interp_fr(disp_map_fr, mov_centroids, backwards=True) + mov_centroids
337
+
338
+ # Metrics - segmentation
339
+ dice = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) / np.sum(fix_seg[..., l]) for l in range(nb_labels)])
340
+ hd = np.mean(safe_medpy_metric(pred_seg[0, ...], fix_seg[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
341
+ dice_macro = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) for l in range(nb_labels)])
342
+
343
+ pred_seg_card_fr = segmentation_ohe_to_cardinal(pred_seg_fr).astype(np.float32)
344
+ mov_seg_card_fr = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
345
+ fix_seg_card_fr = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
346
+
347
+ # Metrics - image
348
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
349
+ {'fix_img:0': fix_img, 'pred_img:0': pred_img_fr})
350
+ ssim = np.mean(ssim)
351
+ ms_ssim = ms_ssim[0]
352
+
353
+ # Metrics - registration
354
+ tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
355
+
356
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1 - t0, tre, len(missing_lbls),
357
+ missing_lbls]
358
+ with open(metrics_file_fr, 'a') as f:
359
+ f.write(';'.join(map(str, new_line)) + '\n')
360
+
361
+ save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
362
+ save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
363
+ save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
364
+ save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
365
+ save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
366
+ save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
367
+
368
+ # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
369
+ # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
370
+ # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
371
+ # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
372
+ # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
373
+ # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
374
+
375
+ # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
376
+ # _ = plt.hist(magnitude.flatten())
377
+ # plt.title('Histogram of disp. magnitudes')
378
+ # plt.show(block=False)
379
+ # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
380
+ # plt.close()
381
+
382
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, seg_batches=[fix_seg_card_fr, mov_seg_card_fr, pred_seg_card_fr], filename=os.path.join(output_folder_fr, '{:03d}_figures_seg.png'.format(step)), show=False, step=10)
383
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, filename=os.path.join(output_folder_fr, '{:03d}_figures_img.png'.format(step)), show=False, step=10)
384
+ # save_disp_map_img(disp_map_fr, 'Displacement map', os.path.join(output_folder_fr, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=10)
385
+
386
+ progress_bar.set_description('[FR] SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
387
+
388
+ print('Summary\n=======\n')
389
+ metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
390
+ print('\nAVG:\n')
391
+ print(metrics_df.mean(axis=0))
392
+ print('\nSTD:\n')
393
+ print(metrics_df.std(axis=0))
394
+ print('\nHD95perc:\n')
395
+ print(metrics_df['HD'].describe(percentiles=[.95]))
396
+ print('\n=======\n')
397
+ if args.fullres:
398
+ print('Summary full resolution\n=======\n')
399
+ metrics_df = pd.read_csv(metrics_file_fr, sep=';', header=0)
400
+ print('\nAVG:\n')
401
+ print(metrics_df.mean(axis=0))
402
+ print('\nSTD:\n')
403
+ print(metrics_df.std(axis=0))
404
+ print('\nHD95perc:\n')
405
+ print(metrics_df['HD'].describe(percentiles=[.95]))
406
+ print('\n=======\n')
407
+ tf.keras.backend.clear_session()
408
+ # sess.close()
409
+ del network
410
+ print('Done')
setup.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+ import os
3
+
4
+ entry_points = {'console_script':['DeepDeformationMapRegistration=DeepDeformationMapRegistration.main:main']}
5
+
6
+ setup(
7
+ name='DeepDeformationMapRegistration',
8
+ py_modules=['DeepDeformationMapRegistration'],
9
+ packages=find_packages(include=['DeepDeformationMapRegistration', 'DeepDeformationMapRegistration.*']),
10
+ version='1.0',
11
+ description='Deep-registration training toolkit',
12
+ author='Javier Pérez de Frutos',
13
+ classifiers=[
14
+ 'Programming language :: Python :: 3',
15
+ 'License :: OSI Approveed :: MIT License',
16
+ 'Operating System :: OS Independent'
17
+ ],
18
+ python_requires='>=3.6',
19
+ install_requires=[
20
+ 'tensorflow-gpu==1.14.0',
21
+ 'tensorboard==1.14.0',
22
+ 'nibabel==3.2.1',
23
+ 'numpy==1.18.5',
24
+ 'livermask'
25
+ ]
26
+ )