jpdefrutos commited on
Commit
b10768a
·
1 Parent(s): ed5ac4a

CPD scripts

Browse files
Centerline/__init__.py ADDED
File without changes
Centerline/centerline.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ # import tensorflow as tf
8
+ # tf.enable_eager_execution()
9
+ # import neurite.py.utils as neurite_utils
10
+
11
+ from skimage.morphology import skeletonize_3d, ball
12
+ from skimage.morphology import binary_closing, binary_opening
13
+ from skimage.filters import median
14
+ from skimage.measure import regionprops, label
15
+ from skimage.transform import warp
16
+
17
+ from scipy.ndimage import zoom
18
+ from scipy.interpolate import LinearNDInterpolator, Rbf
19
+
20
+ import h5py
21
+ import numpy as np
22
+ from tqdm import tqdm
23
+ import re
24
+ import nibabel as nib
25
+ from nilearn.image import resample_img
26
+
27
+ from Centerline.graph_utils import graph_to_ndarray, deform_graph, get_bifurcation_nodes, subsample_graph, \
28
+ apply_displacement
29
+ from Centerline.skeleton_to_graph import get_graph_from_skeleton
30
+ from Centerline.visualization_utils import plot_skeleton, compare_graphs
31
+
32
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm
33
+ from DeepDeformationMapRegistration.utils import constants as C
34
+
35
+ import cupy
36
+ from cupyx.scipy.ndimage import zoom as zoom_gpu
37
+ from cupyx.scipy.ndimage import map_coordinates
38
+
39
+ DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
40
+ DATASET_NAMES = ['Affine', 'None', 'Translation']
41
+ DATASET_FILENAME = 'volume'
42
+ IMGS_FOLDER = '/home/jpdefrutos/workspace/DeepDeformationMapRegistration/Centerline/centerlines'
43
+
44
+ DATASTE_RAW_FILES = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/nifti3'
45
+ LITS_SEGMENTATION_FILE = 'segmentation'
46
+ LITS_CT_FILE = 'volume'
47
+
48
+
49
+ def warp_volume(volume, disp_map, indexing='ij'):
50
+ assert indexing is 'ij' or 'xy', 'Invalid indexing option. Only "ij" or "xy"'
51
+ grid_i = np.linspace(0, disp_map.shape[0], disp_map.shape[0], endpoint=False)
52
+ grid_j = np.linspace(0, disp_map.shape[1], disp_map.shape[1], endpoint=False)
53
+ grid_k = np.linspace(0, disp_map.shape[2], disp_map.shape[2], endpoint=False)
54
+ grid_i, grid_j, grid_k = np.meshgrid(grid_i, grid_j, grid_k, indexing=indexing)
55
+ grid_i = (grid_i.flatten() + disp_map[..., 0].flatten())[..., np.newaxis]
56
+ grid_j = (grid_j.flatten() + disp_map[..., 1].flatten())[..., np.newaxis]
57
+ grid_k = (grid_k.flatten() + disp_map[..., 2].flatten())[..., np.newaxis]
58
+ coords = np.hstack([grid_i, grid_j, grid_k]).reshape([*disp_map.shape[:-1], -1])
59
+ coords = coords.transpose((-1, 0, 1, 2))
60
+ # The returned volume has indexing xy
61
+ return warp(volume, coords)
62
+
63
+
64
+ def keep_largest_segmentation(img):
65
+ label_img = label(img)
66
+ rp = regionprops(label_img) # Regions labeled with 0 (bg) are ignored
67
+ biggest_area = (0, 0)
68
+ for l in range(0, label_img.max()):
69
+ if rp[l].area > biggest_area[1]:
70
+ biggest_area = (l + 1, rp[l].area)
71
+ img[label_img != biggest_area[0]] = 0.
72
+ return img
73
+
74
+
75
+ def preprocess_image(img, keep_largest=False):
76
+ ret = binary_closing(img, ball(1))
77
+ ret = binary_opening(ret, ball(1))
78
+ #ret = median(ret, ball(1), mode='constant')
79
+ if keep_largest:
80
+ ret = keep_largest_segmentation(ret)
81
+ return ret.astype(np.float)
82
+
83
+
84
+ def build_displacement_map_interpolator(disp_map, backwards=False, indexing='ij'):
85
+ grid_i = np.linspace(0, disp_map.shape[0], disp_map.shape[0], endpoint=False)
86
+ grid_j = np.linspace(0, disp_map.shape[1], disp_map.shape[1], endpoint=False)
87
+ grid_k = np.linspace(0, disp_map.shape[2], disp_map.shape[2], endpoint=False)
88
+ grid_i, grid_j, grid_k = np.meshgrid(grid_i, grid_j, grid_k, indexing=indexing)
89
+ grid_i = grid_i.flatten()
90
+ grid_j = grid_j.flatten()
91
+ grid_k = grid_k.flatten()
92
+ # To generate the moving image, we used backwards mapping were the input was the fix image
93
+ # Now we are doing direct mapping from the fix graph coordinates to the moving coordinates
94
+ # The application points of the displacement map are thus the transformed "moving image"-grid
95
+ # and the displacement vectors are reversed
96
+ if backwards:
97
+ coords = np.hstack([grid_i[..., np.newaxis], grid_j[..., np.newaxis], grid_k[..., np.newaxis]])
98
+ return LinearNDInterpolator(coords, np.reshape(disp_map, [-1, 3]))
99
+ else:
100
+ grid_i = (grid_i + disp_map[..., 0].flatten())
101
+ grid_j = (grid_j + disp_map[..., 1].flatten())
102
+ grid_k = (grid_k + disp_map[..., 2].flatten())
103
+
104
+ coords = np.hstack([grid_i[..., np.newaxis], grid_j[..., np.newaxis], grid_k[..., np.newaxis]])
105
+ return LinearNDInterpolator(coords, -np.reshape(disp_map, [-1, 3]))
106
+
107
+
108
+ def resample_segmentation(img, output_shape, preserve_range, threshold=None, gpu=True):
109
+ # Preserve range can be a bool (keep or not the original dyn. range) or a list with a new dyn. range
110
+ zoom_f = np.divide(np.asarray(output_shape), np.asarray(img.shape))
111
+
112
+ if gpu:
113
+ out_img = zoom_gpu(cupy.asarray(img), zoom_f, order=1) # order = 0 or 1
114
+ else:
115
+ out_img = zoom(img, zoom_f)
116
+ if isinstance(preserve_range, bool):
117
+ if preserve_range:
118
+ range_min, range_max = np.amin(img), np.amax(img)
119
+ out_img = min_max_norm(out_img)
120
+ out_img = out_img * (range_max - range_min) + range_min
121
+ elif isinstance(preserve_range, list):
122
+ range_min, range_max = preserve_range
123
+ out_img = min_max_norm(out_img)
124
+ out_img = out_img * (range_max - range_min) + range_min
125
+
126
+ if threshold is not None and out_img.min() < threshold < out_img.max():
127
+ range_min, range_max = np.amin(out_img), np.amax(out_img)
128
+ out_img[out_img > threshold] = range_max
129
+ out_img[out_img < range_max] = range_min
130
+ return cupy.asnumpy(out_img) if gpu else out_img
131
+
132
+
133
+ if __name__ == '__main__':
134
+ for dataset_name in DATASET_NAMES:
135
+ dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
136
+ dataset_files = os.listdir(dataset_loc)
137
+ dataset_files.sort()
138
+ dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
139
+
140
+ iterator = tqdm(dataset_files)
141
+ for file_path in iterator:
142
+ file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
143
+
144
+ iterator.set_description('{} ({}): laoding data'.format(file_num, dataset_name))
145
+ vol_file = h5py.File(file_path, 'r')
146
+ # fix_vessels = vol_file[C.H5_FIX_VESSELS_MASK][..., 0]
147
+ disp_map = vol_file[C.H5_GT_DISP][:]
148
+ bbox = vol_file['parameters/bbox'][:]
149
+ bbox_min = bbox[:3]
150
+ bbox_max = bbox[3:] + bbox_min
151
+
152
+ # Load vessel segmentation mask and resize to 64^3
153
+ fix_labels = nib.load(os.path.join(DATASTE_RAW_FILES, 'segmentation-{:04d}.nii.gz'.format(file_num)))
154
+ fix_vessels = fix_labels.slicer[..., 1]
155
+ fix_vessels = resample_img(fix_vessels, np.eye(3))
156
+ fix_vessels = np.asarray(fix_vessels.dataobj)
157
+ fix_vessels = preprocess_image(fix_vessels)
158
+ fix_vessels = resample_segmentation(fix_vessels, vol_file['parameters/first_reshape'][:], [0, 1], 0.3,
159
+ gpu=True)
160
+ fix_vessels = fix_vessels[bbox_min[0]:bbox_max[0], bbox_min[1]:bbox_max[1], bbox_min[2]:bbox_max[2]]
161
+ fix_vessels = resample_segmentation(fix_vessels, [64] * 3, [0, 1], 0.3, gpu=True)
162
+ fix_vessels = preprocess_image(fix_vessels)
163
+
164
+ mov_vessels = preprocess_image(warp_volume(fix_vessels, disp_map))
165
+ mov_skel = skeletonize_3d(mov_vessels)
166
+ ### Fix the incorrect scaling ###
167
+ disp_map *= 2
168
+ bbox_size = np.asarray(bbox[3:]) # Only load the bbox size
169
+ rescale_factors = 64 / bbox_size
170
+
171
+ disp_map[..., 0] = np.multiply(disp_map[..., 0], rescale_factors[0])
172
+ disp_map[..., 1] = np.multiply(disp_map[..., 1], rescale_factors[1])
173
+ disp_map[..., 2] = np.multiply(disp_map[..., 2], rescale_factors[2])
174
+ #################################
175
+
176
+ iterator.set_description('{} ({}): getting graphs'.format(file_num, dataset_name))
177
+ # Prepare displacement map
178
+ disp_map_interpolator = build_displacement_map_interpolator(disp_map, backwards=False)
179
+
180
+ # Get skeleton and graph
181
+ fix_skel = skeletonize_3d(fix_vessels)
182
+ fix_graph = get_graph_from_skeleton(fix_skel, subsample=True)
183
+ mov_graph = get_graph_from_skeleton(mov_skel, subsample=True) # deform_graph(fix_graph, disp_map_interpolator)
184
+
185
+ ##### TODO: ERASE Check the mov graph ######
186
+ # check_mov_vessels = vol_file[C.H5_MOV_VESSELS_MASK][..., 0]
187
+ # check_mov_vessels = preprocess_image(check_mov_vessels)
188
+ # check_mov_skel = skeletonize_3d(check_mov_vessels)
189
+ # check_mov_graph = get_graph_from_skeleton(check_mov_skel, subsample=True)
190
+ ###########
191
+ fix_pts, fix_nodes, fix_edges = graph_to_ndarray(fix_graph)
192
+ mov_pts, mov_nodes, mov_edges = graph_to_ndarray(mov_graph)
193
+
194
+ fix_bifur_loc, fix_bifur_id = get_bifurcation_nodes(fix_graph)
195
+ mov_bifur_loc, mov_bifur_id = get_bifurcation_nodes(mov_graph)
196
+
197
+ iterator.set_description('{} ({}): saving data'.format(file_num, dataset_name))
198
+ pts_file_path, pts_file_name = os.path.split(file_path)
199
+ pts_file_name = pts_file_name.replace(DATASET_FILENAME, 'points')
200
+ pts_file_path = os.path.join(pts_file_path, pts_file_name)
201
+ pts_file = h5py.File(pts_file_path, 'w')
202
+
203
+ pts_file.create_dataset('fix/points', data=fix_pts)
204
+ pts_file.create_dataset('fix/nodes', data=fix_nodes)
205
+ pts_file.create_dataset('fix/edges', data=fix_edges)
206
+ pts_file.create_dataset('fix/bifurcations', data=fix_bifur_loc)
207
+ pts_file.create_dataset('fix/graph', data=fix_graph)
208
+ pts_file.create_dataset('fix/img', data=fix_vessels)
209
+ pts_file.create_dataset('fix/skeleton', data=fix_skel)
210
+ pts_file.create_dataset('fix/centroid', data=vol_file[C.H5_FIX_CENTROID][:])
211
+
212
+ pts_file.create_dataset('mov/points', data=mov_pts)
213
+ pts_file.create_dataset('mov/nodes', data=mov_nodes)
214
+ pts_file.create_dataset('mov/edges', data=mov_edges)
215
+ pts_file.create_dataset('mov/bifurcations', data=mov_bifur_loc)
216
+ pts_file.create_dataset('mov/graph', data=mov_graph)
217
+ pts_file.create_dataset('mov/img', data=mov_vessels)
218
+ pts_file.create_dataset('mov/skeleton', data=mov_skel)
219
+ pts_file.create_dataset('mov/centroid', data=vol_file[C.H5_MOV_CENTROID][:])
220
+
221
+ pts_file.create_dataset('parameters/voxel_size', data=vol_file['parameters/voxel_size'][:])
222
+ pts_file.create_dataset('parameters/original_affine', data=vol_file['parameters/original_affine'][:])
223
+ pts_file.create_dataset('parameters/isotropic_affine', data=vol_file['parameters/isotropic_affine'][:])
224
+ pts_file.create_dataset('parameters/original_shape', data=vol_file['parameters/original_shape'][:])
225
+ pts_file.create_dataset('parameters/isotropic_shape', data=vol_file['parameters/isotropic_shape'][:])
226
+ pts_file.create_dataset('parameters/first_reshape', data=vol_file['parameters/first_reshape'][:])
227
+ pts_file.create_dataset('parameters/bbox', data=vol_file['parameters/bbox'][:])
228
+ pts_file.create_dataset('parameters/last_reshape', data=vol_file['parameters/last_reshape'][:])
229
+
230
+ pts_file.create_dataset('displacement_map', data=disp_map)
231
+
232
+ vol_file.close()
233
+ pts_file.close()
234
+
235
+ iterator.set_description('{} ({}): drawing plots'.format(file_num, dataset_name))
236
+ num = pts_file_name.split('-')[-1].split('.hd5')[0]
237
+ imgs_folder = os.path.join(IMGS_FOLDER, dataset_name, num)
238
+ os.makedirs(imgs_folder, exist_ok=True)
239
+ plot_skeleton(fix_vessels, fix_skel, fix_graph, os.path.join(imgs_folder, 'fix'), ['.pdf', '.png'])
240
+ plot_skeleton(mov_vessels, mov_skel, mov_graph, os.path.join(imgs_folder, 'mov'), ['.pdf', '.png'])
241
+ iterator.set_description('{} ({})'.format(file_num, dataset_name))
Centerline/cpd_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pycpd import DeformableRegistration, RigidRegistration
2
+ import numpy as np
3
+ import time
4
+ from scipy.interpolate import Rbf
5
+ import warnings
6
+
7
+ def cpd_non_rigid_transform_pt(pt, Y, G, W):
8
+ from scipy.interpolate import LinearNDInterpolator
9
+ interp = LinearNDInterpolator(points=Y, values=np.dot(G, W), fill_value=0.)
10
+ return interp(pt)
11
+
12
+
13
+ def radial_basis_function(pts, vals, function='thin-plate'):
14
+ # The Rbf function does not handle n-D hyper-surfaces, so we need an interpolator per displacements. Actually it does mode='N-D'
15
+ pts_unique, idxs = np.unique(pts, return_index=True, axis=0) # Prevent singular matrices
16
+ ill_conditioned = False
17
+ with warnings.catch_warnings(record=True) as caught_warns:
18
+ warnings.simplefilter('always')
19
+ dx = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 0], function=function)
20
+ dy = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 1], function=function)
21
+ dz = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 2], function=function)
22
+ for w in caught_warns:
23
+ print(w)
24
+ ill_conditioned = ill_conditioned or 'ill-conditioned matrix' in str(w).lower()
25
+ return lambda int_pt: np.asarray([dx(*int_pt), dy(*int_pt), dz(*int_pt)]), ill_conditioned
26
+
27
+
28
+ def deform_registration(fix_pts, mov_pts, callback_fnc=None, time_it=False, max_iterations=100, tolerance=1e-8, alpha=None, beta=None):
29
+ deform_reg = DeformableRegistration(**{'Y': mov_pts, 'X': fix_pts},
30
+ alpha=alpha, beta=beta, tolerance=tolerance, max_iterations=max_iterations)
31
+ start_t = time.time()
32
+ trf_mov_pts, deform_p = deform_reg.register(callback_fnc)
33
+ end_t = time.time()
34
+ if time_it:
35
+ return end_t - start_t, deform_reg
36
+ else:
37
+ return trf_mov_pts, deform_p, deform_reg
38
+
39
+
40
+ def rigid_registration(fix_pts, mov_pts, callback_fnc=None, time_it=False):
41
+ rigid_reg = RigidRegistration(**{'Y': mov_pts, 'X': fix_pts})
42
+ start_t = time.time()
43
+ trf_mov_pts, trf_p = rigid_reg.register(callback_fnc)
44
+ end_t = time.time()
45
+ if time_it:
46
+ return end_t - start_t, rigid_reg
47
+ else:
48
+ return trf_mov_pts, trf_p, rigid_reg
Centerline/evaluate_BayesianCPD_skeleton.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ import h5py
8
+ from tqdm import tqdm
9
+ from functools import partial
10
+ import numpy as np
11
+ from scipy.spatial.distance import euclidean
12
+ import pandas as pd
13
+ from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
14
+ from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
15
+ from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
16
+ from scipy.spatial.distance import cdist
17
+ from skimage.morphology import skeletonize_3d
18
+ import re
19
+ from probreg import bcpd
20
+
21
+ DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
22
+ DATASET_NAMES = ['Affine', 'None', 'Translation']
23
+ DATASET_FILENAME = 'points'
24
+
25
+ OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
26
+
27
+ SCALE = 1e-2 # mm to cm
28
+ # CPD PARAMS (deform)
29
+ MAX_ITER = 200
30
+ ALPHA = 0.1
31
+ BETA = 1.0 # None = Use default
32
+ TOLERANCE = 1e-8
33
+
34
+ if __name__ == '__main__':
35
+ for dataset_name in DATASET_NAMES:
36
+ dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
37
+ dataset_files = os.listdir(dataset_loc)
38
+ dataset_files.sort()
39
+ dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
40
+
41
+ iterator = tqdm(dataset_files)
42
+ df = pd.DataFrame(columns=['DATASET',
43
+ 'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
44
+ 'TIME_DEF', 'TIME_R_DEF',
45
+ 'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
46
+ 'TRE_DEF', 'TRE_R_DEF',
47
+ 'DS_DISP',
48
+ 'DATA_PATH',
49
+ 'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
50
+ for i, file_path in enumerate(iterator):
51
+ fn = os.path.split(file_path)[-1].split('.hd5')[0]
52
+ fnum = int(re.findall('(\d+)', fn)[0])
53
+ iterator.set_description('{}: start'.format(fn))
54
+ pts_file = h5py.File(file_path, 'r')
55
+ # fix_pts = pts_file['fix/points'][:]
56
+ # fix_nodes = pts_file['fix/nodes'][:]
57
+ fix_skel = pts_file['fix/skeleton'][:]
58
+ fix_centroid = pts_file['fix/centroid'][:]
59
+
60
+ # mov_pts = pts_file['mov/points'][:]
61
+ # mov_nodes = pts_file['mov/nodes'][:]
62
+ mov_skel = pts_file['mov/skeleton'][:]
63
+ mov_centroid = pts_file['mov/centroid'][:]
64
+
65
+ bbox = pts_file['parameters/bbox'][:]
66
+ first_reshape = pts_file['parameters/first_reshape'][:]
67
+ isotropic_shape = pts_file['parameters/isotropic_shape'][:]
68
+ iterator.set_description('{}: Loaded data'.format(fn))
69
+ # TODO: bring back to original shape!
70
+ # Reshape to original_shape
71
+ # fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
72
+ # fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
73
+ fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
74
+ fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
75
+ fix_skel = skeletonize_3d(fix_skel)
76
+ fix_skel_pts = np.argwhere(fix_skel)
77
+ # mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
78
+ # mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
79
+ mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
80
+ mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
81
+ mov_skel = skeletonize_3d(mov_skel)
82
+ mov_skel_pts = np.argwhere(mov_skel)
83
+ iterator.set_description('{}: reshaped data'.format(fn))
84
+
85
+ ill_cond_def = False
86
+ ill_cond_r_def = False
87
+ # Deformable only
88
+ iterator.set_description('{}: Computing only deformable reg.'.format(fn))
89
+
90
+ tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE)
91
+
92
+ if np.isnan(deform_reg_def.diff):
93
+ tre_def = np.nan
94
+ pred_mov_centroid = mov_centroid
95
+ else:
96
+ tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
97
+ displacement_mov_centroid = tps(mov_centroid)
98
+ pred_mov_centroid = mov_centroid + displacement_mov_centroid
99
+
100
+ tre_def = euclidean(pred_mov_centroid, fix_centroid)
101
+
102
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
103
+ os.makedirs(plot_file, exist_ok=True)
104
+ plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
105
+ plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
106
+
107
+ # Rigid followed by deformable
108
+ iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
109
+
110
+ rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
111
+ deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
112
+
113
+ # rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
114
+ # deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
115
+
116
+ time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
117
+ rigid_yt = rigid_reg_r_def.TY
118
+ time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
119
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
120
+ alpha=ALPHA, beta=BETA)
121
+
122
+ if np.isnan(deform_reg_r_def.diff):
123
+ pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
124
+ else:
125
+ mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
126
+ tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
127
+ np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
128
+ displacement_mov_centroid_t = tps(mov_centroid_t)
129
+ pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
130
+
131
+ tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
132
+ dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
133
+
134
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
135
+ os.makedirs(plot_file, exist_ok=True)
136
+ plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
137
+ plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
138
+
139
+
140
+ iterator.set_description('{}: Saving data'.format(fn))
141
+ df = df.append({'DATASET': dataset_name,
142
+ 'ITERATIONS_DEF': deform_reg_def.iteration,
143
+ 'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
144
+ 'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
145
+ 'TIME_DEF': time_def,
146
+ 'TIME_R_DEF': time_r_def__r + time_r_def__def,
147
+ 'Q_DEF': deform_reg_def.diff,
148
+ 'Q_R_DEF__R': rigid_reg_r_def.q,
149
+ 'Q_R_DEF__DEF': deform_reg_r_def.diff,
150
+ 'ILL_COND_DEF': ill_cond_def,
151
+ 'ILL_COND_R_DEF': ill_cond_r_def,
152
+ 'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
153
+ 'DS_DISP':euclidean(mov_centroid, fix_centroid),
154
+ 'DATA_PATH': file_path,
155
+ 'DIST_CENTR': np.min(dist_centroid_to_pts),
156
+ 'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
157
+ 'SAMPLE_NUM':fnum}, ignore_index=True)
158
+ pts_file.close()
159
+
160
+ df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
Centerline/evaluate_CPD_dense.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import h5py
7
+ from tqdm import tqdm
8
+ from functools import partial
9
+ import numpy as np
10
+ from scipy.spatial.distance import euclidean
11
+ import pandas as pd
12
+ from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
13
+ from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
14
+ from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
15
+ from scipy.spatial.distance import cdist
16
+ import re
17
+
18
+ DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
19
+ DATASET_NAMES = ['None', 'Affine', 'None', 'Translation']
20
+ DATASET_FILENAME = 'points'
21
+
22
+ OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/dense_final'
23
+
24
+ SCALE = 1e-2 # mm to cm
25
+
26
+ # CPD PARAMS (deform)
27
+ MAX_ITER = 200
28
+ ALPHA = 2.
29
+ BETA = 2. # None = Use default
30
+ TOLERANCE = 1e-8
31
+ RBF_FUNCTION='thin-plate'
32
+
33
+ if __name__ == '__main__':
34
+ for dataset_name in DATASET_NAMES:
35
+ dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
36
+ dataset_files = os.listdir(dataset_loc)
37
+ dataset_files.sort()
38
+ dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
39
+
40
+ iterator = tqdm(dataset_files)
41
+ df = pd.DataFrame(columns=['DATASET',
42
+ 'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
43
+ 'TIME_DEF', 'TIME_R_DEF',
44
+ 'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
45
+ 'TRE_DEF', 'TRE_R_DEF',
46
+ 'DS_DISP',
47
+ 'DATA_PATH',
48
+ 'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
49
+ for i, file_path in enumerate(iterator):
50
+ fn = os.path.split(file_path)[-1].split('.hd5')[0]
51
+ fnum = int(re.findall('(\d+)', fn)[0])
52
+ iterator.set_description('{}: start'.format(fn))
53
+ pts_file = h5py.File(file_path, 'r')
54
+ fix_pts = pts_file['fix/points'][:]
55
+ # fix_nodes = pts_file['fix/nodes'][:]
56
+ fix_centroid = pts_file['fix/centroid'][:]
57
+
58
+ mov_pts = pts_file['mov/points'][:]
59
+ # mov_nodes = pts_file['mov/nodes'][:]
60
+ mov_centroid = pts_file['mov/centroid'][:]
61
+
62
+ bbox = pts_file['parameters/bbox'][:]
63
+ first_reshape = pts_file['parameters/first_reshape'][:]
64
+ original_shape = pts_file['parameters/isotropic_shape'][:]
65
+ iterator.set_description('{}: Loaded data'.format(fn))
66
+ # TODO: bring back to original shape!
67
+ # Reshape to original_shape
68
+ # fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
69
+ fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
70
+ fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, original_shape)
71
+ # mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
72
+ mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
73
+ mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, original_shape)
74
+ iterator.set_description('{}: reshaped data'.format(fn))
75
+
76
+ ill_cond_def = False
77
+ ill_cond_r_def = False
78
+ # Deformable only
79
+ iterator.set_description('{}: Computing only deformable reg.'.format(fn))
80
+
81
+ # deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
82
+
83
+ # _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
84
+ time_def, deform_reg_def = deform_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True,
85
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
86
+ alpha=ALPHA, beta=BETA)
87
+ if np.isnan(deform_reg_def.diff):
88
+ tre_def = np.nan
89
+ pred_mov_centroid = np.zeros((3,))
90
+ else:
91
+ tps, ill_cond_def = radial_basis_function(mov_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
92
+ displacement_mov_centroid = tps(mov_centroid)
93
+ pred_mov_centroid = mov_centroid + displacement_mov_centroid
94
+
95
+ tre_def = euclidean(pred_mov_centroid, fix_centroid)
96
+
97
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
98
+ os.makedirs(plot_file, exist_ok=True)
99
+ plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
100
+ plot_cpd(fix_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
101
+
102
+ # Rigid followed by deformable
103
+ iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
104
+
105
+ # rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER,
106
+ # '{}/{:04d}/RIGID_DEF/rigid'.format(
107
+ # dataset_name, fnum)))
108
+ # deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER,
109
+ # '{}/{:04d}/RIGID_DEF/deform'.format(
110
+ # dataset_name, fnum)))
111
+ # rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
112
+ # deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
113
+
114
+ time_r_def__r, rigid_reg_r_def = rigid_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True)
115
+ rigid_yt = rigid_reg_r_def.TY
116
+ time_r_def__def, deform_reg_r_def = deform_registration(fix_pts*SCALE, rigid_yt, time_it=True,
117
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
118
+ alpha=ALPHA, beta=BETA)
119
+
120
+ if np.isnan(deform_reg_r_def.diff):
121
+ pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
122
+ else:
123
+ mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
124
+ tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
125
+ displacement_mov_centroid_t = tps(mov_centroid_t)
126
+ pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
127
+
128
+ tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
129
+ dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_pts)
130
+
131
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
132
+ os.makedirs(plot_file, exist_ok=True)
133
+ plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
134
+ plot_cpd(fix_pts, deform_reg_r_def.TY, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
135
+
136
+ iterator.set_description('{}: Saving data'.format(fn))
137
+ df = df.append({'DATASET': dataset_name,
138
+ 'ITERATIONS_DEF': deform_reg_def.iteration,
139
+ 'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
140
+ 'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
141
+ 'TIME_DEF': time_def,
142
+ 'TIME_R_DEF': time_r_def__r + time_r_def__def,
143
+ 'Q_DEF': deform_reg_def.diff,
144
+ 'Q_R_DEF__R': rigid_reg_r_def.q,
145
+ 'Q_R_DEF__DEF': deform_reg_r_def.diff,
146
+ 'ILL_COND_DEF': ill_cond_def,
147
+ 'ILL_COND_R_DEF': ill_cond_r_def,
148
+ 'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
149
+ 'DS_DISP':euclidean(mov_centroid, fix_centroid),
150
+ 'DATA_PATH': file_path,
151
+ 'DIST_CENTR': np.min(dist_centroid_to_pts),
152
+ 'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
153
+ 'SAMPLE_NUM': fnum}, ignore_index=True)
154
+ pts_file.close()
155
+
156
+ df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
Centerline/evaluate_CPD_nodes.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import h5py
7
+ from tqdm import tqdm
8
+ from functools import partial
9
+ import numpy as np
10
+ from scipy.spatial.distance import euclidean
11
+ import pandas as pd
12
+ from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
13
+ from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
14
+ from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
15
+ from scipy.spatial.distance import cdist
16
+ import re
17
+
18
+ DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
19
+ DATASET_NAMES = ['Affine', 'None', 'Translation']
20
+ DATASET_FILENAME = 'points'
21
+
22
+ OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/nodes_final'
23
+
24
+ SCALE = 1e-2 # mm to cm
25
+
26
+ # CPD PARAMS (deform)
27
+ MAX_ITER = 200
28
+ ALPHA = 2.
29
+ BETA = 2. # None = Use default
30
+ TOLERANCE = 1e-8
31
+ RBF_FUNCTION='thin-plate'
32
+
33
+ if __name__ == '__main__':
34
+ for dataset_name in DATASET_NAMES:
35
+ dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
36
+ dataset_files = os.listdir(dataset_loc)
37
+ dataset_files.sort()
38
+ dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
39
+
40
+ iterator = tqdm(dataset_files)
41
+ df = pd.DataFrame(columns=['DATASET',
42
+ 'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
43
+ 'TIME_DEF', 'TIME_R_DEF',
44
+ 'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
45
+ 'TRE_DEF', 'TRE_R_DEF',
46
+ 'DS_DISP',
47
+ 'DATA_PATH',
48
+ 'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
49
+ for i, file_path in enumerate(iterator):
50
+ fn = os.path.split(file_path)[-1].split('.hd5')[0]
51
+ fnum = int(re.findall('(\d+)', fn)[0])
52
+ iterator.set_description('{}: start'.format(fn))
53
+ pts_file = h5py.File(file_path, 'r')
54
+ fix_pts = pts_file['fix/points'][:]
55
+ fix_nodes = pts_file['fix/nodes'][:]
56
+ fix_centroid = pts_file['fix/centroid'][:]
57
+
58
+ mov_pts = pts_file['mov/points'][:]
59
+ mov_nodes = pts_file['mov/nodes'][:]
60
+ mov_centroid = pts_file['mov/centroid'][:]
61
+
62
+ bbox = pts_file['parameters/bbox'][:]
63
+ first_reshape = pts_file['parameters/first_reshape'][:]
64
+ isotropic_shape = pts_file['parameters/isotropic_shape'][:]
65
+ iterator.set_description('{}: Loaded data'.format(fn))
66
+ # TODO: bring back to original shape!
67
+ # Reshape to original_shape
68
+ fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64] * 3, first_reshape, isotropic_shape)
69
+ fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, isotropic_shape)
70
+ fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
71
+ mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, isotropic_shape)
72
+ mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, isotropic_shape)
73
+ mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
74
+ iterator.set_description('{}: reshaped data'.format(fn))
75
+
76
+ if mov_nodes.shape[0] == 1:
77
+ # Otherwise we only have a point, and CPD can't handle that... absurd!
78
+ fix_nodes = fix_pts
79
+ mov_nodes = mov_pts
80
+
81
+ ill_cond_def = False
82
+ ill_cond_r_def = False
83
+ # Deformable only
84
+ iterator.set_description('{}: Computing only deformable reg.'.format(fn))
85
+
86
+ # deform_cb = partial(plot_cpd_registration_step,
87
+ # out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
88
+
89
+ # _, _, deform_reg_def = deform_registration(fix_nodes, mov_nodes, deform_cb)
90
+ time_def, deform_reg_def = deform_registration(fix_nodes*SCALE, mov_nodes*SCALE, time_it=True,
91
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
92
+ alpha=ALPHA, beta=BETA)
93
+ if np.isnan(deform_reg_def.diff):
94
+ tre_def = np.nan
95
+ pred_mov_centroid = np.zeros((3,))
96
+ else:
97
+ tps, ill_cond_def = radial_basis_function(mov_nodes, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
98
+ displacement_mov_centroid = tps(mov_centroid)
99
+ pred_mov_centroid = mov_centroid + displacement_mov_centroid
100
+
101
+ tre_def = euclidean(pred_mov_centroid, fix_centroid)
102
+
103
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
104
+ os.makedirs(plot_file, exist_ok=True)
105
+ plot_cpd(fix_nodes, mov_nodes, fix_centroid, mov_centroid, plot_file + '/before_registration')
106
+ plot_cpd(fix_nodes, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
107
+
108
+ # Rigid followed by deformable
109
+ iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
110
+
111
+ # rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
112
+ # deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
113
+ # rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_nodes, mov_nodes, rigid_cb)
114
+ # deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_nodes, rigid_yt, deform_cb)
115
+
116
+ time_r_def__r, rigid_reg_r_def = rigid_registration(fix_nodes*SCALE, mov_nodes*SCALE, time_it=True)
117
+ rigid_yt = rigid_reg_r_def.TY
118
+ time_r_def__def, deform_reg_r_def = deform_registration(fix_nodes*SCALE, rigid_yt, time_it=True,
119
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
120
+ alpha=ALPHA, beta=BETA)
121
+
122
+ if np.isnan(deform_reg_r_def.diff):
123
+ pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
124
+ else:
125
+ mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
126
+ tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
127
+ displacement_mov_centroid_t = tps(mov_centroid_t)
128
+ pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
129
+
130
+ tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
131
+ dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_nodes)
132
+
133
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
134
+ os.makedirs(plot_file, exist_ok=True)
135
+ plot_cpd(fix_nodes, mov_nodes, fix_centroid, mov_centroid, plot_file + '/before_registration')
136
+ plot_cpd(fix_nodes, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
137
+
138
+ iterator.set_description('{}: Saving data'.format(fn))
139
+ df = df.append({'DATASET':dataset_name,
140
+ 'ITERATIONS_DEF': deform_reg_def.iteration,
141
+ 'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
142
+ 'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
143
+ 'TIME_DEF': time_def,
144
+ 'TIME_R_DEF': time_r_def__r + time_r_def__def,
145
+ 'Q_DEF': deform_reg_def.diff,
146
+ 'Q_R_DEF__R': rigid_reg_r_def.q,
147
+ 'Q_R_DEF__DEF': deform_reg_r_def.diff,
148
+ 'ILL_COND_DEF': ill_cond_def,
149
+ 'ILL_COND_R_DEF': ill_cond_r_def,
150
+ 'TRE_DEF':tre_def, 'TRE_R_DEF':tre_r_def,
151
+ 'DS_DISP':euclidean(mov_centroid, fix_centroid),
152
+ 'DATA_PATH':file_path,
153
+ 'DIST_CENTR':np.min(dist_centroid_to_pts),
154
+ 'DIST_CENTR_DEF_95':np.percentile(dist_centroid_to_pts, 95),
155
+ 'SAMPLE_NUM':fnum}, ignore_index=True)
156
+ pts_file.close()
157
+
158
+ df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
Centerline/evaluate_CPD_skeleton.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ import h5py
8
+ from tqdm import tqdm
9
+ from functools import partial
10
+ import numpy as np
11
+ from scipy.spatial.distance import euclidean
12
+ import pandas as pd
13
+ from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
14
+ from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
15
+ from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
16
+ from scipy.spatial.distance import cdist
17
+ from skimage.morphology import skeletonize_3d
18
+ import re
19
+
20
+ DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
21
+ DATASET_NAMES = ['Affine', 'None', 'Translation']
22
+ DATASET_FILENAME = 'points'
23
+
24
+ OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
25
+
26
+ SCALE = 1e-2 # mm to cm
27
+ # CPD PARAMS (deform)
28
+ MAX_ITER = 200
29
+ ALPHA = 0.1
30
+ BETA = 1.0 # None = Use default
31
+ TOLERANCE = 1e-8
32
+
33
+ if __name__ == '__main__':
34
+ for dataset_name in DATASET_NAMES:
35
+ dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
36
+ dataset_files = os.listdir(dataset_loc)
37
+ dataset_files.sort()
38
+ dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
39
+
40
+ iterator = tqdm(dataset_files)
41
+ df = pd.DataFrame(columns=['DATASET',
42
+ 'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
43
+ 'TIME_DEF', 'TIME_R_DEF',
44
+ 'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
45
+ 'TRE_DEF', 'TRE_R_DEF',
46
+ 'DS_DISP',
47
+ 'DATA_PATH',
48
+ 'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
49
+ for i, file_path in enumerate(iterator):
50
+ fn = os.path.split(file_path)[-1].split('.hd5')[0]
51
+ fnum = int(re.findall('(\d+)', fn)[0])
52
+ iterator.set_description('{}: start'.format(fn))
53
+ pts_file = h5py.File(file_path, 'r')
54
+ # fix_pts = pts_file['fix/points'][:]
55
+ # fix_nodes = pts_file['fix/nodes'][:]
56
+ fix_skel = pts_file['fix/skeleton'][:]
57
+ fix_centroid = pts_file['fix/centroid'][:]
58
+
59
+ # mov_pts = pts_file['mov/points'][:]
60
+ # mov_nodes = pts_file['mov/nodes'][:]
61
+ mov_skel = pts_file['mov/skeleton'][:]
62
+ mov_centroid = pts_file['mov/centroid'][:]
63
+
64
+ bbox = pts_file['parameters/bbox'][:]
65
+ first_reshape = pts_file['parameters/first_reshape'][:]
66
+ isotropic_shape = pts_file['parameters/isotropic_shape'][:]
67
+ iterator.set_description('{}: Loaded data'.format(fn))
68
+ # TODO: bring back to original shape!
69
+ # Reshape to original_shape
70
+ # fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
71
+ # fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
72
+ fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
73
+ fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
74
+ fix_skel = skeletonize_3d(fix_skel)
75
+ fix_skel_pts = np.argwhere(fix_skel)
76
+ # mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
77
+ # mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
78
+ mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
79
+ mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
80
+ mov_skel = skeletonize_3d(mov_skel)
81
+ mov_skel_pts = np.argwhere(mov_skel)
82
+ iterator.set_description('{}: reshaped data'.format(fn))
83
+
84
+ ill_cond_def = False
85
+ ill_cond_r_def = False
86
+ # Deformable only
87
+ iterator.set_description('{}: Computing only deformable reg.'.format(fn))
88
+
89
+ deform_cb = partial(plot_cpd_registration_step,
90
+ out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
91
+
92
+ # _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
93
+ time_def, deform_reg_def = deform_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True,
94
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
95
+ alpha=ALPHA, beta=BETA)
96
+ if np.isnan(deform_reg_def.diff):
97
+ tre_def = np.nan
98
+ pred_mov_centroid = mov_centroid
99
+ else:
100
+ tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
101
+ displacement_mov_centroid = tps(mov_centroid)
102
+ pred_mov_centroid = mov_centroid + displacement_mov_centroid
103
+
104
+ tre_def = euclidean(pred_mov_centroid, fix_centroid)
105
+
106
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
107
+ os.makedirs(plot_file, exist_ok=True)
108
+ plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
109
+ plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
110
+
111
+ # Rigid followed by deformable
112
+ iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
113
+
114
+ rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
115
+ deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
116
+
117
+ # rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
118
+ # deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
119
+
120
+ time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
121
+ rigid_yt = rigid_reg_r_def.TY
122
+ time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
123
+ tolerance=TOLERANCE, max_iterations=MAX_ITER,
124
+ alpha=ALPHA, beta=BETA)
125
+
126
+ if np.isnan(deform_reg_r_def.diff):
127
+ pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
128
+ else:
129
+ mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
130
+ tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
131
+ np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
132
+ displacement_mov_centroid_t = tps(mov_centroid_t)
133
+ pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
134
+
135
+ tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
136
+ dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
137
+
138
+ plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
139
+ os.makedirs(plot_file, exist_ok=True)
140
+ plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
141
+ plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
142
+
143
+
144
+ iterator.set_description('{}: Saving data'.format(fn))
145
+ df = df.append({'DATASET': dataset_name,
146
+ 'ITERATIONS_DEF': deform_reg_def.iteration,
147
+ 'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
148
+ 'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
149
+ 'TIME_DEF': time_def,
150
+ 'TIME_R_DEF': time_r_def__r + time_r_def__def,
151
+ 'Q_DEF': deform_reg_def.diff,
152
+ 'Q_R_DEF__R': rigid_reg_r_def.q,
153
+ 'Q_R_DEF__DEF': deform_reg_r_def.diff,
154
+ 'ILL_COND_DEF': ill_cond_def,
155
+ 'ILL_COND_R_DEF': ill_cond_r_def,
156
+ 'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
157
+ 'DS_DISP':euclidean(mov_centroid, fix_centroid),
158
+ 'DATA_PATH': file_path,
159
+ 'DIST_CENTR': np.min(dist_centroid_to_pts),
160
+ 'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
161
+ 'SAMPLE_NUM':fnum}, ignore_index=True)
162
+ pts_file.close()
163
+
164
+ df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
Centerline/get_vessels.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
2
+ from tqdm import tqdm
3
+ import os
4
+ import h5py
5
+ import DeepDeformationMapRegistration.utils.constants as C
6
+
7
+ DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
8
+ DATASET_NAMES = ['Affine', 'None', 'Translation']
9
+ DATASET_FILENAME = 'volume'
10
+
11
+
12
+ if __name__ == '__main__':
13
+ for dataset_name in DATASET_NAMES:
14
+ dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
15
+ dataset_files = os.listdir(dataset_loc)
16
+ dataset_files.sort()
17
+ dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
18
+
19
+ iterator = tqdm(dataset_files)
20
+ for fn in iterator:
21
+ f = os.path.split(fn)[-1].split('.hd5')[0]
22
+ vol_file = h5py.File(fn, 'r')
23
+ fix_vessels = vol_file[C.H5_FIX_VESSELS_MASK][..., 0]
24
+ mov_vessels = vol_file[C.H5_MOV_VESSELS_MASK][..., 0]
25
+
26
+ dst_folder = os.path.join(os.getcwd(), 'VESSELS', dataset_name)
27
+ os.makedirs(dst_folder, exist_ok=True)
28
+ save_nifti(fix_vessels, os.path.join(dst_folder, f+'_fix.nii.gz'))
29
+ save_nifti(mov_vessels, os.path.join(dst_folder, f+'_mov.nii.gz'))
30
+ vol_file.close()
Centerline/graph_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import numpy as np
3
+ from scipy.interpolate import RegularGridInterpolator, LinearNDInterpolator
4
+
5
+
6
+ def graph_to_ndarray(graph):
7
+ out_nodes = np.empty((1, 3))
8
+ out_edges = np.empty((1, 3))
9
+ visited_nodes = list()
10
+ visited_node_pairs = list()
11
+ for (start_node, end_node) in graph.edges():
12
+ if (not (start_node, end_node) in visited_node_pairs) and (not (end_node, start_node) in visited_node_pairs):
13
+ edge = graph[start_node][end_node]['pts']
14
+ out_edges = np.vstack([out_edges, edge])
15
+
16
+ # Avoid duplicates
17
+ if not (start_node in visited_nodes):
18
+ out_nodes = np.vstack([out_nodes, graph.nodes[start_node]['o']])
19
+ visited_nodes.append(start_node)
20
+ if not (end_node in visited_nodes):
21
+ out_nodes = np.vstack([out_nodes, graph.nodes[end_node]['o']])
22
+ visited_nodes.append(end_node)
23
+
24
+ visited_node_pairs.append((start_node, end_node))
25
+
26
+ return np.vstack([out_edges, out_nodes]), out_nodes, out_edges
27
+
28
+
29
+ def get_bifurcation_nodes(graph: nx.Graph):
30
+ # Vertex degree relates to the number of branches connected to a given node
31
+ out_nodes = np.empty((1, 3))
32
+ bif_nodes_id = list()
33
+ for node_num, deg in graph.degree:
34
+ if deg > 1:
35
+ bif_nodes_id.append(node_num)
36
+ out_nodes = np.vstack([out_nodes, graph.nodes[node_num]['o']])
37
+
38
+ return out_nodes, bif_nodes_id
39
+
40
+
41
+ def apply_displacement(pts_list: np.ndarray, interpolator: [RegularGridInterpolator, LinearNDInterpolator]):
42
+ pts_list = pts_list.astype(np.float)
43
+ ret_val = pts_list + interpolator(pts_list).squeeze()
44
+ return ret_val
45
+
46
+
47
+ def deform_graph(graph, dm_interpolator: [RegularGridInterpolator, LinearNDInterpolator]):
48
+ def_graph = nx.Graph()
49
+ for (start_node, end_node) in graph.edges():
50
+ edge = graph[start_node][end_node]['pts']
51
+ def_edge = apply_displacement(edge, dm_interpolator)
52
+
53
+ def_start_node_pts = apply_displacement(graph.nodes[start_node]['pts'], dm_interpolator)
54
+ def_end_node_pts = apply_displacement(graph.nodes[end_node]['pts'], dm_interpolator)
55
+
56
+ def_start_node_o = apply_displacement(graph.nodes[start_node]['o'], dm_interpolator)
57
+ def_end_node_o = apply_displacement(graph.nodes[end_node]['o'], dm_interpolator)
58
+
59
+ def_graph.add_node(start_node, pts=def_start_node_pts, o=def_start_node_o)
60
+ def_graph.add_node(end_node, pts=def_end_node_pts, o=def_end_node_o)
61
+ def_graph.add_edge(start_node, end_node, pts=def_edge, weight=len(def_edge))
62
+ return def_graph
63
+
64
+
65
+ def subsample_graph(graph: nx.Graph, num_samples=3):
66
+ sub_graph = nx.Graph()
67
+ for (start_node, end_node) in graph.edges():
68
+ edge = graph[start_node][end_node]['pts']
69
+ edge_len = edge.shape[0]
70
+ sub_edge_len = (edge_len - 2) // num_samples # Do not count the pts corresponding to the nodes (-2)
71
+
72
+ sub_edge = [edge[0]]
73
+ include_last = bool((edge_len - 2) % num_samples) # Skip the last point, as this is too close to the node
74
+ if sub_edge_len:
75
+ idxs = np.arange(0, edge_len, num_samples)[1:] if include_last else np.arange(0, edge_len, num_samples)[1:-1]
76
+ for i in idxs:
77
+ sub_edge.append(edge[i])
78
+
79
+ sub_edge.append(edge[-1])
80
+ sub_edge = np.asarray(sub_edge)
81
+ sub_graph.add_node(start_node, pts=graph.nodes[start_node]['pts'], o=graph.nodes[start_node]['o'])
82
+ sub_graph.add_node(end_node, pts=graph.nodes[end_node]['pts'], o=graph.nodes[end_node]['o'])
83
+ sub_graph.add_edge(start_node, end_node, pts=sub_edge, weight=len(sub_edge))
84
+ return sub_graph
85
+
Centerline/skeleton_to_graph.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SRC: https://github.com/Image-Py/sknw/blob/master/sknw/sknw.py
2
+ import numpy as np
3
+ import networkx as nx
4
+ from Centerline.graph_utils import subsample_graph
5
+
6
+
7
+ def neighbors(shape):
8
+ dim = len(shape)
9
+ block = np.ones([3] * dim)
10
+ block[tuple([1] * dim)] = 0
11
+ idx = np.where(block > 0)
12
+ idx = np.array(idx, dtype=np.uint8).T
13
+ idx = np.array(idx - [1] * dim)
14
+ acc = np.cumprod((1,) + shape[::-1][:-1])
15
+ return np.dot(idx, acc[::-1])
16
+
17
+
18
+ # my mark
19
+ def mark(img, nbs): # mark the array use (0, 1, 2)
20
+ img = img.ravel()
21
+ for p in range(len(img)):
22
+ if img[p] == 0: continue
23
+ s = 0
24
+ for dp in nbs:
25
+ if img[p + dp] != 0: s += 1
26
+ if s == 2:
27
+ img[p] = 1
28
+ else:
29
+ img[p] = 2
30
+
31
+
32
+ # trans index to r, c...
33
+ def idx2rc(idx, acc):
34
+ rst = np.zeros((len(idx), len(acc)), dtype=np.int16)
35
+ for i in range(len(idx)):
36
+ for j in range(len(acc)):
37
+ rst[i, j] = idx[i] // acc[j]
38
+ idx[i] -= rst[i, j] * acc[j]
39
+ rst -= 1
40
+ return rst
41
+
42
+
43
+ # fill a node (may be two or more points)
44
+ def fill(img, p, num, nbs, acc, buf):
45
+ back = img[p]
46
+ img[p] = num
47
+ buf[0] = p
48
+ cur = 0;
49
+ s = 1;
50
+
51
+ while True:
52
+ p = buf[cur]
53
+ for dp in nbs:
54
+ cp = p + dp
55
+ if img[cp] == back:
56
+ img[cp] = num
57
+ buf[s] = cp
58
+ s += 1
59
+ cur += 1
60
+ if cur == s: break
61
+ return idx2rc(buf[:s], acc)
62
+
63
+
64
+ # trace the edge and use a buffer, then buf.copy, if use [] numba not works
65
+ def trace(img, p, nbs, acc, buf):
66
+ c1 = 0;
67
+ c2 = 0;
68
+ newp = 0
69
+ cur = 0
70
+
71
+ while True:
72
+ buf[cur] = p
73
+ img[p] = 0
74
+ cur += 1
75
+ for dp in nbs:
76
+ cp = p + dp
77
+ if img[cp] >= 10:
78
+ if c1 == 0:
79
+ c1 = img[cp]
80
+ else:
81
+ c2 = img[cp]
82
+ if img[cp] == 1:
83
+ newp = cp
84
+ p = newp
85
+ if c2 != 0: break
86
+ return (c1 - 10, c2 - 10, idx2rc(buf[:cur], acc))
87
+
88
+
89
+ # parse the image then get the nodes and edges
90
+ def parse_struc(img, pts, nbs, acc):
91
+ img = img.ravel()
92
+ buf = np.zeros(131072, dtype=np.int64)
93
+ num = 10
94
+ nodes = []
95
+ for p in pts:
96
+ if img[p] == 2:
97
+ nds = fill(img, p, num, nbs, acc, buf)
98
+ num += 1
99
+ nodes.append(nds)
100
+
101
+ edges = []
102
+ for p in pts:
103
+ for dp in nbs:
104
+ if img[p + dp] == 1:
105
+ edge = trace(img, p + dp, nbs, acc, buf)
106
+ edges.append(edge)
107
+ return nodes, edges
108
+
109
+
110
+ # use nodes and edges build a networkx graph
111
+ def build_graph(nodes, edges, multi=False):
112
+ graph = nx.MultiGraph() if multi else nx.Graph()
113
+ for i in range(len(nodes)):
114
+ graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0))
115
+ for s, e, pts in edges:
116
+ l = np.linalg.norm(pts[1:] - pts[:-1], axis=1).sum()
117
+ graph.add_edge(s, e, pts=pts, weight=l)
118
+ return graph
119
+
120
+
121
+ def buffer(ske):
122
+ buf = np.zeros(tuple(np.array(ske.shape) + 2), dtype=np.uint16)
123
+ buf[tuple([slice(1, -1)] * buf.ndim)] = ske
124
+ return buf
125
+
126
+
127
+ def build_sknw(ske, multi=False):
128
+ buf = buffer(ske)
129
+ nbs = neighbors(buf.shape)
130
+ acc = np.cumprod((1,) + buf.shape[::-1][:-1])[::-1]
131
+ mark(buf, nbs)
132
+ pts = np.array(np.where(buf.ravel() == 2))[0]
133
+ nodes, edges = parse_struc(buf, pts, nbs, acc)
134
+ return build_graph(nodes, edges, multi)
135
+
136
+
137
+ # draw the graph
138
+ def draw_graph(img, graph, cn=255, ce=128):
139
+ acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1]
140
+ img = img.ravel()
141
+ for idx in graph.nodes():
142
+ pts = graph.nodes[idx]['pts']
143
+ img[np.dot(pts, acc)] = cn
144
+ for (s, e) in graph.edges():
145
+ eds = graph[s][e]
146
+ for i in eds:
147
+ pts = eds[i]['pts']
148
+ img[np.dot(pts, acc)] = ce
149
+
150
+
151
+ def get_graph_from_skeleton(mask, subsample=False):
152
+ graph = build_sknw(mask, False)
153
+ if len(graph.nodes) > 1 and len(graph.edges) and subsample:
154
+ graph = subsample_graph(graph, 3)
155
+ return graph
156
+
157
+
158
+ if __name__ == '__main__':
159
+ g = nx.MultiGraph()
160
+ g.add_nodes_from([1, 2, 3, 4, 5])
161
+ g.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 5), (5, 4)])
162
+ print(g.nodes())
163
+ print(g.edges())
164
+ a = g.subgraph(1)
165
+ print('d')
166
+ print(a)
167
+ print('d')
Centerline/skeletonization.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import numpy as np
3
+ import nibabel as nib
4
+ from scipy import ndimage as ndi
5
+ from scipy.signal import convolve
6
+ from numpy.linalg import norm
7
+ import networkx as nx
8
+ import logging
9
+ import traceback
10
+ import timeit
11
+ import time
12
+ import math
13
+ from ast import literal_eval as make_tuple
14
+ from skimage.measure import label
15
+ import subprocess
16
+ import platform
17
+ import glob
18
+
19
+
20
+ def loadVolume(volumeFolderPath, volumeName):
21
+ """
22
+ Load nifti files (*.nii or *.nii.gz).
23
+ Parameters
24
+ ----------
25
+ volumeFolderPath : str
26
+ Folder of the volume file.
27
+ volumeName : str
28
+ Name of the volume file.
29
+
30
+ Returns
31
+ -------
32
+ volume : ndarray
33
+ Volume data in the form of numpy ndarray.
34
+ affine : ndarray
35
+ Associated affine transformation matrix in the form of numpy ndarray.
36
+ """
37
+ volumeFilePath = os.path.join(volumeFolderPath, volumeName)
38
+ volumeImg = nib.load(volumeFilePath)
39
+ volume = volumeImg.get_data()
40
+ shape = volume.shape
41
+ affine = volumeImg.affine
42
+ print('Volume loaded from {} with shape = {}.'.format(volumeFilePath, shape))
43
+
44
+ return volume, affine
45
+
46
+
47
+ def saveVolume(volume, affine, path, astype=None):
48
+ """
49
+ Save the given volume to the specified location in specified data type.
50
+ Parameters
51
+ ----------
52
+ volume : ndarray
53
+ Volume data to be saved.
54
+ affine : ndarray
55
+ The affine transformation matrix associated with the volume.
56
+ path : str
57
+ The absolute path where the volume is going to be saved.
58
+ astype : numpy dtype, optional
59
+ The desired data type of the volume data.
60
+ """
61
+ if astype is None:
62
+ astype = np.uint8
63
+
64
+ nib.save(nib.Nifti1Image(volume.astype(astype), affine), path)
65
+ print('Volume saved to {} as type {}.'.format(path, astype))
66
+
67
+
68
+ def labelVolume(volume, minSize=1, maxHop=3):
69
+ """
70
+ Partition the volume into several connected components and attach labels.
71
+ Parameters
72
+ ----------
73
+ volume : ndarray
74
+ Volume to be partitioned.
75
+ minSize : int, optional
76
+ The connected component that is less than this size will be disgarded.
77
+ maxHop : int, optional
78
+ Controls how neighboring voxels are defined. See `label` doc for details.
79
+
80
+ Returns
81
+ -------
82
+ labeled : ndarray
83
+ The partitioned and labeled volume. Each connected component has a label (a positive integer) and the background
84
+ is labeled as 0.
85
+ labelResult : list
86
+ In the form of [[label1, size1], [label2, size2], ...]
87
+ """
88
+ labeled, maxNum = label(volume, return_num=True, connectivity=maxHop)
89
+ counts = np.bincount(labeled.ravel())
90
+ countLoc = np.nonzero(counts)[0]
91
+ sizeList = counts[countLoc]
92
+ labelResult = list(zip(countLoc[sizeList >= minSize], sizeList[sizeList >= minSize]))
93
+ # print(labelResult)
94
+ # print('Total segments: {}'.format(np.count_nonzero(sizeList >= minSize)))
95
+ return labeled, labelResult
96
+
97
+
98
+ def analyze(vesselVolumeMask, baseFolder):
99
+ """
100
+ Main function to provoke the skeletonization process. Note that here I am using the docker version of the code. If
101
+ you have already downloaded the original C++ code and successfully compiled it, then you can run that compiled code
102
+ instead of this one.
103
+ """
104
+ vesselVolumeMask = vesselVolumeMask.astype(np.uint8)
105
+ vesselVolumeMask[vesselVolumeMask != 0] = 1
106
+ vesselVolumeMask = np.swapaxes(vesselVolumeMask, 0, 2)
107
+ shape = vesselVolumeMask.shape
108
+
109
+ vesselVolumeMaskLabeled, vesselVolumeMaskLabelResult = labelVolume(vesselVolumeMask, minSize=1)
110
+ directory = os.path.join(baseFolder, 'skeletonizationResult')
111
+ if not os.path.exists(directory):
112
+ os.makedirs(directory)
113
+ print('Directory {} created.'.format(directory))
114
+
115
+ vesselVolumeMaskLabelInfoFilename = 'vesselVolumeMaskLabelInfo.npz'
116
+ vesselVolumeMaskLabelInfoFilePath = os.path.join(directory, vesselVolumeMaskLabelInfoFilename)
117
+ np.savez_compressed(vesselVolumeMaskLabelInfoFilePath, vesselVolumeMaskLabeled=vesselVolumeMaskLabeled,
118
+ vesselVolumeMaskLabelResult=vesselVolumeMaskLabelResult)
119
+ print('{} saved to {}.'.format(vesselVolumeMaskLabelInfoFilename, vesselVolumeMaskLabelInfoFilePath))
120
+
121
+ # directory2 = directory + 'labelNum=' + str(labelNum) + '/'
122
+ # if not os.path.exists(directory2):
123
+ # os.makedirs(directory2)
124
+ # with open(directory2 + 'BB.txt', 'w') as f1:
125
+ # f1.write('1\n')
126
+ # f1.write('{} {} {}\n'.format(0, 0, 0))
127
+ # f1.write('{} {} {}'.format(*shape))
128
+ # '''
129
+ BBFilePath = os.path.join(directory, 'BB.txt')
130
+ f1 = open(BBFilePath, 'w')
131
+ f1.write('1\n')
132
+ f1.write('{} {} {}\n'.format(0, 0, 0))
133
+ f1.write('{} {} {}'.format(*shape))
134
+ f1.close()
135
+
136
+ vesselCoords = np.array(np.where(vesselVolumeMask)).T
137
+ xyzFilePath = os.path.join(directory, 'xyz.txt')
138
+ np.savetxt(xyzFilePath, vesselCoords, fmt='%1u')
139
+ f2 = open(xyzFilePath, "r")
140
+ contents = f2.readlines()
141
+ f2.close()
142
+
143
+ contents.insert(0, '{}\n'.format(len(vesselCoords)))
144
+
145
+ f2 = open(xyzFilePath, "w")
146
+ contents = "".join(contents)
147
+ f2.write(contents)
148
+ f2.close()
149
+ # '''
150
+
151
+ # '''
152
+ currentPlatform = platform.system()
153
+ print('Current platform is {}.'.format(currentPlatform))
154
+ if currentPlatform == 'Windows':
155
+ cmd = '"C:/Program Files/Docker/Docker/Resources/bin/docker.exe" run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
156
+ elif currentPlatform == 'Darwin':
157
+ cmd = 'docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2'
158
+ elif currentPlatform == 'Linux':
159
+ cmd = '/usr/local/bin/docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
160
+ cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
161
+ cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2'
162
+
163
+ print('cmd={}'.format(cmd))
164
+ subprocess.call(cmd, shell=True)
165
+ # '''
166
+
167
+
168
+ def combineSkeletonSegments(skeletonSegmentFolderPath):
169
+ """
170
+ Collect and combine the results from the skeletonization.
171
+ Parameters
172
+ ----------
173
+ skeletonSegmentFolderPath : str
174
+ The folder that contains the segments information (result_segments_xyz*.txt).
175
+
176
+ Returns
177
+ -------
178
+ segmentList : list
179
+ A list containing the segment information. Each sublist represents a segment and each element in the sublist
180
+ represents a centerpoint coordinates.
181
+ """
182
+ segmentList = []
183
+ files = glob.glob(os.path.join(skeletonSegmentFolderPath, 'result_segments_xyz*.txt'))
184
+ for segmentFile in files:
185
+ result = readSegmentFile(segmentFile)
186
+ segmentList += result
187
+
188
+ return segmentList
189
+
190
+
191
+ def readSegmentFile(segmentFile):
192
+ """
193
+ Parse the segment files (result_segments_xyz*.txt) and return segments information in a list.
194
+ Parameters
195
+ ----------
196
+ segmentFile : str
197
+ Path to the segment file.
198
+
199
+ Returns
200
+ -------
201
+ segmentList : list
202
+ A list containing the segment information. Each sublist represents a segment and each element in the sublist
203
+ represents a centerpoint coordinates.
204
+ """
205
+ isFirstLine = True
206
+ isSegmentLength = True
207
+ segmentList = []
208
+ with open(segmentFile) as f:
209
+ for line in f:
210
+ if isFirstLine:
211
+ numOfSegments = int(line)
212
+ isFirstLine = False
213
+ else:
214
+ if isSegmentLength:
215
+ segmentLength = int(line)
216
+ isSegmentLength = False
217
+ segmentCounter = 1
218
+ segment = []
219
+ else:
220
+ if segmentCounter <= segmentLength:
221
+ voxel = tuple([int(x) for x in line.split(' ')])
222
+ segment.append(voxel[::-1])
223
+ segmentCounter += 1
224
+ else:
225
+ segmentCounter += 1
226
+ isSegmentLength = True
227
+ segmentList.append(segment)
228
+ assert (len(segment) == segmentLength)
229
+
230
+ return segmentList
231
+
232
+
233
+ # def drawSegments(segmentList):
234
+ # pass
235
+
236
+ def processSegments(segmentList, shape):
237
+ """
238
+ Re-partition the segments so that each segment is a simple branch, i.e., it does not contain bifurcation point
239
+ unless at the two ends.
240
+ Note that this function might be replaced by another more concise function `getSegmentList`.
241
+ Parameters
242
+ ----------
243
+ segmentList : list
244
+ A list containing the segment information. Each sublist represents a segment and each element in the sublist
245
+ represents a centerpoint coordinates.
246
+ shape : tuple
247
+ Shape of the vessel volume (used for ploting).
248
+
249
+ Returns
250
+ -------
251
+ G : NetworkX graph
252
+ A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch.
253
+ segmentList : list
254
+ A list containing the segment information. Each sublist represents a segment and each element in the sublist
255
+ represents a centerpoint coordinates.
256
+ errorSegments : list
257
+ A list that contains segments that cannot be fixed.
258
+ """
259
+ ## Import pyqtgraph ##
260
+ from pyqtgraph.Qt import QtCore, QtGui
261
+ import pyqtgraph as pg
262
+ import pyqtgraph.opengl as gl
263
+
264
+ ## Init ##
265
+ app = pg.QtGui.QApplication([])
266
+ w = gl.GLViewWidget()
267
+ w.opts['distance'] = 800
268
+ w.setGeometry(0, 110, 1600, 900)
269
+ offset = np.array(shape) / (-2.0)
270
+
271
+ G = nx.Graph()
272
+ colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')]
273
+ colorPointer = 0
274
+ skeleton = np.full(shape, 0)
275
+ for segment in segmentList:
276
+ # G.add_path(list(map(tuple, segment)))
277
+ G.add_path(segment)
278
+ segmentCoords = np.array(segment)
279
+ skeleton[tuple(segmentCoords.T)] = 1
280
+ # segmentCoordsView = segmentCoords + offset
281
+ # aa = gl.GLLinePlotItem(pos=segmentCoordsView, color=colorList[colorPointer], width=3)
282
+ # w.addItem(aa)
283
+ # colorPointer = colorPointer + 1 if colorPointer < len(colorList) - 1 else 0
284
+
285
+ # skeletonCoords = np.array(np.where(skeleton)).T
286
+ # skeletonCoordsView = (skeletonCoords + offset) * affineTransform
287
+ # aa = gl.GLScatterPlotItem(pos=skeletonCoordsView, size=5)
288
+ # w.addItem(aa)
289
+
290
+ # w.show()
291
+
292
+ voxelDegrees = np.array([v for _, v in G.degree(G.nodes())])
293
+ maxVoxelDegree = np.amax(voxelDegrees)
294
+ voxelDegreesZippedResult = list(zip(np.arange(maxVoxelDegree + 1), np.bincount(voxelDegrees)))
295
+ print('Voxel degree distribution is \n{}'.format(voxelDegreesZippedResult))
296
+ print('Number of cycles is {}'.format(len(nx.cycle_basis(G))))
297
+
298
+ # Remove duplicate segments
299
+ keepList = np.full((len(segmentList),), True)
300
+ duplicateCounter = 0
301
+ for idx, seg in enumerate(segmentList):
302
+ for idx2, seg2 in enumerate(segmentList[idx + 1:]):
303
+ if seg == seg2 or seg == seg2[::-1]:
304
+ keepList[idx + idx2] = False
305
+ duplicateCounter += 1
306
+
307
+ segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
308
+ print('{} duplicate segments removed!'.format(duplicateCounter))
309
+
310
+ # Cut segments into sub-pieces if there are bifurcation points in the middle
311
+ extraSegments = []
312
+ keepList = np.full((len(segmentList),), True)
313
+ for idx, segment in enumerate(segmentList):
314
+ voxelDegrees = np.array([v for _, v in G.degree(segment)])
315
+ if len(voxelDegrees) >= 3:
316
+ if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or (not np.all(voxelDegrees[1:-1] == 2)):
317
+ keepList[idx] = False
318
+ locs = np.nonzero(voxelDegrees != 2)[0]
319
+ if voxelDegrees[0] == 2:
320
+ locs = np.hstack((0, locs))
321
+
322
+ if voxelDegrees[-1] == 2:
323
+ locs = np.hstack((locs, len(voxelDegrees)))
324
+
325
+ newSegments = []
326
+ for ii in range(len(locs) - 1):
327
+ newSegments.append(segment[locs[ii]:(locs[ii + 1] + 1)])
328
+
329
+ extraSegments += newSegments
330
+
331
+ segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
332
+ segmentList += extraSegments
333
+
334
+ # Remove duplicate segments again
335
+ keepList = np.full((len(segmentList),), True)
336
+ duplicateCounter = 0
337
+ for idx, seg in enumerate(segmentList):
338
+ for idx2, seg2 in enumerate(segmentList[idx + 1:]):
339
+ if seg == seg2 or seg == seg2[::-1]:
340
+ keepList[idx + idx2] = False
341
+ duplicateCounter += 1
342
+
343
+ segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
344
+ print('{} duplicate segments removed in the second stage!'.format(duplicateCounter))
345
+
346
+ # Remove segment if it is completely contained in another segment
347
+ # keepList = np.full((len(segmentList),), True)
348
+ # sublistCounter = 0
349
+ # for idx, seg in enumerate(segmentList):
350
+ # for idx2, seg2 in enumerate(segmentList[idx + 1:]):
351
+ # if contains(seg, seg2):
352
+ # keepList[idx] = False
353
+ # sublistCounter += 1
354
+ # elif contains(seg2, seg):
355
+ # keepList[idx + idx2] = False
356
+ # sublistCounter += 1
357
+
358
+ # segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
359
+ # print('{} sublist segments removed!'.format(sublistCounter))
360
+
361
+ # Treat the segment if either end is not correct
362
+ hasInvalidSegments = False
363
+ for idx, segment in enumerate(segmentList):
364
+ voxelDegrees = np.array([v for _, v in G.degree(segment)])
365
+ if len(voxelDegrees) == 2:
366
+ if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
367
+ # print('Degrees on either end is 2: {}'.format(voxelDegrees))
368
+ hasInvalidSegments = True
369
+ elif len(voxelDegrees) > 2:
370
+ if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2):
371
+ # print('Degrees not correct: {}'.format(voxelDegrees))
372
+ hasInvalidSegments = True
373
+
374
+ if not hasInvalidSegments:
375
+ drawSegments(segmentList, shape)
376
+ print('No errors!')
377
+ errorSegments = []
378
+ return G, segmentList, errorSegments
379
+
380
+ iterCounter = 1
381
+ while hasInvalidSegments:
382
+ print('\n\nIter={}'.format(iterCounter))
383
+ keepList = np.full((len(segmentList),), True)
384
+ extraSegments = []
385
+ for idx, segment in enumerate(segmentList):
386
+ if keepList[idx]:
387
+ voxelDegrees = np.array([v for _, v in G.degree(segment)])
388
+ if voxelDegrees[0] == 2 and voxelDegrees[-1] == 2:
389
+ print('Both end have 2 neighbours')
390
+ elif voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
391
+ # print('Degrees on either end is 2: {}'.format(voxelDegrees))
392
+ # pass
393
+ # segmentCoords = np.array(segment)
394
+ if voxelDegrees[0] == 2:
395
+ otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
396
+ (seg[0] == segment[0] or seg[-1] == segment[0]) and keepList[
397
+ idx2] and idx != idx2]
398
+ if len(otherSegmentInfo) != 0:
399
+ if len(otherSegmentInfo) > 1:
400
+ # print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment))
401
+ otherSegmentInfoTemp = []
402
+ for idx2, seg in otherSegmentInfo:
403
+ if contains(segment, seg) or contains(segment[::-1], seg):
404
+ keepList[idx] = False
405
+ continue
406
+ elif contains(seg, segment) or contains(seg[::-1], segment):
407
+ keepList[idx2] = False
408
+ otherSegmentInfoTemp.append((idx2, seg))
409
+
410
+ otherSegmentInfo = otherSegmentInfoTemp
411
+ # otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))]
412
+ if len(otherSegmentInfo) > 1:
413
+ print('More than one other segments found!')
414
+ print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees))
415
+ for otherSegmentIdx, otherSegment in otherSegmentInfo:
416
+ otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)])
417
+ print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment,
418
+ otherSegmentVoxelDegrees))
419
+ elif len(otherSegmentInfo) == 1:
420
+ otherSegmentIdx, otherSegment = otherSegmentInfo[0]
421
+ else:
422
+ print('No valid other segments found!')
423
+ continue
424
+ else:
425
+ otherSegmentIdx, otherSegment = otherSegmentInfo[0]
426
+ if contains(segment, otherSegment) or contains(segment[::-1], otherSegment):
427
+ keepList[idx] = False
428
+ continue
429
+ elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment):
430
+ keepList[otherSegmentIdx] = False
431
+ continue
432
+
433
+ newSegment = otherSegment + segment[1:] if otherSegment[-1] == segment[0] else otherSegment[
434
+ ::-1] + segment[
435
+ 1:]
436
+ if not validateSegment(G, newSegment):
437
+ newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)])
438
+ print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees,
439
+ newSegmentVoxelDegrees))
440
+ else:
441
+ print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx))
442
+
443
+ extraSegments.append(newSegment)
444
+ keepList[idx] = False
445
+ keepList[otherSegmentIdx] = False
446
+ else:
447
+ print(
448
+ 'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment,
449
+ voxelDegrees))
450
+ possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
451
+ (seg[0] == segment[0] or seg[-1] == segment[0]) and idx != idx2]
452
+ print('Possible segments: {}'.format(len(possibleSegmentsInfo)))
453
+
454
+ elif voxelDegrees[-1] == 2:
455
+ otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
456
+ (seg[0] == segment[-1] or seg[-1] == segment[-1]) and keepList[
457
+ idx2] and idx != idx2]
458
+ if len(otherSegmentInfo) != 0:
459
+ if len(otherSegmentInfo) > 1:
460
+ # print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment))
461
+ otherSegmentInfoTemp = []
462
+ for idx2, seg in otherSegmentInfo:
463
+ if contains(segment, seg) or contains(segment[::-1], seg):
464
+ keepList[idx] = False
465
+ continue
466
+ elif contains(seg, segment) or contains(seg[::-1], segment):
467
+ keepList[idx2] = False
468
+ otherSegmentInfoTemp.append((idx2, seg))
469
+
470
+ otherSegmentInfo = otherSegmentInfoTemp
471
+ # otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))]
472
+ if len(otherSegmentInfo) > 1:
473
+ print('More than one other segments found!')
474
+ print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees))
475
+ for otherSegmentIdx, otherSegment in otherSegmentInfo:
476
+ otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)])
477
+ print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment,
478
+ otherSegmentVoxelDegrees))
479
+ elif len(otherSegmentInfo) == 1:
480
+ otherSegmentIdx, otherSegment = otherSegmentInfo[0]
481
+ else:
482
+ print('No valid other segments found!')
483
+ continue
484
+ else:
485
+ otherSegmentIdx, otherSegment = otherSegmentInfo[0]
486
+ if contains(segment, otherSegment) or contains(segment[::-1], otherSegment):
487
+ keepList[idx] = False
488
+ continue
489
+ elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment):
490
+ keepList[otherSegmentIdx] = False
491
+ continue
492
+
493
+ newSegment = segment[:-1] + otherSegment if otherSegment[0] == segment[-1] else segment[
494
+ :-1] + otherSegment[
495
+ ::-1]
496
+ if not validateSegment(G, newSegment):
497
+ newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)])
498
+ print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees,
499
+ newSegmentVoxelDegrees))
500
+ else:
501
+ print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx))
502
+
503
+ extraSegments.append(newSegment)
504
+ keepList[idx] = False
505
+ keepList[otherSegmentIdx] = False
506
+ else:
507
+ print(
508
+ 'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment,
509
+ voxelDegrees))
510
+ possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
511
+ (seg[0] == segment[-1] or seg[-1] == segment[-1]) and idx != idx2]
512
+ print('Possible segments: {}'.format(len(possibleSegmentsInfo)))
513
+
514
+ segmentList = [segment for idx, segment in enumerate(segmentList) if keepList[idx]]
515
+ segmentList += extraSegments
516
+ hasInvalidSegments = False
517
+ errorSegments = []
518
+ for idx, segment in enumerate(segmentList):
519
+ voxelDegrees = np.array([v for _, v in G.degree(segment)])
520
+ if len(voxelDegrees) == 2:
521
+ if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
522
+ print('Degrees on either end is 2: {}'.format(voxelDegrees))
523
+ hasInvalidSegments = True
524
+ errorSegments.append(segment)
525
+ elif len(voxelDegrees) > 2:
526
+ if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2):
527
+ print('Degrees not correct: {}'.format(voxelDegrees))
528
+ hasInvalidSegments = True
529
+ errorSegments.append(segment)
530
+
531
+ print('hasInvalidSegments = {}'.format(hasInvalidSegments))
532
+ iterCounter += 1
533
+ if len(extraSegments) == 0:
534
+ hasInvalidSegments = False
535
+ print('While loop aborted because there is no change in segments!')
536
+
537
+ for errorSegment in errorSegments:
538
+ segmentList.remove(errorSegment)
539
+
540
+ # np.savez_compressed(directory + 'segmentList.npz', segmentList=segmentList)
541
+ # if partIdx != 10:
542
+ # nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNamePartial + str(partIdx) + '.nii.gz')
543
+ # else:
544
+ # nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNameTotal + '.nii.gz')
545
+
546
+ # nx.write_graphml(G, directory + 'graphRepresentation.graphml')
547
+
548
+ # drawAbstractGraph(offset, segmentList)
549
+ # drawAbstractGraph(offset, errorSegments)
550
+
551
+ print(errorSegments)
552
+
553
+ return G, segmentList, errorSegments
554
+
555
+
556
+ def getSegmentList(G, nodeInfoDict):
557
+ """
558
+ Generate segmentList from graph and nodeInfoDict.
559
+ Parameters
560
+ ----------
561
+ G : NetworkX graph
562
+ The graph representation of the network.
563
+ nodeInfoDict : dict
564
+ A dictionary containing the information about nodes.
565
+
566
+ Returns
567
+ -------
568
+ segmentList : list
569
+ A list of segments in which each segment is a simple branch.
570
+ """
571
+ startNodeIDList = [nodeID for nodeID in nodeInfoDict.keys() if nodeInfoDict[nodeID]['parentNodeID'] == -1]
572
+ print('startNodeIDList = {}'.format(startNodeIDList))
573
+ segmentList = []
574
+ for startNodeID in startNodeIDList:
575
+ segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID)
576
+
577
+ print('There are {} segments in segmentList'.format(len(segmentList)))
578
+ print(segmentList)
579
+ return segmentList
580
+
581
+
582
+ def getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID):
583
+ """
584
+ Implementation of `getSegmentList`. Use DFS to traverse all the segments.
585
+ Parameters
586
+ ----------
587
+ G : NetworkX graph
588
+ The graph representation of the network.
589
+ nodeInfoDict : dict
590
+ A dictionary containing the information about nodes.
591
+ segmentList : list
592
+ A list of segments in which each segment is a simple branch.
593
+ startNodeID : int
594
+ The index of the start point of a segment.
595
+
596
+ Returns
597
+ -------
598
+ segmentList : list
599
+ A list of segments in which each segment is a simple branch.
600
+ """
601
+ neighborNodeIDList = [nodeID for nodeID in list(G[startNodeID].keys()) if
602
+ 'visited' not in G[startNodeID][nodeID]] # use adjacency dict to find neighbors
603
+ newSegmentList = []
604
+ for neighborNodeID in neighborNodeIDList:
605
+ newSegment = [startNodeID, neighborNodeID]
606
+ G[startNodeID][neighborNodeID]['visited'] = True
607
+ currentNodeID = neighborNodeID
608
+ while G.degree(currentNodeID) == 2:
609
+ newNodeID = [nodeID for nodeID in G[currentNodeID].keys() if 'visited' not in G[currentNodeID][nodeID]][0]
610
+ G[currentNodeID][newNodeID]['visited'] = True
611
+ newSegment.append(newNodeID)
612
+ currentNodeID = newNodeID
613
+
614
+ newSegmentList.append(newSegment)
615
+ segmentList.append(newSegment)
616
+ segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, currentNodeID)
617
+
618
+ return segmentList
619
+
620
+
621
+ def sublist(ls1, ls2):
622
+ '''
623
+ >>> sublist([], [1,2,3])
624
+ True
625
+ >>> sublist([1,2,3,4], [2,5,3])
626
+ True
627
+ >>> sublist([1,2,3,4], [0,3,2])
628
+ False
629
+ >>> sublist([1,2,3,4], [1,2,5,6,7,8,5,76,4,3])
630
+ False
631
+ '''
632
+
633
+ def get_all_in(one, another):
634
+ for element in one:
635
+ if element in another:
636
+ yield element
637
+
638
+ for x1, x2 in zip(get_all_in(ls1, ls2), get_all_in(ls2, ls1)):
639
+ if x1 != x2:
640
+ return False
641
+
642
+ return True
643
+
644
+
645
+ def contains(lst1, lst2):
646
+ lst1, lst2 = (lst2, lst1) if len(lst1) > len(lst2) else (lst1, lst2)
647
+ if lst1[0] in lst2:
648
+ startLoc = lst2.index(lst1[0])
649
+ else:
650
+ return False
651
+
652
+ if lst1[-1] in lst2:
653
+ endLoc = lst2.index(lst1[-1])
654
+ else:
655
+ return False
656
+
657
+ if startLoc < endLoc:
658
+ if lst1 == lst2[startLoc:(endLoc + 1)]:
659
+ return True
660
+ else:
661
+ return False
662
+ else:
663
+ if lst1 == lst2[endLoc:(startLoc + 1)][::-1]:
664
+ return True
665
+ else:
666
+ return False
667
+
668
+
669
+ def validateSegment(G, segment):
670
+ """
671
+ Check whether a segment is a simple branch.
672
+ Parameters
673
+ ----------
674
+ G : NetworkX graph
675
+ A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch.
676
+ segment : list
677
+ A list containing the coordinates of the centerpoints of a segment.
678
+
679
+ Returns
680
+ -------
681
+ result : bool
682
+ If True, the segment is a simple branch.
683
+ """
684
+ voxelDegrees = np.array([v for _, v in G.degree(segment)])
685
+ if voxelDegrees[0] != 2 and voxelDegrees[-1] != 2:
686
+ if len(voxelDegrees) == 2:
687
+ result = True
688
+ elif len(voxelDegrees) > 2:
689
+ if np.all(voxelDegrees[1:-1] == 2):
690
+ result = True
691
+ else:
692
+ result = False
693
+ else:
694
+ print('Error! Segment with length 1 found!')
695
+ result = False
696
+ else:
697
+ result = False
698
+
699
+ return result
700
+
701
+
702
+ def drawSegments(segmentList, shape):
703
+ """
704
+ Plot all the segments in `segmentList`. Try to assign different colors to the segments connected to the same node.
705
+ Parameters
706
+ ----------
707
+ segmentList : list
708
+ A list containing the segment information. Each sublist represents a segment and each element in the sublist
709
+ represents a centerpoint coordinates.
710
+ shape : tuple
711
+ Shape of the vessel volume (used for ploting).
712
+ """
713
+ ## Import pyqtgraph ##
714
+ from pyqtgraph.Qt import QtCore, QtGui
715
+ import pyqtgraph as pg
716
+ import pyqtgraph.opengl as gl
717
+
718
+ ## Init ##
719
+ app = pg.QtGui.QApplication([])
720
+ w = gl.GLViewWidget()
721
+ w.opts['distance'] = 800
722
+ w.setGeometry(0, 110, 1600, 900)
723
+ offset = np.array(shape) / (-2.0)
724
+
725
+ colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')]
726
+ colorNames = ['Red', 'Green', 'Blue', 'Cyan', 'Magneta', 'Yellow']
727
+ numOfColors = len(colorList)
728
+ nodeColorDict = {}
729
+ for segment in segmentList:
730
+ startVoxel = segment[0]
731
+ endVoxel = segment[-1]
732
+ if startVoxel in nodeColorDict and endVoxel in nodeColorDict: # and endVoxel in [voxel for voxel, _ in nodeColorDict[startVoxel]]:
733
+ nodeColorDict[startVoxel].append([endVoxel, -1])
734
+ nodeColorDict[endVoxel].append([startVoxel, -1])
735
+ else:
736
+ if startVoxel not in nodeColorDict:
737
+ nodeColorDict[startVoxel] = [[endVoxel, -1]]
738
+ else:
739
+ nodeColorDict[startVoxel].append([endVoxel, -1])
740
+
741
+ if endVoxel not in nodeColorDict:
742
+ nodeColorDict[endVoxel] = [[startVoxel, -1]]
743
+ else:
744
+ nodeColorDict[endVoxel].append([startVoxel, -1])
745
+
746
+ existingColorsInStart = [colorCode for _, colorCode in nodeColorDict[startVoxel]]
747
+ existingColorsInEnd = [colorCode for _, colorCode in nodeColorDict[endVoxel]]
748
+ availableColors = [colorCode for colorCode in range(numOfColors) if
749
+ colorCode not in existingColorsInStart and colorCode not in existingColorsInEnd]
750
+ # print('color in start: {} and color in end: {}'.format(existingColorsInStart, existingColorsInEnd))
751
+ chosenColor = availableColors[0] if len(availableColors) != 0 else 0
752
+ nodeColorDict[startVoxel][-1][1] = chosenColor
753
+ nodeColorDict[endVoxel][-1][1] = chosenColor
754
+
755
+ segmentCoords = np.array(segment)
756
+ aa = gl.GLLinePlotItem(pos=segmentCoords, color=colorList[chosenColor], width=3)
757
+ aa.translate(*offset)
758
+ w.addItem(aa)
759
+
760
+ w.show()
761
+ pg.QtGui.QApplication.exec_()
762
+ # sys.exit(app.exec_())
763
+
764
+
765
+ def main():
766
+ start_time = timeit.default_timer()
767
+ baseFolder = os.path.abspath(os.path.dirname(__file__))
768
+
769
+ ## Load existing volume ##
770
+ vesselVolumeMaskFolderPath = baseFolder
771
+ vesselVolumeMaskFileName = 'vesselVolumeMask.nii.gz'
772
+ vesselVolumeMask, vesselVolumeMaskAffine = loadVolume(vesselVolumeMaskFolderPath, vesselVolumeMaskFileName)
773
+
774
+ ## Skeletonization ##
775
+ # analyze(vesselVolumeMask, baseFolder)
776
+
777
+ skeletonSegmentFolderPath = os.path.join(baseFolder, 'skeletonizationResult/segments_by_cc')
778
+ segmentListRough = combineSkeletonSegments(skeletonSegmentFolderPath)
779
+
780
+ shape = vesselVolumeMask.shape
781
+ # drawSegments(segmentListRough, shape)
782
+
783
+ G, segmentList, errorSegments = processSegments(segmentListRough, shape=shape)
784
+ # drawSegments(segmentList, shape)
785
+ G = nx.Graph()
786
+ segmentIndex = 0
787
+ for segment in segmentList:
788
+ G.add_path(segment, segmentIndex=segmentIndex)
789
+ segmentIndex += 1
790
+
791
+ ## Save graph representation ##
792
+ graphFileName = 'graphRepresentation.graphml'
793
+ graphFilePath = os.path.join(baseFolder, graphFileName)
794
+ nx.write_graphml(G, graphFilePath)
795
+ print('{} saved to {}.'.format(graphFileName, graphFilePath))
796
+
797
+ ## Save segmentList ##
798
+ segmentListFileName = 'segmentList.npz'
799
+ segmentListFilePath = os.path.join(baseFolder, segmentListFileName)
800
+ np.savez_compressed(segmentListFilePath, segmentList=segmentList)
801
+ print('{} saved to {}.'.format(segmentListFileName, segmentListFilePath))
802
+
803
+ ## Save skeleton.nii.gz ##
804
+ skeleton = np.zeros_like(vesselVolumeMask)
805
+ for segment in segmentList:
806
+ skeleton[tuple(np.array(segment).T)] = 1
807
+
808
+ skeletonFileName = 'skeleton.nii.gz'
809
+ skeletonFilePath = os.path.join(baseFolder, skeletonFileName)
810
+ saveVolume(skeleton, vesselVolumeMaskAffine, skeletonFilePath, astype=np.uint8)
811
+
812
+ elapsed = timeit.default_timer() - start_time
813
+ print('Elapsed: {} sec'.format(elapsed))
814
+
815
+
816
+ if __name__ == "__main__":
817
+ main()
Centerline/thinPlateSplines.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.distance import pdist, cdist, squareform
3
+ from sklearn.metrics import pairwise_distances
4
+
5
+ class ThinPlateSplines:
6
+ def __init__(self, ctrl_pts: np.ndarray, target_pts: np.ndarray, reg=0.0):
7
+ """
8
+
9
+ :param ctrl_pts: [N, d] tensor of control d-dimensional points
10
+ :param target_pts: [N, d] tensor of target d-dimensional points
11
+ :param reg: regularization coefficient
12
+ """
13
+ self.__ctrl_pts = ctrl_pts
14
+ self.__target_pts = target_pts
15
+ self.__reg = reg
16
+ self.__num_ctrl_pts = ctrl_pts.shape[0]
17
+ self.__dim = ctrl_pts.shape[1]
18
+
19
+ self.__K = None
20
+ self.__compute_coeffs()
21
+ self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...] # Affine parameters of the TPS
22
+ self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] # Non-affine parameters of he TPS
23
+
24
+ def __compute_coeffs(self):
25
+ target_pts_aug = np.vstack([self.__target_pts,
26
+ np.zeros([self.__dim + 1, self.__dim])]).astype(self.__target_pts.dtype)
27
+
28
+ T_i = np.linalg.inv(self.__make_T()).astype(self.__target_pts.dtype)
29
+ self.__coeffs = np.matmul(T_i, target_pts_aug).astype(self.__target_pts.dtype)
30
+
31
+ def __make_T(self):
32
+ # cp: [K x 2] control points
33
+ # T: [(K+3) x (K+3)]
34
+ P = np.hstack([np.ones([self.__num_ctrl_pts, 1], dtype=np.float), self.__ctrl_pts])
35
+ zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
36
+ self.__K = self.__U_dist(self.__ctrl_pts)
37
+ alfa = np.mean(self.__K)
38
+
39
+ self.__K = self.__K + np.ones_like(self.__K) * np.power(alfa, 2) * self.__reg
40
+
41
+ top = np.hstack([P, self.__K])
42
+ bottom = np.hstack([P.transpose(), zeros])
43
+
44
+ return np.vstack([top, bottom])
45
+
46
+ def __U_dist(self, ctrl_pts, int_pts=None):
47
+ dist = pairwise_distances(ctrl_pts, int_pts)
48
+
49
+ if ctrl_pts.shape[-1] == 2:
50
+ u_dist = np.square(dist) * np.log(dist + 1e-6)
51
+ else:
52
+ u_dist = np.sqrt(dist)
53
+
54
+ return u_dist
55
+
56
+ def __lift_pts(self, int_pts: np.ndarray, num_pts):
57
+ # int_pts: [N x 2], input points
58
+ # cp: [K x 2], control points
59
+ # pLift: [N x (3+K)], lifted input points
60
+
61
+ int_pts_lift = np.hstack([self.__U_dist(self.__ctrl_pts, int_pts),
62
+ np.ones([num_pts, 1], dtype=np.float),
63
+ int_pts])
64
+ return int_pts_lift
65
+
66
+ def _get_coefficients(self):
67
+ return self.__coeffs
68
+
69
+ def interpolate(self, int_points):
70
+ """
71
+
72
+ :param int_points: [K, d] flattened d-points of a mesh
73
+ :return:
74
+ """
75
+ num_pts = int_points.shape[0]
76
+ int_points_lift = self.__lift_pts(int_points, num_pts)
77
+ return np.dot(int_points_lift, self.__coeffs)
78
+
79
+ @property
80
+ def bending_energy(self):
81
+ aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
82
+ return tf.matmul(aux, self.__non_aff_paramms)
Centerline/visualization_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from mpl_toolkits.mplot3d import Axes3D
3
+ from matplotlib.lines import Line2D
4
+ import numpy as np
5
+ from DeepDeformationMapRegistration.utils.visualization import add_axes_arrows_3d, remove_tick_labels, set_axes_size
6
+ import os
7
+
8
+
9
+ def _plot_graph(graph, ax, nodes_colour='C3', edges_colour='C1', plot_nodes=True, plot_edges=True, add_axes=True):
10
+ if plot_edges:
11
+ for (start_node, end_node) in graph.edges():
12
+ edge_pts = graph[start_node][end_node]['pts']
13
+ edge_pts = np.vstack([graph.nodes[start_node]['o'], edge_pts])
14
+ edge_pts = np.vstack([edge_pts, graph.nodes[end_node]['o']])
15
+ ax.plot(edge_pts[:, 0], edge_pts[:, 1], edge_pts[:, 2], edges_colour)
16
+
17
+ if plot_nodes:
18
+ nodes = graph.nodes()
19
+ ps = np.array([nodes[i]['o'] for i in nodes])
20
+ if len(ps.shape) > 1:
21
+ ax.scatter(ps[:, 0], ps[:, 1], ps[:, 2], nodes_colour)
22
+ else:
23
+ ax.scatter(ps[0], ps[1], ps[2], nodes_colour)
24
+ ax.set_xlim(0, 63)
25
+ ax.set_ylim(0, 63)
26
+ ax.set_zlim(0, 63)
27
+ remove_tick_labels(ax, True)
28
+ if add_axes:
29
+ add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b')
30
+ ax.view_init(None, 45)
31
+
32
+ return ax
33
+
34
+
35
+ def plot_skeleton(img, skeleton, graph, filename='skeleton', extension=['.png']):
36
+ if not isinstance(extension, list):
37
+ extension = [extension]
38
+ # Skeleton
39
+ f = plt.figure(figsize=(5, 5))
40
+ ax = f.add_subplot(111, projection='3d')
41
+
42
+ coords = np.argwhere(skeleton)
43
+ i = coords[:, 0]
44
+ j = coords[:, 1]
45
+ k = coords[:, 2]
46
+
47
+ seg = ax.voxels(img, facecolors=(0., 0., 1., 0.3), label='image')
48
+ ske = ax.scatter(i, j, k, color='C1', label='skeleton', s=1)
49
+ ax.set_xlim(0, 63)
50
+ ax.set_ylim(0, 63)
51
+ ax.set_zlim(0, 63)
52
+ remove_tick_labels(ax, True)
53
+ add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b')
54
+ ax.view_init(None, 45)
55
+ for ex in extension:
56
+ f.savefig(filename + '_segmentation_skeleton' + ex)
57
+
58
+ # Combined
59
+ ax = _plot_graph(graph, ax, 'r', 'r')
60
+
61
+ for ex in extension:
62
+ f.savefig(filename + '_combined' + ex)
63
+ plt.close()
64
+
65
+ # Graph
66
+ f = plt.figure(figsize=(5, 5))
67
+ ax = f.add_subplot(111, projection='3d')
68
+
69
+ ax = _plot_graph(graph, ax)
70
+
71
+ for ex in extension:
72
+ f.savefig(filename + '_graph' + ex)
73
+ plt.close()
74
+
75
+
76
+
77
+
78
+ def compare_graphs(graph_0, graph_1, graph_names=None, filename='compare_graphs'):
79
+ f = plt.figure(figsize=(12, 5))
80
+ if graph_names is None:
81
+ graph_names =['graph_0', 'graph_1']
82
+ else:
83
+ assert len(graph_names) == 2, 'A different name is expected for each graph'
84
+ ax = f.add_subplot(131, projection='3d')
85
+ ax = _plot_graph(graph_0, ax)
86
+ ax.set_title(graph_names[0], y=-0.01)
87
+
88
+ ax = f.add_subplot(132, projection='3d')
89
+ ax = _plot_graph(graph_1, ax)
90
+ ax.set_title(graph_names[1])
91
+
92
+ ax = f.add_subplot(133, projection='3d')
93
+ ax = _plot_graph(graph_0, ax, 'C2', 'C2', plot_nodes=False)
94
+ ax = _plot_graph(graph_1, ax, 'C4', 'C4', plot_nodes=False)
95
+ legend_elements = [Line2D([0], [0], color='C2', lw=2, label=graph_names[0]),
96
+ Line2D([0], [0], color='C4', lw=2, label=graph_names[1])]
97
+ ax.legend(handles=legend_elements)
98
+
99
+ f.savefig(filename + '_compare_graphs.png')
100
+ plt.close()
101
+
102
+
103
+ def plot_cpd_registration_step(iteration, error, X, Y, out_folder, add_axes=True, pdf=True):
104
+ fig = plt.figure(figsize=(8, 8))
105
+ ax = fig.add_axes([0, 0, .9, .9], projection='3d')
106
+ ax.scatter(X[:, 0], X[:, 1], X[:, 2], color='C1', label='Fixed')
107
+ ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], color='C2', label='Moving')
108
+
109
+ ax.text2D(0.95, 0.98, 'Iteration: {:d}'.format(
110
+ iteration), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')
111
+ #ax.text2D(0.95, 0.90, 'Error: {:10.4f}'.format(
112
+ # error), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')
113
+ ax.legend(loc='upper left', fontsize='x-large')
114
+
115
+ if add_axes:
116
+ x_range = [np.min(np.hstack([X[:, 0], Y[:, 0]])), np.max(np.hstack([X[:, 0], Y[:, 0]]))]
117
+ y_range = [np.min(np.hstack([X[:, 1], Y[:, 1]])), np.max(np.hstack([X[:, 1], Y[:, 1]]))]
118
+ z_range = [np.min(np.hstack([X[:, 2], Y[:, 2]])), np.max(np.hstack([X[:, 2], Y[:, 2]]))]
119
+ ax.set_xlim(x_range[0], x_range[1])
120
+ ax.set_ylim(y_range[0], y_range[1])
121
+ ax.set_zlim(z_range[0], z_range[1])
122
+
123
+ remove_tick_labels(ax, True)
124
+ add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3)
125
+ ax.view_init(None, 45)
126
+
127
+ os.makedirs(out_folder, exist_ok=True)
128
+ fig.savefig(os.path.join(out_folder, '{:04d}.png'.format(iteration)))
129
+ if pdf:
130
+ fig.savefig(os.path.join(out_folder, '{:04d}.pdf'.format(iteration)))
131
+ plt.close()
132
+
133
+
134
+ def plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, file_name):
135
+ fig = plt.figure(figsize=(8, 8))
136
+ ax = fig.add_axes([0, 0, .9, .9], projection='3d')
137
+ ax.scatter(fix_pts[:, 0], fix_pts[:, 1], fix_pts[:, 2], color='C1', label='Fixed')
138
+ ax.scatter(mov_pts[:, 0], mov_pts[:, 1], mov_pts[:, 2], color='C2', label='Moving')
139
+ ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='none', s=100, edgecolor='b', label='Centroid')
140
+ ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='none', s=100, edgecolor='b')
141
+ ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='C1')
142
+ ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='C2')
143
+
144
+ x_range = [np.min(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]])),
145
+ np.max(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]]))]
146
+ y_range = [np.min(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]])),
147
+ np.max(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]]))]
148
+ z_range = [np.min(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]])),
149
+ np.max(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]]))]
150
+ ax.set_xlim(x_range[0], x_range[1])
151
+ ax.set_ylim(y_range[0], y_range[1])
152
+ ax.set_zlim(z_range[0], z_range[1])
153
+
154
+ remove_tick_labels(ax, True)
155
+ add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3)
156
+ ax.view_init(None, 45)
157
+ ax.legend(fontsize='xx-large')
158
+ fig.savefig(file_name + '.png')
159
+ fig.savefig(file_name + '.pdf')
160
+ plt.close()
161
+