File size: 8,234 Bytes
70656f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import h5py
import ants
import numpy as np
import DeepDeformationMapRegistration.utils.constants as C
import os
from tqdm import  tqdm
import re

from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, segmentation_ohe_to_cardinal

from argparse import ArgumentParser

import tensorflow as tf

DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
DATASET_NAMES = 'test_sample_\d{4}.h5'
DATASET_FILENAME = 'volume'
IMGS_FOLDER = '/home/jpdefrutos/workspace/DeepDeformationMapRegistration/Centerline/imgs'

WARPED_MOV = 'warpedmovout'
WARPED_FIX = 'warpedfixout'
FWD_TRFS = 'fwdtransforms'
INV_TRFS = 'invtransforms'

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, help='Directory with the images')
    parser.add_argument('--outdir', type=str, help='Output directory')
    args = parser.parse_args()

    dataset_files = os.listdir(args.dataset)
    dataset_files.sort()
    dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]

    dataset_iterator = tqdm(dataset_files)

    f = h5py.File(dataset_files[0], 'r')
    image_output_shape = list(f['fix_image'][:].shape[:-1])
    f.close()

    #### TF prep
    metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
                   NCC(image_input_shape).metric,
                   vxm.losses.MSE().loss,
                   MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
                   GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
                   HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
                   GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]

    fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
    pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
    fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
    pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')

    ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
    ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
    mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
    ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
    dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
    hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
    dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)

    config = tf.compat.v1.ConfigProto()  # device_count={'GPU':0})
    config.gpu_options.allow_growth = True
    config.log_device_placement = False  ## to log device placement (on which device the operation ran)
    config.allow_soft_placement = True

    sess = tf.Session(config=config)
    tf.keras.backend.set_session(sess)
    ####
    dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')

    metrics_file = os.path.join(output_folder, 'metrics.csv')

    for file_path in dataset_iterator:
        file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])

        dataset_iterator.set_description('{} ({}): laoding data'.format(file_num, dataset_name))
        with h5py.File(file_path, 'r') as vol_file:
            fix_img = vol_file['fix_image'][:]
            mov_img = vol_file['mov_image'][:]

            fix_seg = vol_file['fix_segmentations'][:]
            mov_seg = vol_file['mov_segmentations'][:]

            fix_centroid = vol_file['fix_centroids'][:]
            mov_centroid = vol_file['mov_centroids'][:]

        # ndarray to ANTsImage
        fix_img = ants.make_image(fix_img.shape, fix_img)
        mov_img = ants.make_image(mov_img.shape, mov_img)

        reg_output_syn = ants.registration(fix_img, mov_img, 'SyN')
        reg_output_syncc = ants.registration(fix_img, mov_img, 'SyNCC')
        mov_to_fix_trf_syn = reg_output_syn[FWD_TRFS]
        mov_to_fix_trf_syncc = reg_output_syn[FWD_TRFS]
        if not len(mov_to_fix_trf_syn) and not len(mov_to_fix_trf_syncc):
            print('ERR: Registration failed for: '+file_path)
        else:
            for reg_output in [reg_output_syn, reg_output_syncc]:
                mov_to_fix_trf = reg_output[FWD_TRFS]
                pred_img = reg_output[WARPED_MOV].numpy()
                pred_seg = mov_to_fix_trf.apply_to_image(ants.make_image(mov_seg.shape, mov_seg)).numpy()

                with sess.as_default():
                    dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
                                                    {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})

                    pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
                    mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
                    fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)

                    ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
                                                       {'fix_img:0': fix_img, 'pred_img:0': pred_img})
                    ms_ssim = ms_ssim[0]
                tf.keras.backend.clear_session()

                # TRE
                pred_centroids = dm_interp(mov_to_fix_trf.numpy(), mov_centroid, backwards=True) + mov_centroid
                upsample_scale = 128 / 64
                fix_centroids_isotropic = fix_centroids * upsample_scale
                pred_centroids_isotropic = pred_centroids * upsample_scale

                fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
                pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
                tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
                tre = np.mean([v for v in tre_array if not np.isnan(v)])

                new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
                with open(metrics_file, 'a') as f:
                    f.write(';'.join(map(str, new_line))+'\n')

                save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
                save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)

                plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
                plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
                save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)

    print('Summary\n=======\n')
    print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(
        pd.read_csv(metrics_file, sep=';', header=0).std(axis=0)))
    print('\n=======\n')