File size: 13,790 Bytes
70656f8
 
 
a7b71d6
f915f2e
a7b71d6
70656f8
a7b71d6
 
70656f8
f915f2e
 
 
 
70656f8
 
 
a7b71d6
 
f915f2e
7968536
286a978
 
a7b71d6
70656f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5764e7
70656f8
 
 
7968536
286a978
7968536
 
70656f8
 
286a978
 
7968536
 
 
 
 
70656f8
 
 
286a978
e5764e7
70656f8
 
a7b71d6
7968536
70656f8
 
 
 
a7b71d6
70656f8
 
a7b71d6
 
 
70656f8
286a978
 
 
 
70656f8
 
 
 
 
286a978
 
 
70656f8
 
 
 
 
 
 
 
 
67a11d3
f915f2e
286a978
a7b71d6
7968536
70656f8
7968536
 
e5764e7
 
 
70656f8
f915f2e
e5764e7
70656f8
 
e5764e7
70656f8
 
 
 
7968536
 
 
 
 
70656f8
7968536
 
70656f8
 
ca1d395
 
a7b71d6
e5764e7
a7b71d6
 
 
 
e5764e7
a7b71d6
 
 
70656f8
 
 
 
 
 
f915f2e
a7b71d6
70656f8
ca1d395
70656f8
a7b71d6
 
 
 
ca1d395
e5764e7
 
70656f8
286a978
 
 
 
7968536
 
286a978
 
 
 
 
 
 
70656f8
 
 
 
 
 
a7b71d6
 
 
286a978
70656f8
 
a7b71d6
 
286a978
a7b71d6
7968536
286a978
 
 
70656f8
7968536
 
 
286a978
 
 
a7b71d6
7968536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5764e7
 
 
286a978
 
 
 
 
 
 
e5764e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import h5py
import ants
import numpy as np
import nibabel as nb
import os, sys
from tqdm import tqdm
import re
import time
import pandas as pd

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)  # PYTHON > 3.3 does not allow relative referencing

from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
import DeepDeformationMapRegistration.utils.constants as C
import shutil
import medpy.metric as medpy_metrics

import voxelmorph as vxm

from argparse import ArgumentParser

import tensorflow as tf

DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
DATASET_NAMES = 'test_sample_\d{4}.h5'
DATASET_FILENAME = 'volume'
IMGS_FOLDER = '/home/jpdefrutos/workspace/DeepDeformationMapRegistration/Centerline/imgs'

