File size: 9,363 Bytes
b10768a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os, sys

currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)  # PYTHON > 3.3 does not allow relative referencing

import h5py
from tqdm import tqdm
from functools import partial
import numpy as np
from scipy.spatial.distance import euclidean
import pandas as pd
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
from scipy.spatial.distance import cdist
from skimage.morphology import skeletonize_3d
import re

DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
DATASET_NAMES = ['Affine', 'None', 'Translation']
DATASET_FILENAME = 'points'

OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'

SCALE = 1e-2  # mm to cm
# CPD PARAMS (deform)
MAX_ITER = 200
ALPHA = 0.1
BETA = 1.0  # None = Use default
TOLERANCE = 1e-8

if __name__ == '__main__':
    for dataset_name in DATASET_NAMES:
        dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
        dataset_files = os.listdir(dataset_loc)
        dataset_files.sort()
        dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]

        iterator = tqdm(dataset_files)
        df = pd.DataFrame(columns=['DATASET',
                                   'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
                                   'TIME_DEF', 'TIME_R_DEF',
                                   'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
                                   'TRE_DEF', 'TRE_R_DEF',
                                   'DS_DISP',
                                   'DATA_PATH',
                                   'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
        for i, file_path in enumerate(iterator):
            fn = os.path.split(file_path)[-1].split('.hd5')[0]
            fnum = int(re.findall('(\d+)', fn)[0])
            iterator.set_description('{}: start'.format(fn))
            pts_file = h5py.File(file_path, 'r')
            # fix_pts = pts_file['fix/points'][:]
            # fix_nodes = pts_file['fix/nodes'][:]
            fix_skel = pts_file['fix/skeleton'][:]
            fix_centroid = pts_file['fix/centroid'][:]

            # mov_pts = pts_file['mov/points'][:]
            # mov_nodes = pts_file['mov/nodes'][:]
            mov_skel = pts_file['mov/skeleton'][:]
            mov_centroid = pts_file['mov/centroid'][:]

            bbox = pts_file['parameters/bbox'][:]
            first_reshape = pts_file['parameters/first_reshape'][:]
            isotropic_shape = pts_file['parameters/isotropic_shape'][:]
            iterator.set_description('{}: Loaded data'.format(fn))
            # TODO: bring back to original shape!
            # Reshape to original_shape
            # fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
            # fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
            fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
            fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
            fix_skel = skeletonize_3d(fix_skel)
            fix_skel_pts = np.argwhere(fix_skel)
            # mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
            # mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
            mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
            mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
            mov_skel = skeletonize_3d(mov_skel)
            mov_skel_pts = np.argwhere(mov_skel)
            iterator.set_description('{}: reshaped data'.format(fn))

            ill_cond_def = False
            ill_cond_r_def = False
            # Deformable only
            iterator.set_description('{}: Computing only deformable reg.'.format(fn))

            deform_cb = partial(plot_cpd_registration_step,
                                out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))

            # _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
            time_def, deform_reg_def = deform_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True,
                                                           tolerance=TOLERANCE, max_iterations=MAX_ITER,
                                                           alpha=ALPHA, beta=BETA)
            if np.isnan(deform_reg_def.diff):
                tre_def = np.nan
                pred_mov_centroid = mov_centroid
            else:
                tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
                displacement_mov_centroid = tps(mov_centroid)
                pred_mov_centroid = mov_centroid + displacement_mov_centroid

                tre_def = euclidean(pred_mov_centroid, fix_centroid)

            plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
            os.makedirs(plot_file, exist_ok=True)
            plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
            plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')

            # Rigid followed by deformable
            iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))

            rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
            deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))

            # rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
            # deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)

            time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
            rigid_yt = rigid_reg_r_def.TY
            time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
                                                                    tolerance=TOLERANCE, max_iterations=MAX_ITER,
                                                                    alpha=ALPHA, beta=BETA)

            if np.isnan(deform_reg_r_def.diff):
                pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
            else:
                mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
                tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
                                                            np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
                displacement_mov_centroid_t = tps(mov_centroid_t)
                pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t

            tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
            dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)

            plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
            os.makedirs(plot_file, exist_ok=True)
            plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
            plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')


            iterator.set_description('{}: Saving data'.format(fn))
            df = df.append({'DATASET': dataset_name,
                       'ITERATIONS_DEF': deform_reg_def.iteration,
                       'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
                       'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
                       'TIME_DEF': time_def,
                       'TIME_R_DEF': time_r_def__r + time_r_def__def,
                       'Q_DEF': deform_reg_def.diff,
                       'Q_R_DEF__R': rigid_reg_r_def.q,
                       'Q_R_DEF__DEF': deform_reg_r_def.diff,
                       'ILL_COND_DEF': ill_cond_def,
                       'ILL_COND_R_DEF': ill_cond_r_def,
                       'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
                       'DS_DISP':euclidean(mov_centroid, fix_centroid),
                       'DATA_PATH': file_path,
                       'DIST_CENTR': np.min(dist_centroid_to_pts),
                       'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
                       'SAMPLE_NUM':fnum}, ignore_index=True)
            pts_file.close()

        df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))