WARPED_MOV = 'warpedmovout'
WARPED_FIX = 'warpedfixout'
FWD_TRFS = 'fwdtransforms'
INV_TRFS = 'invtransforms'


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, help='Directory with the images')
    parser.add_argument('--outdirname', type=str, help='Output directory')
    parser.add_argument('--gpu', type=int, help='GPU')
    parser.add_argument('--savenifti', type=bool, default=True)
    parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
    args = parser.parse_args()

    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    if args.erase:
        shutil.rmtree(args.outdirname, ignore_errors=True)
    os.makedirs(args.outdirname, exist_ok=True)
    os.makedirs(os.path.join(args.outdirname, 'SyN'), exist_ok=True)
    os.makedirs(os.path.join(args.outdirname, 'SyNCC'), exist_ok=True)
    dataset_files = os.listdir(args.dataset)
    dataset_files.sort()
    dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
    dataset_files.sort()
    dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")

    f = h5py.File(dataset_files[0], 'r')
    image_shape = list(f['fix_image'][:].shape[:-1])
    nb_labels = f['fix_segmentations'][:].shape[-1] - 1
    f.close()

    #### TF prep
    metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
                   NCC(image_shape).metric,
                   vxm.losses.MSE().loss,
                   MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
                   GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric,
                   HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
                   GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]

    fix_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
    pred_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
    fix_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
    pred_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')

    ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
    ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
    mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
    ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
    # dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
    # hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
    # dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)

    config = tf.compat.v1.ConfigProto()  # device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    config.log_device_placement = False  ## to log device placement (on which device the operation ran)
    config.allow_soft_placement = True

    sess = tf.Session(config=config)
    tf.keras.backend.set_session(sess)
    ####
    os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "{:d}".format(os.cpu_count())  #https://github.com/ANTsX/ANTsPy/issues/261
    print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
    # dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
    # Header of the metrics csv file
    csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE']

    metrics_file = {'SyN': os.path.join(args.outdirname, 'SyN', 'metrics.csv'),
                    'SyNCC': os.path.join(args.outdirname, 'SyNCC', 'metrics.csv')}
    for k in metrics_file.keys():
        with open(metrics_file[k], 'w') as f:
            f.write(';'.join(csv_header)+'\n')

    print('Starting the loop')
    for step, file_path in dataset_iterator:
        file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])

        dataset_iterator.set_description('{} ({}): loading data'.format(file_num, file_path))
        with h5py.File(file_path, 'r') as vol_file:
            fix_img = vol_file['fix_image'][:]
            mov_img = vol_file['mov_image'][:]

            fix_seg = vol_file['fix_segmentations'][..., 1:].astype(np.float32)
            mov_seg = vol_file['mov_segmentations'][..., 1:].astype(np.float32)

            fix_centroids = vol_file['fix_centroids'][1:, ...]
            mov_centroids = vol_file['mov_centroids'][1:, ...]

            isotropic_shape = vol_file['isotropic_shape'][:]
            voxel_size = np.divide(fix_img.shape[:-1], isotropic_shape)

        # ndarray to ANTsImage
        fix_img_ants = ants.make_image(fix_img.shape[:-1], np.squeeze(fix_img))  # SoA doesn't work fine with 1-ch images
        mov_img_ants = ants.make_image(mov_img.shape[:-1], np.squeeze(mov_img))  # SoA doesn't work fine with 1-ch images

        dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
        t0_syn = time.time()
        reg_output_syn = ants.registration(fix_img_ants, mov_img_ants, 'SyN')
        t1_syn = time.time()

        dataset_iterator.set_description('{} ({}): running ANTs SyN'.format(file_num, file_path))
        t0_syncc = time.time()
        reg_output_syncc = ants.registration(fix_img_ants, mov_img_ants, 'SyNCC')
        t1_syncc = time.time()

        mov_to_fix_trf_syn = reg_output_syn[FWD_TRFS]
        mov_to_fix_trf_syncc = reg_output_syn[FWD_TRFS]
        if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
            print('ERR: Registration failed for: '+file_path)
        else:
            for reg_method, reg_output in zip(['SyN', 'SyNCC'], [reg_output_syn, reg_output_syncc]):
                mov_to_fix_trf_list = reg_output[FWD_TRFS]
                pred_img = reg_output[WARPED_MOV].numpy()
                pred_img = pred_img[..., np.newaxis]  # SoA doesn't work fine with 1-ch images

                fix_seg_ants = ants.make_image(fix_seg.shape, np.squeeze(fix_seg))
                mov_seg_ants = ants.make_image(mov_seg.shape, np.squeeze(mov_seg))
                pred_seg = ants.apply_transforms(fixed=fix_seg_ants, moving=mov_seg_ants,
                                                 transformlist=mov_to_fix_trf_list).numpy()
                pred_seg = np.squeeze(pred_seg)  # SoA adds an extra axis which shouldn't be there

                dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
                with sess.as_default():
                    dice = np.mean([medpy_metrics.dc(pred_seg[np.newaxis, ..., l], fix_seg[np.newaxis,..., l]) / np.sum(
                        fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
                    hd = np.mean(
                        [medpy_metrics.hd(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
                    hd95 = np.mean(
                        [medpy_metrics.hd95(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
                    dice_macro = np.mean(
                        [medpy_metrics.dc(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])

                    # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
                    #                                 {'fix_seg:0': fix_seg[np.newaxis, ...],  # Batch axis
                    #                                  'pred_seg:0': pred_seg[np.newaxis, ...]  # Batch axis
                    #                                  })

                    pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
                    mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
                    fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)

                    ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
                                                       {'fix_img:0': fix_img[np.newaxis, ...],  # Batch axis
                                                        'pred_img:0': pred_img[np.newaxis, ...]  # Batch axis
                                                        })
                    ssim = np.mean(ssim)
                    ms_ssim = ms_ssim[0]

                    # TRE
                    disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
                    dm_interp = DisplacementMapInterpolator(fix_img.shape[:-1], 'griddata', step=2)
                    pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
                    # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
                    # upsample_scale = 128 / 64
                    # fix_centroids_isotropic = fix_centroids * upsample_scale
                    # pred_centroids_isotropic = pred_centroids * upsample_scale

                    fix_centroids_isotropic = fix_centroids * voxel_size
                    pred_centroids_isotropic = pred_centroids * voxel_size

                    # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
                    # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
                    tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
                    tre = np.mean([v for v in tre_array if not np.isnan(v)])
                    if np.isnan(tre):
                        print('TRE is NaN for {} and file {}'.format(reg_method, step))

                dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
                new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95,
                            t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
                            tre]
                with open(metrics_file[reg_method], 'a') as f:
                    f.write(';'.join(map(str, new_line))+'\n')
                if args.savenifti:
                    save_nifti(fix_img, os.path.join(args.outdirname, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                    save_nifti(mov_img, os.path.join(args.outdirname, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                    save_nifti(pred_img, os.path.join(args.outdirname, reg_method, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                    save_nifti(fix_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                    save_nifti(mov_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                    save_nifti(pred_seg_card, os.path.join(args.outdirname, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)

                plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], seg_batches=[fix_seg_card[np.newaxis, ...], mov_seg_card[np.newaxis, ...], pred_seg_card[np.newaxis, ...]], filename=os.path.join(args.outdirname, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
                plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], filename=os.path.join(args.outdirname, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
                save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdirname, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)

    for k in metrics_file.keys():
        print('Summary {}\n=======\n'.format(k))
        metrics_df = pd.read_csv(metrics_file[k], sep=';', header=0)
        print('\nAVG:\n')
        print(metrics_df.mean(axis=0))
        print('\nSTD:\n')
        print(metrics_df.std(axis=0))
        print('\nHD95perc:\n')
        print(metrics_df['HD'].describe(percentiles=[.95]))
        print('\n=======\n')