Commit
·
b10768a
1
Parent(s):
ed5ac4a
CPD scripts
Browse files- Centerline/__init__.py +0 -0
- Centerline/centerline.py +241 -0
- Centerline/cpd_utils.py +48 -0
- Centerline/evaluate_BayesianCPD_skeleton.py +160 -0
- Centerline/evaluate_CPD_dense.py +156 -0
- Centerline/evaluate_CPD_nodes.py +158 -0
- Centerline/evaluate_CPD_skeleton.py +164 -0
- Centerline/get_vessels.py +30 -0
- Centerline/graph_utils.py +85 -0
- Centerline/skeleton_to_graph.py +167 -0
- Centerline/skeletonization.py +817 -0
- Centerline/thinPlateSplines.py +82 -0
- Centerline/visualization_utils.py +161 -0
Centerline/__init__.py
ADDED
File without changes
|
Centerline/centerline.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
# import tensorflow as tf
|
8 |
+
# tf.enable_eager_execution()
|
9 |
+
# import neurite.py.utils as neurite_utils
|
10 |
+
|
11 |
+
from skimage.morphology import skeletonize_3d, ball
|
12 |
+
from skimage.morphology import binary_closing, binary_opening
|
13 |
+
from skimage.filters import median
|
14 |
+
from skimage.measure import regionprops, label
|
15 |
+
from skimage.transform import warp
|
16 |
+
|
17 |
+
from scipy.ndimage import zoom
|
18 |
+
from scipy.interpolate import LinearNDInterpolator, Rbf
|
19 |
+
|
20 |
+
import h5py
|
21 |
+
import numpy as np
|
22 |
+
from tqdm import tqdm
|
23 |
+
import re
|
24 |
+
import nibabel as nib
|
25 |
+
from nilearn.image import resample_img
|
26 |
+
|
27 |
+
from Centerline.graph_utils import graph_to_ndarray, deform_graph, get_bifurcation_nodes, subsample_graph, \
|
28 |
+
apply_displacement
|
29 |
+
from Centerline.skeleton_to_graph import get_graph_from_skeleton
|
30 |
+
from Centerline.visualization_utils import plot_skeleton, compare_graphs
|
31 |
+
|
32 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
33 |
+
from DeepDeformationMapRegistration.utils import constants as C
|
34 |
+
|
35 |
+
import cupy
|
36 |
+
from cupyx.scipy.ndimage import zoom as zoom_gpu
|
37 |
+
from cupyx.scipy.ndimage import map_coordinates
|
38 |
+
|
39 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
40 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
41 |
+
DATASET_FILENAME = 'volume'
|
42 |
+
IMGS_FOLDER = '/home/jpdefrutos/workspace/DeepDeformationMapRegistration/Centerline/centerlines'
|
43 |
+
|
44 |
+
DATASTE_RAW_FILES = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/nifti3'
|
45 |
+
LITS_SEGMENTATION_FILE = 'segmentation'
|
46 |
+
LITS_CT_FILE = 'volume'
|
47 |
+
|
48 |
+
|
49 |
+
def warp_volume(volume, disp_map, indexing='ij'):
|
50 |
+
assert indexing is 'ij' or 'xy', 'Invalid indexing option. Only "ij" or "xy"'
|
51 |
+
grid_i = np.linspace(0, disp_map.shape[0], disp_map.shape[0], endpoint=False)
|
52 |
+
grid_j = np.linspace(0, disp_map.shape[1], disp_map.shape[1], endpoint=False)
|
53 |
+
grid_k = np.linspace(0, disp_map.shape[2], disp_map.shape[2], endpoint=False)
|
54 |
+
grid_i, grid_j, grid_k = np.meshgrid(grid_i, grid_j, grid_k, indexing=indexing)
|
55 |
+
grid_i = (grid_i.flatten() + disp_map[..., 0].flatten())[..., np.newaxis]
|
56 |
+
grid_j = (grid_j.flatten() + disp_map[..., 1].flatten())[..., np.newaxis]
|
57 |
+
grid_k = (grid_k.flatten() + disp_map[..., 2].flatten())[..., np.newaxis]
|
58 |
+
coords = np.hstack([grid_i, grid_j, grid_k]).reshape([*disp_map.shape[:-1], -1])
|
59 |
+
coords = coords.transpose((-1, 0, 1, 2))
|
60 |
+
# The returned volume has indexing xy
|
61 |
+
return warp(volume, coords)
|
62 |
+
|
63 |
+
|
64 |
+
def keep_largest_segmentation(img):
|
65 |
+
label_img = label(img)
|
66 |
+
rp = regionprops(label_img) # Regions labeled with 0 (bg) are ignored
|
67 |
+
biggest_area = (0, 0)
|
68 |
+
for l in range(0, label_img.max()):
|
69 |
+
if rp[l].area > biggest_area[1]:
|
70 |
+
biggest_area = (l + 1, rp[l].area)
|
71 |
+
img[label_img != biggest_area[0]] = 0.
|
72 |
+
return img
|
73 |
+
|
74 |
+
|
75 |
+
def preprocess_image(img, keep_largest=False):
|
76 |
+
ret = binary_closing(img, ball(1))
|
77 |
+
ret = binary_opening(ret, ball(1))
|
78 |
+
#ret = median(ret, ball(1), mode='constant')
|
79 |
+
if keep_largest:
|
80 |
+
ret = keep_largest_segmentation(ret)
|
81 |
+
return ret.astype(np.float)
|
82 |
+
|
83 |
+
|
84 |
+
def build_displacement_map_interpolator(disp_map, backwards=False, indexing='ij'):
|
85 |
+
grid_i = np.linspace(0, disp_map.shape[0], disp_map.shape[0], endpoint=False)
|
86 |
+
grid_j = np.linspace(0, disp_map.shape[1], disp_map.shape[1], endpoint=False)
|
87 |
+
grid_k = np.linspace(0, disp_map.shape[2], disp_map.shape[2], endpoint=False)
|
88 |
+
grid_i, grid_j, grid_k = np.meshgrid(grid_i, grid_j, grid_k, indexing=indexing)
|
89 |
+
grid_i = grid_i.flatten()
|
90 |
+
grid_j = grid_j.flatten()
|
91 |
+
grid_k = grid_k.flatten()
|
92 |
+
# To generate the moving image, we used backwards mapping were the input was the fix image
|
93 |
+
# Now we are doing direct mapping from the fix graph coordinates to the moving coordinates
|
94 |
+
# The application points of the displacement map are thus the transformed "moving image"-grid
|
95 |
+
# and the displacement vectors are reversed
|
96 |
+
if backwards:
|
97 |
+
coords = np.hstack([grid_i[..., np.newaxis], grid_j[..., np.newaxis], grid_k[..., np.newaxis]])
|
98 |
+
return LinearNDInterpolator(coords, np.reshape(disp_map, [-1, 3]))
|
99 |
+
else:
|
100 |
+
grid_i = (grid_i + disp_map[..., 0].flatten())
|
101 |
+
grid_j = (grid_j + disp_map[..., 1].flatten())
|
102 |
+
grid_k = (grid_k + disp_map[..., 2].flatten())
|
103 |
+
|
104 |
+
coords = np.hstack([grid_i[..., np.newaxis], grid_j[..., np.newaxis], grid_k[..., np.newaxis]])
|
105 |
+
return LinearNDInterpolator(coords, -np.reshape(disp_map, [-1, 3]))
|
106 |
+
|
107 |
+
|
108 |
+
def resample_segmentation(img, output_shape, preserve_range, threshold=None, gpu=True):
|
109 |
+
# Preserve range can be a bool (keep or not the original dyn. range) or a list with a new dyn. range
|
110 |
+
zoom_f = np.divide(np.asarray(output_shape), np.asarray(img.shape))
|
111 |
+
|
112 |
+
if gpu:
|
113 |
+
out_img = zoom_gpu(cupy.asarray(img), zoom_f, order=1) # order = 0 or 1
|
114 |
+
else:
|
115 |
+
out_img = zoom(img, zoom_f)
|
116 |
+
if isinstance(preserve_range, bool):
|
117 |
+
if preserve_range:
|
118 |
+
range_min, range_max = np.amin(img), np.amax(img)
|
119 |
+
out_img = min_max_norm(out_img)
|
120 |
+
out_img = out_img * (range_max - range_min) + range_min
|
121 |
+
elif isinstance(preserve_range, list):
|
122 |
+
range_min, range_max = preserve_range
|
123 |
+
out_img = min_max_norm(out_img)
|
124 |
+
out_img = out_img * (range_max - range_min) + range_min
|
125 |
+
|
126 |
+
if threshold is not None and out_img.min() < threshold < out_img.max():
|
127 |
+
range_min, range_max = np.amin(out_img), np.amax(out_img)
|
128 |
+
out_img[out_img > threshold] = range_max
|
129 |
+
out_img[out_img < range_max] = range_min
|
130 |
+
return cupy.asnumpy(out_img) if gpu else out_img
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == '__main__':
|
134 |
+
for dataset_name in DATASET_NAMES:
|
135 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
136 |
+
dataset_files = os.listdir(dataset_loc)
|
137 |
+
dataset_files.sort()
|
138 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
139 |
+
|
140 |
+
iterator = tqdm(dataset_files)
|
141 |
+
for file_path in iterator:
|
142 |
+
file_num = int(re.findall('(\d+)', os.path.split(file_path)[-1])[0])
|
143 |
+
|
144 |
+
iterator.set_description('{} ({}): laoding data'.format(file_num, dataset_name))
|
145 |
+
vol_file = h5py.File(file_path, 'r')
|
146 |
+
# fix_vessels = vol_file[C.H5_FIX_VESSELS_MASK][..., 0]
|
147 |
+
disp_map = vol_file[C.H5_GT_DISP][:]
|
148 |
+
bbox = vol_file['parameters/bbox'][:]
|
149 |
+
bbox_min = bbox[:3]
|
150 |
+
bbox_max = bbox[3:] + bbox_min
|
151 |
+
|
152 |
+
# Load vessel segmentation mask and resize to 64^3
|
153 |
+
fix_labels = nib.load(os.path.join(DATASTE_RAW_FILES, 'segmentation-{:04d}.nii.gz'.format(file_num)))
|
154 |
+
fix_vessels = fix_labels.slicer[..., 1]
|
155 |
+
fix_vessels = resample_img(fix_vessels, np.eye(3))
|
156 |
+
fix_vessels = np.asarray(fix_vessels.dataobj)
|
157 |
+
fix_vessels = preprocess_image(fix_vessels)
|
158 |
+
fix_vessels = resample_segmentation(fix_vessels, vol_file['parameters/first_reshape'][:], [0, 1], 0.3,
|
159 |
+
gpu=True)
|
160 |
+
fix_vessels = fix_vessels[bbox_min[0]:bbox_max[0], bbox_min[1]:bbox_max[1], bbox_min[2]:bbox_max[2]]
|
161 |
+
fix_vessels = resample_segmentation(fix_vessels, [64] * 3, [0, 1], 0.3, gpu=True)
|
162 |
+
fix_vessels = preprocess_image(fix_vessels)
|
163 |
+
|
164 |
+
mov_vessels = preprocess_image(warp_volume(fix_vessels, disp_map))
|
165 |
+
mov_skel = skeletonize_3d(mov_vessels)
|
166 |
+
### Fix the incorrect scaling ###
|
167 |
+
disp_map *= 2
|
168 |
+
bbox_size = np.asarray(bbox[3:]) # Only load the bbox size
|
169 |
+
rescale_factors = 64 / bbox_size
|
170 |
+
|
171 |
+
disp_map[..., 0] = np.multiply(disp_map[..., 0], rescale_factors[0])
|
172 |
+
disp_map[..., 1] = np.multiply(disp_map[..., 1], rescale_factors[1])
|
173 |
+
disp_map[..., 2] = np.multiply(disp_map[..., 2], rescale_factors[2])
|
174 |
+
#################################
|
175 |
+
|
176 |
+
iterator.set_description('{} ({}): getting graphs'.format(file_num, dataset_name))
|
177 |
+
# Prepare displacement map
|
178 |
+
disp_map_interpolator = build_displacement_map_interpolator(disp_map, backwards=False)
|
179 |
+
|
180 |
+
# Get skeleton and graph
|
181 |
+
fix_skel = skeletonize_3d(fix_vessels)
|
182 |
+
fix_graph = get_graph_from_skeleton(fix_skel, subsample=True)
|
183 |
+
mov_graph = get_graph_from_skeleton(mov_skel, subsample=True) # deform_graph(fix_graph, disp_map_interpolator)
|
184 |
+
|
185 |
+
##### TODO: ERASE Check the mov graph ######
|
186 |
+
# check_mov_vessels = vol_file[C.H5_MOV_VESSELS_MASK][..., 0]
|
187 |
+
# check_mov_vessels = preprocess_image(check_mov_vessels)
|
188 |
+
# check_mov_skel = skeletonize_3d(check_mov_vessels)
|
189 |
+
# check_mov_graph = get_graph_from_skeleton(check_mov_skel, subsample=True)
|
190 |
+
###########
|
191 |
+
fix_pts, fix_nodes, fix_edges = graph_to_ndarray(fix_graph)
|
192 |
+
mov_pts, mov_nodes, mov_edges = graph_to_ndarray(mov_graph)
|
193 |
+
|
194 |
+
fix_bifur_loc, fix_bifur_id = get_bifurcation_nodes(fix_graph)
|
195 |
+
mov_bifur_loc, mov_bifur_id = get_bifurcation_nodes(mov_graph)
|
196 |
+
|
197 |
+
iterator.set_description('{} ({}): saving data'.format(file_num, dataset_name))
|
198 |
+
pts_file_path, pts_file_name = os.path.split(file_path)
|
199 |
+
pts_file_name = pts_file_name.replace(DATASET_FILENAME, 'points')
|
200 |
+
pts_file_path = os.path.join(pts_file_path, pts_file_name)
|
201 |
+
pts_file = h5py.File(pts_file_path, 'w')
|
202 |
+
|
203 |
+
pts_file.create_dataset('fix/points', data=fix_pts)
|
204 |
+
pts_file.create_dataset('fix/nodes', data=fix_nodes)
|
205 |
+
pts_file.create_dataset('fix/edges', data=fix_edges)
|
206 |
+
pts_file.create_dataset('fix/bifurcations', data=fix_bifur_loc)
|
207 |
+
pts_file.create_dataset('fix/graph', data=fix_graph)
|
208 |
+
pts_file.create_dataset('fix/img', data=fix_vessels)
|
209 |
+
pts_file.create_dataset('fix/skeleton', data=fix_skel)
|
210 |
+
pts_file.create_dataset('fix/centroid', data=vol_file[C.H5_FIX_CENTROID][:])
|
211 |
+
|
212 |
+
pts_file.create_dataset('mov/points', data=mov_pts)
|
213 |
+
pts_file.create_dataset('mov/nodes', data=mov_nodes)
|
214 |
+
pts_file.create_dataset('mov/edges', data=mov_edges)
|
215 |
+
pts_file.create_dataset('mov/bifurcations', data=mov_bifur_loc)
|
216 |
+
pts_file.create_dataset('mov/graph', data=mov_graph)
|
217 |
+
pts_file.create_dataset('mov/img', data=mov_vessels)
|
218 |
+
pts_file.create_dataset('mov/skeleton', data=mov_skel)
|
219 |
+
pts_file.create_dataset('mov/centroid', data=vol_file[C.H5_MOV_CENTROID][:])
|
220 |
+
|
221 |
+
pts_file.create_dataset('parameters/voxel_size', data=vol_file['parameters/voxel_size'][:])
|
222 |
+
pts_file.create_dataset('parameters/original_affine', data=vol_file['parameters/original_affine'][:])
|
223 |
+
pts_file.create_dataset('parameters/isotropic_affine', data=vol_file['parameters/isotropic_affine'][:])
|
224 |
+
pts_file.create_dataset('parameters/original_shape', data=vol_file['parameters/original_shape'][:])
|
225 |
+
pts_file.create_dataset('parameters/isotropic_shape', data=vol_file['parameters/isotropic_shape'][:])
|
226 |
+
pts_file.create_dataset('parameters/first_reshape', data=vol_file['parameters/first_reshape'][:])
|
227 |
+
pts_file.create_dataset('parameters/bbox', data=vol_file['parameters/bbox'][:])
|
228 |
+
pts_file.create_dataset('parameters/last_reshape', data=vol_file['parameters/last_reshape'][:])
|
229 |
+
|
230 |
+
pts_file.create_dataset('displacement_map', data=disp_map)
|
231 |
+
|
232 |
+
vol_file.close()
|
233 |
+
pts_file.close()
|
234 |
+
|
235 |
+
iterator.set_description('{} ({}): drawing plots'.format(file_num, dataset_name))
|
236 |
+
num = pts_file_name.split('-')[-1].split('.hd5')[0]
|
237 |
+
imgs_folder = os.path.join(IMGS_FOLDER, dataset_name, num)
|
238 |
+
os.makedirs(imgs_folder, exist_ok=True)
|
239 |
+
plot_skeleton(fix_vessels, fix_skel, fix_graph, os.path.join(imgs_folder, 'fix'), ['.pdf', '.png'])
|
240 |
+
plot_skeleton(mov_vessels, mov_skel, mov_graph, os.path.join(imgs_folder, 'mov'), ['.pdf', '.png'])
|
241 |
+
iterator.set_description('{} ({})'.format(file_num, dataset_name))
|
Centerline/cpd_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pycpd import DeformableRegistration, RigidRegistration
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
from scipy.interpolate import Rbf
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
def cpd_non_rigid_transform_pt(pt, Y, G, W):
|
8 |
+
from scipy.interpolate import LinearNDInterpolator
|
9 |
+
interp = LinearNDInterpolator(points=Y, values=np.dot(G, W), fill_value=0.)
|
10 |
+
return interp(pt)
|
11 |
+
|
12 |
+
|
13 |
+
def radial_basis_function(pts, vals, function='thin-plate'):
|
14 |
+
# The Rbf function does not handle n-D hyper-surfaces, so we need an interpolator per displacements. Actually it does mode='N-D'
|
15 |
+
pts_unique, idxs = np.unique(pts, return_index=True, axis=0) # Prevent singular matrices
|
16 |
+
ill_conditioned = False
|
17 |
+
with warnings.catch_warnings(record=True) as caught_warns:
|
18 |
+
warnings.simplefilter('always')
|
19 |
+
dx = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 0], function=function)
|
20 |
+
dy = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 1], function=function)
|
21 |
+
dz = Rbf(pts_unique[:, 0], pts_unique[:, 1], pts_unique[:, 2], vals[idxs][:, 2], function=function)
|
22 |
+
for w in caught_warns:
|
23 |
+
print(w)
|
24 |
+
ill_conditioned = ill_conditioned or 'ill-conditioned matrix' in str(w).lower()
|
25 |
+
return lambda int_pt: np.asarray([dx(*int_pt), dy(*int_pt), dz(*int_pt)]), ill_conditioned
|
26 |
+
|
27 |
+
|
28 |
+
def deform_registration(fix_pts, mov_pts, callback_fnc=None, time_it=False, max_iterations=100, tolerance=1e-8, alpha=None, beta=None):
|
29 |
+
deform_reg = DeformableRegistration(**{'Y': mov_pts, 'X': fix_pts},
|
30 |
+
alpha=alpha, beta=beta, tolerance=tolerance, max_iterations=max_iterations)
|
31 |
+
start_t = time.time()
|
32 |
+
trf_mov_pts, deform_p = deform_reg.register(callback_fnc)
|
33 |
+
end_t = time.time()
|
34 |
+
if time_it:
|
35 |
+
return end_t - start_t, deform_reg
|
36 |
+
else:
|
37 |
+
return trf_mov_pts, deform_p, deform_reg
|
38 |
+
|
39 |
+
|
40 |
+
def rigid_registration(fix_pts, mov_pts, callback_fnc=None, time_it=False):
|
41 |
+
rigid_reg = RigidRegistration(**{'Y': mov_pts, 'X': fix_pts})
|
42 |
+
start_t = time.time()
|
43 |
+
trf_mov_pts, trf_p = rigid_reg.register(callback_fnc)
|
44 |
+
end_t = time.time()
|
45 |
+
if time_it:
|
46 |
+
return end_t - start_t, rigid_reg
|
47 |
+
else:
|
48 |
+
return trf_mov_pts, trf_p, rigid_reg
|
Centerline/evaluate_BayesianCPD_skeleton.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
import h5py
|
8 |
+
from tqdm import tqdm
|
9 |
+
from functools import partial
|
10 |
+
import numpy as np
|
11 |
+
from scipy.spatial.distance import euclidean
|
12 |
+
import pandas as pd
|
13 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
14 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
15 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
16 |
+
from scipy.spatial.distance import cdist
|
17 |
+
from skimage.morphology import skeletonize_3d
|
18 |
+
import re
|
19 |
+
from probreg import bcpd
|
20 |
+
|
21 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
22 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
23 |
+
DATASET_FILENAME = 'points'
|
24 |
+
|
25 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
|
26 |
+
|
27 |
+
SCALE = 1e-2 # mm to cm
|
28 |
+
# CPD PARAMS (deform)
|
29 |
+
MAX_ITER = 200
|
30 |
+
ALPHA = 0.1
|
31 |
+
BETA = 1.0 # None = Use default
|
32 |
+
TOLERANCE = 1e-8
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
for dataset_name in DATASET_NAMES:
|
36 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
37 |
+
dataset_files = os.listdir(dataset_loc)
|
38 |
+
dataset_files.sort()
|
39 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
40 |
+
|
41 |
+
iterator = tqdm(dataset_files)
|
42 |
+
df = pd.DataFrame(columns=['DATASET',
|
43 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
44 |
+
'TIME_DEF', 'TIME_R_DEF',
|
45 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
46 |
+
'TRE_DEF', 'TRE_R_DEF',
|
47 |
+
'DS_DISP',
|
48 |
+
'DATA_PATH',
|
49 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
50 |
+
for i, file_path in enumerate(iterator):
|
51 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
52 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
53 |
+
iterator.set_description('{}: start'.format(fn))
|
54 |
+
pts_file = h5py.File(file_path, 'r')
|
55 |
+
# fix_pts = pts_file['fix/points'][:]
|
56 |
+
# fix_nodes = pts_file['fix/nodes'][:]
|
57 |
+
fix_skel = pts_file['fix/skeleton'][:]
|
58 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
59 |
+
|
60 |
+
# mov_pts = pts_file['mov/points'][:]
|
61 |
+
# mov_nodes = pts_file['mov/nodes'][:]
|
62 |
+
mov_skel = pts_file['mov/skeleton'][:]
|
63 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
64 |
+
|
65 |
+
bbox = pts_file['parameters/bbox'][:]
|
66 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
67 |
+
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
|
68 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
69 |
+
# TODO: bring back to original shape!
|
70 |
+
# Reshape to original_shape
|
71 |
+
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
|
72 |
+
# fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
|
73 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
74 |
+
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
|
75 |
+
fix_skel = skeletonize_3d(fix_skel)
|
76 |
+
fix_skel_pts = np.argwhere(fix_skel)
|
77 |
+
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
|
78 |
+
# mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
|
79 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
80 |
+
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
|
81 |
+
mov_skel = skeletonize_3d(mov_skel)
|
82 |
+
mov_skel_pts = np.argwhere(mov_skel)
|
83 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
84 |
+
|
85 |
+
ill_cond_def = False
|
86 |
+
ill_cond_r_def = False
|
87 |
+
# Deformable only
|
88 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
89 |
+
|
90 |
+
tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE)
|
91 |
+
|
92 |
+
if np.isnan(deform_reg_def.diff):
|
93 |
+
tre_def = np.nan
|
94 |
+
pred_mov_centroid = mov_centroid
|
95 |
+
else:
|
96 |
+
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
|
97 |
+
displacement_mov_centroid = tps(mov_centroid)
|
98 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
99 |
+
|
100 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
101 |
+
|
102 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
103 |
+
os.makedirs(plot_file, exist_ok=True)
|
104 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
105 |
+
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
106 |
+
|
107 |
+
# Rigid followed by deformable
|
108 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
109 |
+
|
110 |
+
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
|
111 |
+
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
|
112 |
+
|
113 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
|
114 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
|
115 |
+
|
116 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
|
117 |
+
rigid_yt = rigid_reg_r_def.TY
|
118 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
|
119 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
120 |
+
alpha=ALPHA, beta=BETA)
|
121 |
+
|
122 |
+
if np.isnan(deform_reg_r_def.diff):
|
123 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
124 |
+
else:
|
125 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
126 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
|
127 |
+
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
|
128 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
129 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
130 |
+
|
131 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
132 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
|
133 |
+
|
134 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
135 |
+
os.makedirs(plot_file, exist_ok=True)
|
136 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
137 |
+
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
138 |
+
|
139 |
+
|
140 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
141 |
+
df = df.append({'DATASET': dataset_name,
|
142 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
143 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
144 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
145 |
+
'TIME_DEF': time_def,
|
146 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
147 |
+
'Q_DEF': deform_reg_def.diff,
|
148 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
149 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
150 |
+
'ILL_COND_DEF': ill_cond_def,
|
151 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
152 |
+
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
|
153 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
154 |
+
'DATA_PATH': file_path,
|
155 |
+
'DIST_CENTR': np.min(dist_centroid_to_pts),
|
156 |
+
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
|
157 |
+
'SAMPLE_NUM':fnum}, ignore_index=True)
|
158 |
+
pts_file.close()
|
159 |
+
|
160 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/evaluate_CPD_dense.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
from tqdm import tqdm
|
8 |
+
from functools import partial
|
9 |
+
import numpy as np
|
10 |
+
from scipy.spatial.distance import euclidean
|
11 |
+
import pandas as pd
|
12 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
13 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
14 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
15 |
+
from scipy.spatial.distance import cdist
|
16 |
+
import re
|
17 |
+
|
18 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
19 |
+
DATASET_NAMES = ['None', 'Affine', 'None', 'Translation']
|
20 |
+
DATASET_FILENAME = 'points'
|
21 |
+
|
22 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/dense_final'
|
23 |
+
|
24 |
+
SCALE = 1e-2 # mm to cm
|
25 |
+
|
26 |
+
# CPD PARAMS (deform)
|
27 |
+
MAX_ITER = 200
|
28 |
+
ALPHA = 2.
|
29 |
+
BETA = 2. # None = Use default
|
30 |
+
TOLERANCE = 1e-8
|
31 |
+
RBF_FUNCTION='thin-plate'
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
for dataset_name in DATASET_NAMES:
|
35 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
36 |
+
dataset_files = os.listdir(dataset_loc)
|
37 |
+
dataset_files.sort()
|
38 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
39 |
+
|
40 |
+
iterator = tqdm(dataset_files)
|
41 |
+
df = pd.DataFrame(columns=['DATASET',
|
42 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
43 |
+
'TIME_DEF', 'TIME_R_DEF',
|
44 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
45 |
+
'TRE_DEF', 'TRE_R_DEF',
|
46 |
+
'DS_DISP',
|
47 |
+
'DATA_PATH',
|
48 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
49 |
+
for i, file_path in enumerate(iterator):
|
50 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
51 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
52 |
+
iterator.set_description('{}: start'.format(fn))
|
53 |
+
pts_file = h5py.File(file_path, 'r')
|
54 |
+
fix_pts = pts_file['fix/points'][:]
|
55 |
+
# fix_nodes = pts_file['fix/nodes'][:]
|
56 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
57 |
+
|
58 |
+
mov_pts = pts_file['mov/points'][:]
|
59 |
+
# mov_nodes = pts_file['mov/nodes'][:]
|
60 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
61 |
+
|
62 |
+
bbox = pts_file['parameters/bbox'][:]
|
63 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
64 |
+
original_shape = pts_file['parameters/isotropic_shape'][:]
|
65 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
66 |
+
# TODO: bring back to original shape!
|
67 |
+
# Reshape to original_shape
|
68 |
+
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
|
69 |
+
fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
|
70 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, original_shape)
|
71 |
+
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
|
72 |
+
mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
|
73 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, original_shape)
|
74 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
75 |
+
|
76 |
+
ill_cond_def = False
|
77 |
+
ill_cond_r_def = False
|
78 |
+
# Deformable only
|
79 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
80 |
+
|
81 |
+
# deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
|
82 |
+
|
83 |
+
# _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
|
84 |
+
time_def, deform_reg_def = deform_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True,
|
85 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
86 |
+
alpha=ALPHA, beta=BETA)
|
87 |
+
if np.isnan(deform_reg_def.diff):
|
88 |
+
tre_def = np.nan
|
89 |
+
pred_mov_centroid = np.zeros((3,))
|
90 |
+
else:
|
91 |
+
tps, ill_cond_def = radial_basis_function(mov_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
92 |
+
displacement_mov_centroid = tps(mov_centroid)
|
93 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
94 |
+
|
95 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
96 |
+
|
97 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
98 |
+
os.makedirs(plot_file, exist_ok=True)
|
99 |
+
plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
100 |
+
plot_cpd(fix_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
101 |
+
|
102 |
+
# Rigid followed by deformable
|
103 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
104 |
+
|
105 |
+
# rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER,
|
106 |
+
# '{}/{:04d}/RIGID_DEF/rigid'.format(
|
107 |
+
# dataset_name, fnum)))
|
108 |
+
# deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER,
|
109 |
+
# '{}/{:04d}/RIGID_DEF/deform'.format(
|
110 |
+
# dataset_name, fnum)))
|
111 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
|
112 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
|
113 |
+
|
114 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True)
|
115 |
+
rigid_yt = rigid_reg_r_def.TY
|
116 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_pts*SCALE, rigid_yt, time_it=True,
|
117 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
118 |
+
alpha=ALPHA, beta=BETA)
|
119 |
+
|
120 |
+
if np.isnan(deform_reg_r_def.diff):
|
121 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
122 |
+
else:
|
123 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
124 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
125 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
126 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
127 |
+
|
128 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
129 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_pts)
|
130 |
+
|
131 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
132 |
+
os.makedirs(plot_file, exist_ok=True)
|
133 |
+
plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
134 |
+
plot_cpd(fix_pts, deform_reg_r_def.TY, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
135 |
+
|
136 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
137 |
+
df = df.append({'DATASET': dataset_name,
|
138 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
139 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
140 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
141 |
+
'TIME_DEF': time_def,
|
142 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
143 |
+
'Q_DEF': deform_reg_def.diff,
|
144 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
145 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
146 |
+
'ILL_COND_DEF': ill_cond_def,
|
147 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
148 |
+
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
|
149 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
150 |
+
'DATA_PATH': file_path,
|
151 |
+
'DIST_CENTR': np.min(dist_centroid_to_pts),
|
152 |
+
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
|
153 |
+
'SAMPLE_NUM': fnum}, ignore_index=True)
|
154 |
+
pts_file.close()
|
155 |
+
|
156 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/evaluate_CPD_nodes.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
from tqdm import tqdm
|
8 |
+
from functools import partial
|
9 |
+
import numpy as np
|
10 |
+
from scipy.spatial.distance import euclidean
|
11 |
+
import pandas as pd
|
12 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
13 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
14 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
15 |
+
from scipy.spatial.distance import cdist
|
16 |
+
import re
|
17 |
+
|
18 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
19 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
20 |
+
DATASET_FILENAME = 'points'
|
21 |
+
|
22 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/nodes_final'
|
23 |
+
|
24 |
+
SCALE = 1e-2 # mm to cm
|
25 |
+
|
26 |
+
# CPD PARAMS (deform)
|
27 |
+
MAX_ITER = 200
|
28 |
+
ALPHA = 2.
|
29 |
+
BETA = 2. # None = Use default
|
30 |
+
TOLERANCE = 1e-8
|
31 |
+
RBF_FUNCTION='thin-plate'
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
for dataset_name in DATASET_NAMES:
|
35 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
36 |
+
dataset_files = os.listdir(dataset_loc)
|
37 |
+
dataset_files.sort()
|
38 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
39 |
+
|
40 |
+
iterator = tqdm(dataset_files)
|
41 |
+
df = pd.DataFrame(columns=['DATASET',
|
42 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
43 |
+
'TIME_DEF', 'TIME_R_DEF',
|
44 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
45 |
+
'TRE_DEF', 'TRE_R_DEF',
|
46 |
+
'DS_DISP',
|
47 |
+
'DATA_PATH',
|
48 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
49 |
+
for i, file_path in enumerate(iterator):
|
50 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
51 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
52 |
+
iterator.set_description('{}: start'.format(fn))
|
53 |
+
pts_file = h5py.File(file_path, 'r')
|
54 |
+
fix_pts = pts_file['fix/points'][:]
|
55 |
+
fix_nodes = pts_file['fix/nodes'][:]
|
56 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
57 |
+
|
58 |
+
mov_pts = pts_file['mov/points'][:]
|
59 |
+
mov_nodes = pts_file['mov/nodes'][:]
|
60 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
61 |
+
|
62 |
+
bbox = pts_file['parameters/bbox'][:]
|
63 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
64 |
+
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
|
65 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
66 |
+
# TODO: bring back to original shape!
|
67 |
+
# Reshape to original_shape
|
68 |
+
fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64] * 3, first_reshape, isotropic_shape)
|
69 |
+
fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, isotropic_shape)
|
70 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
71 |
+
mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, isotropic_shape)
|
72 |
+
mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, isotropic_shape)
|
73 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
74 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
75 |
+
|
76 |
+
if mov_nodes.shape[0] == 1:
|
77 |
+
# Otherwise we only have a point, and CPD can't handle that... absurd!
|
78 |
+
fix_nodes = fix_pts
|
79 |
+
mov_nodes = mov_pts
|
80 |
+
|
81 |
+
ill_cond_def = False
|
82 |
+
ill_cond_r_def = False
|
83 |
+
# Deformable only
|
84 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
85 |
+
|
86 |
+
# deform_cb = partial(plot_cpd_registration_step,
|
87 |
+
# out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
|
88 |
+
|
89 |
+
# _, _, deform_reg_def = deform_registration(fix_nodes, mov_nodes, deform_cb)
|
90 |
+
time_def, deform_reg_def = deform_registration(fix_nodes*SCALE, mov_nodes*SCALE, time_it=True,
|
91 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
92 |
+
alpha=ALPHA, beta=BETA)
|
93 |
+
if np.isnan(deform_reg_def.diff):
|
94 |
+
tre_def = np.nan
|
95 |
+
pred_mov_centroid = np.zeros((3,))
|
96 |
+
else:
|
97 |
+
tps, ill_cond_def = radial_basis_function(mov_nodes, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
98 |
+
displacement_mov_centroid = tps(mov_centroid)
|
99 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
100 |
+
|
101 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
102 |
+
|
103 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
104 |
+
os.makedirs(plot_file, exist_ok=True)
|
105 |
+
plot_cpd(fix_nodes, mov_nodes, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
106 |
+
plot_cpd(fix_nodes, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
107 |
+
|
108 |
+
# Rigid followed by deformable
|
109 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
110 |
+
|
111 |
+
# rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
|
112 |
+
# deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
|
113 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_nodes, mov_nodes, rigid_cb)
|
114 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_nodes, rigid_yt, deform_cb)
|
115 |
+
|
116 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_nodes*SCALE, mov_nodes*SCALE, time_it=True)
|
117 |
+
rigid_yt = rigid_reg_r_def.TY
|
118 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_nodes*SCALE, rigid_yt, time_it=True,
|
119 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
120 |
+
alpha=ALPHA, beta=BETA)
|
121 |
+
|
122 |
+
if np.isnan(deform_reg_r_def.diff):
|
123 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
124 |
+
else:
|
125 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
126 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION)
|
127 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
128 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
129 |
+
|
130 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
131 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_nodes)
|
132 |
+
|
133 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
134 |
+
os.makedirs(plot_file, exist_ok=True)
|
135 |
+
plot_cpd(fix_nodes, mov_nodes, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
136 |
+
plot_cpd(fix_nodes, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
137 |
+
|
138 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
139 |
+
df = df.append({'DATASET':dataset_name,
|
140 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
141 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
142 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
143 |
+
'TIME_DEF': time_def,
|
144 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
145 |
+
'Q_DEF': deform_reg_def.diff,
|
146 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
147 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
148 |
+
'ILL_COND_DEF': ill_cond_def,
|
149 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
150 |
+
'TRE_DEF':tre_def, 'TRE_R_DEF':tre_r_def,
|
151 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
152 |
+
'DATA_PATH':file_path,
|
153 |
+
'DIST_CENTR':np.min(dist_centroid_to_pts),
|
154 |
+
'DIST_CENTR_DEF_95':np.percentile(dist_centroid_to_pts, 95),
|
155 |
+
'SAMPLE_NUM':fnum}, ignore_index=True)
|
156 |
+
pts_file.close()
|
157 |
+
|
158 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/evaluate_CPD_skeleton.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
import h5py
|
8 |
+
from tqdm import tqdm
|
9 |
+
from functools import partial
|
10 |
+
import numpy as np
|
11 |
+
from scipy.spatial.distance import euclidean
|
12 |
+
import pandas as pd
|
13 |
+
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
|
14 |
+
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
|
15 |
+
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
|
16 |
+
from scipy.spatial.distance import cdist
|
17 |
+
from skimage.morphology import skeletonize_3d
|
18 |
+
import re
|
19 |
+
|
20 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
21 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
22 |
+
DATASET_FILENAME = 'points'
|
23 |
+
|
24 |
+
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
|
25 |
+
|
26 |
+
SCALE = 1e-2 # mm to cm
|
27 |
+
# CPD PARAMS (deform)
|
28 |
+
MAX_ITER = 200
|
29 |
+
ALPHA = 0.1
|
30 |
+
BETA = 1.0 # None = Use default
|
31 |
+
TOLERANCE = 1e-8
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
for dataset_name in DATASET_NAMES:
|
35 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
36 |
+
dataset_files = os.listdir(dataset_loc)
|
37 |
+
dataset_files.sort()
|
38 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
39 |
+
|
40 |
+
iterator = tqdm(dataset_files)
|
41 |
+
df = pd.DataFrame(columns=['DATASET',
|
42 |
+
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
|
43 |
+
'TIME_DEF', 'TIME_R_DEF',
|
44 |
+
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
|
45 |
+
'TRE_DEF', 'TRE_R_DEF',
|
46 |
+
'DS_DISP',
|
47 |
+
'DATA_PATH',
|
48 |
+
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
|
49 |
+
for i, file_path in enumerate(iterator):
|
50 |
+
fn = os.path.split(file_path)[-1].split('.hd5')[0]
|
51 |
+
fnum = int(re.findall('(\d+)', fn)[0])
|
52 |
+
iterator.set_description('{}: start'.format(fn))
|
53 |
+
pts_file = h5py.File(file_path, 'r')
|
54 |
+
# fix_pts = pts_file['fix/points'][:]
|
55 |
+
# fix_nodes = pts_file['fix/nodes'][:]
|
56 |
+
fix_skel = pts_file['fix/skeleton'][:]
|
57 |
+
fix_centroid = pts_file['fix/centroid'][:]
|
58 |
+
|
59 |
+
# mov_pts = pts_file['mov/points'][:]
|
60 |
+
# mov_nodes = pts_file['mov/nodes'][:]
|
61 |
+
mov_skel = pts_file['mov/skeleton'][:]
|
62 |
+
mov_centroid = pts_file['mov/centroid'][:]
|
63 |
+
|
64 |
+
bbox = pts_file['parameters/bbox'][:]
|
65 |
+
first_reshape = pts_file['parameters/first_reshape'][:]
|
66 |
+
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
|
67 |
+
iterator.set_description('{}: Loaded data'.format(fn))
|
68 |
+
# TODO: bring back to original shape!
|
69 |
+
# Reshape to original_shape
|
70 |
+
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
|
71 |
+
# fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
|
72 |
+
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
73 |
+
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
|
74 |
+
fix_skel = skeletonize_3d(fix_skel)
|
75 |
+
fix_skel_pts = np.argwhere(fix_skel)
|
76 |
+
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
|
77 |
+
# mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
|
78 |
+
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
|
79 |
+
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
|
80 |
+
mov_skel = skeletonize_3d(mov_skel)
|
81 |
+
mov_skel_pts = np.argwhere(mov_skel)
|
82 |
+
iterator.set_description('{}: reshaped data'.format(fn))
|
83 |
+
|
84 |
+
ill_cond_def = False
|
85 |
+
ill_cond_r_def = False
|
86 |
+
# Deformable only
|
87 |
+
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
|
88 |
+
|
89 |
+
deform_cb = partial(plot_cpd_registration_step,
|
90 |
+
out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)))
|
91 |
+
|
92 |
+
# _, _, deform_reg_def = deform_registration(fix_pts, mov_pts, deform_cb)
|
93 |
+
time_def, deform_reg_def = deform_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True,
|
94 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
95 |
+
alpha=ALPHA, beta=BETA)
|
96 |
+
if np.isnan(deform_reg_def.diff):
|
97 |
+
tre_def = np.nan
|
98 |
+
pred_mov_centroid = mov_centroid
|
99 |
+
else:
|
100 |
+
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
|
101 |
+
displacement_mov_centroid = tps(mov_centroid)
|
102 |
+
pred_mov_centroid = mov_centroid + displacement_mov_centroid
|
103 |
+
|
104 |
+
tre_def = euclidean(pred_mov_centroid, fix_centroid)
|
105 |
+
|
106 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
|
107 |
+
os.makedirs(plot_file, exist_ok=True)
|
108 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
109 |
+
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
110 |
+
|
111 |
+
# Rigid followed by deformable
|
112 |
+
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
|
113 |
+
|
114 |
+
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
|
115 |
+
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
|
116 |
+
|
117 |
+
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
|
118 |
+
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
|
119 |
+
|
120 |
+
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
|
121 |
+
rigid_yt = rigid_reg_r_def.TY
|
122 |
+
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
|
123 |
+
tolerance=TOLERANCE, max_iterations=MAX_ITER,
|
124 |
+
alpha=ALPHA, beta=BETA)
|
125 |
+
|
126 |
+
if np.isnan(deform_reg_r_def.diff):
|
127 |
+
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
128 |
+
else:
|
129 |
+
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
|
130 |
+
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
|
131 |
+
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
|
132 |
+
displacement_mov_centroid_t = tps(mov_centroid_t)
|
133 |
+
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
|
134 |
+
|
135 |
+
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
|
136 |
+
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
|
137 |
+
|
138 |
+
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
|
139 |
+
os.makedirs(plot_file, exist_ok=True)
|
140 |
+
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
|
141 |
+
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
|
142 |
+
|
143 |
+
|
144 |
+
iterator.set_description('{}: Saving data'.format(fn))
|
145 |
+
df = df.append({'DATASET': dataset_name,
|
146 |
+
'ITERATIONS_DEF': deform_reg_def.iteration,
|
147 |
+
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
|
148 |
+
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
|
149 |
+
'TIME_DEF': time_def,
|
150 |
+
'TIME_R_DEF': time_r_def__r + time_r_def__def,
|
151 |
+
'Q_DEF': deform_reg_def.diff,
|
152 |
+
'Q_R_DEF__R': rigid_reg_r_def.q,
|
153 |
+
'Q_R_DEF__DEF': deform_reg_r_def.diff,
|
154 |
+
'ILL_COND_DEF': ill_cond_def,
|
155 |
+
'ILL_COND_R_DEF': ill_cond_r_def,
|
156 |
+
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
|
157 |
+
'DS_DISP':euclidean(mov_centroid, fix_centroid),
|
158 |
+
'DATA_PATH': file_path,
|
159 |
+
'DIST_CENTR': np.min(dist_centroid_to_pts),
|
160 |
+
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
|
161 |
+
'SAMPLE_NUM':fnum}, ignore_index=True)
|
162 |
+
pts_file.close()
|
163 |
+
|
164 |
+
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|
Centerline/get_vessels.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
2 |
+
from tqdm import tqdm
|
3 |
+
import os
|
4 |
+
import h5py
|
5 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
6 |
+
|
7 |
+
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
|
8 |
+
DATASET_NAMES = ['Affine', 'None', 'Translation']
|
9 |
+
DATASET_FILENAME = 'volume'
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
for dataset_name in DATASET_NAMES:
|
14 |
+
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
|
15 |
+
dataset_files = os.listdir(dataset_loc)
|
16 |
+
dataset_files.sort()
|
17 |
+
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
|
18 |
+
|
19 |
+
iterator = tqdm(dataset_files)
|
20 |
+
for fn in iterator:
|
21 |
+
f = os.path.split(fn)[-1].split('.hd5')[0]
|
22 |
+
vol_file = h5py.File(fn, 'r')
|
23 |
+
fix_vessels = vol_file[C.H5_FIX_VESSELS_MASK][..., 0]
|
24 |
+
mov_vessels = vol_file[C.H5_MOV_VESSELS_MASK][..., 0]
|
25 |
+
|
26 |
+
dst_folder = os.path.join(os.getcwd(), 'VESSELS', dataset_name)
|
27 |
+
os.makedirs(dst_folder, exist_ok=True)
|
28 |
+
save_nifti(fix_vessels, os.path.join(dst_folder, f+'_fix.nii.gz'))
|
29 |
+
save_nifti(mov_vessels, os.path.join(dst_folder, f+'_mov.nii.gz'))
|
30 |
+
vol_file.close()
|
Centerline/graph_utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import networkx as nx
|
2 |
+
import numpy as np
|
3 |
+
from scipy.interpolate import RegularGridInterpolator, LinearNDInterpolator
|
4 |
+
|
5 |
+
|
6 |
+
def graph_to_ndarray(graph):
|
7 |
+
out_nodes = np.empty((1, 3))
|
8 |
+
out_edges = np.empty((1, 3))
|
9 |
+
visited_nodes = list()
|
10 |
+
visited_node_pairs = list()
|
11 |
+
for (start_node, end_node) in graph.edges():
|
12 |
+
if (not (start_node, end_node) in visited_node_pairs) and (not (end_node, start_node) in visited_node_pairs):
|
13 |
+
edge = graph[start_node][end_node]['pts']
|
14 |
+
out_edges = np.vstack([out_edges, edge])
|
15 |
+
|
16 |
+
# Avoid duplicates
|
17 |
+
if not (start_node in visited_nodes):
|
18 |
+
out_nodes = np.vstack([out_nodes, graph.nodes[start_node]['o']])
|
19 |
+
visited_nodes.append(start_node)
|
20 |
+
if not (end_node in visited_nodes):
|
21 |
+
out_nodes = np.vstack([out_nodes, graph.nodes[end_node]['o']])
|
22 |
+
visited_nodes.append(end_node)
|
23 |
+
|
24 |
+
visited_node_pairs.append((start_node, end_node))
|
25 |
+
|
26 |
+
return np.vstack([out_edges, out_nodes]), out_nodes, out_edges
|
27 |
+
|
28 |
+
|
29 |
+
def get_bifurcation_nodes(graph: nx.Graph):
|
30 |
+
# Vertex degree relates to the number of branches connected to a given node
|
31 |
+
out_nodes = np.empty((1, 3))
|
32 |
+
bif_nodes_id = list()
|
33 |
+
for node_num, deg in graph.degree:
|
34 |
+
if deg > 1:
|
35 |
+
bif_nodes_id.append(node_num)
|
36 |
+
out_nodes = np.vstack([out_nodes, graph.nodes[node_num]['o']])
|
37 |
+
|
38 |
+
return out_nodes, bif_nodes_id
|
39 |
+
|
40 |
+
|
41 |
+
def apply_displacement(pts_list: np.ndarray, interpolator: [RegularGridInterpolator, LinearNDInterpolator]):
|
42 |
+
pts_list = pts_list.astype(np.float)
|
43 |
+
ret_val = pts_list + interpolator(pts_list).squeeze()
|
44 |
+
return ret_val
|
45 |
+
|
46 |
+
|
47 |
+
def deform_graph(graph, dm_interpolator: [RegularGridInterpolator, LinearNDInterpolator]):
|
48 |
+
def_graph = nx.Graph()
|
49 |
+
for (start_node, end_node) in graph.edges():
|
50 |
+
edge = graph[start_node][end_node]['pts']
|
51 |
+
def_edge = apply_displacement(edge, dm_interpolator)
|
52 |
+
|
53 |
+
def_start_node_pts = apply_displacement(graph.nodes[start_node]['pts'], dm_interpolator)
|
54 |
+
def_end_node_pts = apply_displacement(graph.nodes[end_node]['pts'], dm_interpolator)
|
55 |
+
|
56 |
+
def_start_node_o = apply_displacement(graph.nodes[start_node]['o'], dm_interpolator)
|
57 |
+
def_end_node_o = apply_displacement(graph.nodes[end_node]['o'], dm_interpolator)
|
58 |
+
|
59 |
+
def_graph.add_node(start_node, pts=def_start_node_pts, o=def_start_node_o)
|
60 |
+
def_graph.add_node(end_node, pts=def_end_node_pts, o=def_end_node_o)
|
61 |
+
def_graph.add_edge(start_node, end_node, pts=def_edge, weight=len(def_edge))
|
62 |
+
return def_graph
|
63 |
+
|
64 |
+
|
65 |
+
def subsample_graph(graph: nx.Graph, num_samples=3):
|
66 |
+
sub_graph = nx.Graph()
|
67 |
+
for (start_node, end_node) in graph.edges():
|
68 |
+
edge = graph[start_node][end_node]['pts']
|
69 |
+
edge_len = edge.shape[0]
|
70 |
+
sub_edge_len = (edge_len - 2) // num_samples # Do not count the pts corresponding to the nodes (-2)
|
71 |
+
|
72 |
+
sub_edge = [edge[0]]
|
73 |
+
include_last = bool((edge_len - 2) % num_samples) # Skip the last point, as this is too close to the node
|
74 |
+
if sub_edge_len:
|
75 |
+
idxs = np.arange(0, edge_len, num_samples)[1:] if include_last else np.arange(0, edge_len, num_samples)[1:-1]
|
76 |
+
for i in idxs:
|
77 |
+
sub_edge.append(edge[i])
|
78 |
+
|
79 |
+
sub_edge.append(edge[-1])
|
80 |
+
sub_edge = np.asarray(sub_edge)
|
81 |
+
sub_graph.add_node(start_node, pts=graph.nodes[start_node]['pts'], o=graph.nodes[start_node]['o'])
|
82 |
+
sub_graph.add_node(end_node, pts=graph.nodes[end_node]['pts'], o=graph.nodes[end_node]['o'])
|
83 |
+
sub_graph.add_edge(start_node, end_node, pts=sub_edge, weight=len(sub_edge))
|
84 |
+
return sub_graph
|
85 |
+
|
Centerline/skeleton_to_graph.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SRC: https://github.com/Image-Py/sknw/blob/master/sknw/sknw.py
|
2 |
+
import numpy as np
|
3 |
+
import networkx as nx
|
4 |
+
from Centerline.graph_utils import subsample_graph
|
5 |
+
|
6 |
+
|
7 |
+
def neighbors(shape):
|
8 |
+
dim = len(shape)
|
9 |
+
block = np.ones([3] * dim)
|
10 |
+
block[tuple([1] * dim)] = 0
|
11 |
+
idx = np.where(block > 0)
|
12 |
+
idx = np.array(idx, dtype=np.uint8).T
|
13 |
+
idx = np.array(idx - [1] * dim)
|
14 |
+
acc = np.cumprod((1,) + shape[::-1][:-1])
|
15 |
+
return np.dot(idx, acc[::-1])
|
16 |
+
|
17 |
+
|
18 |
+
# my mark
|
19 |
+
def mark(img, nbs): # mark the array use (0, 1, 2)
|
20 |
+
img = img.ravel()
|
21 |
+
for p in range(len(img)):
|
22 |
+
if img[p] == 0: continue
|
23 |
+
s = 0
|
24 |
+
for dp in nbs:
|
25 |
+
if img[p + dp] != 0: s += 1
|
26 |
+
if s == 2:
|
27 |
+
img[p] = 1
|
28 |
+
else:
|
29 |
+
img[p] = 2
|
30 |
+
|
31 |
+
|
32 |
+
# trans index to r, c...
|
33 |
+
def idx2rc(idx, acc):
|
34 |
+
rst = np.zeros((len(idx), len(acc)), dtype=np.int16)
|
35 |
+
for i in range(len(idx)):
|
36 |
+
for j in range(len(acc)):
|
37 |
+
rst[i, j] = idx[i] // acc[j]
|
38 |
+
idx[i] -= rst[i, j] * acc[j]
|
39 |
+
rst -= 1
|
40 |
+
return rst
|
41 |
+
|
42 |
+
|
43 |
+
# fill a node (may be two or more points)
|
44 |
+
def fill(img, p, num, nbs, acc, buf):
|
45 |
+
back = img[p]
|
46 |
+
img[p] = num
|
47 |
+
buf[0] = p
|
48 |
+
cur = 0;
|
49 |
+
s = 1;
|
50 |
+
|
51 |
+
while True:
|
52 |
+
p = buf[cur]
|
53 |
+
for dp in nbs:
|
54 |
+
cp = p + dp
|
55 |
+
if img[cp] == back:
|
56 |
+
img[cp] = num
|
57 |
+
buf[s] = cp
|
58 |
+
s += 1
|
59 |
+
cur += 1
|
60 |
+
if cur == s: break
|
61 |
+
return idx2rc(buf[:s], acc)
|
62 |
+
|
63 |
+
|
64 |
+
# trace the edge and use a buffer, then buf.copy, if use [] numba not works
|
65 |
+
def trace(img, p, nbs, acc, buf):
|
66 |
+
c1 = 0;
|
67 |
+
c2 = 0;
|
68 |
+
newp = 0
|
69 |
+
cur = 0
|
70 |
+
|
71 |
+
while True:
|
72 |
+
buf[cur] = p
|
73 |
+
img[p] = 0
|
74 |
+
cur += 1
|
75 |
+
for dp in nbs:
|
76 |
+
cp = p + dp
|
77 |
+
if img[cp] >= 10:
|
78 |
+
if c1 == 0:
|
79 |
+
c1 = img[cp]
|
80 |
+
else:
|
81 |
+
c2 = img[cp]
|
82 |
+
if img[cp] == 1:
|
83 |
+
newp = cp
|
84 |
+
p = newp
|
85 |
+
if c2 != 0: break
|
86 |
+
return (c1 - 10, c2 - 10, idx2rc(buf[:cur], acc))
|
87 |
+
|
88 |
+
|
89 |
+
# parse the image then get the nodes and edges
|
90 |
+
def parse_struc(img, pts, nbs, acc):
|
91 |
+
img = img.ravel()
|
92 |
+
buf = np.zeros(131072, dtype=np.int64)
|
93 |
+
num = 10
|
94 |
+
nodes = []
|
95 |
+
for p in pts:
|
96 |
+
if img[p] == 2:
|
97 |
+
nds = fill(img, p, num, nbs, acc, buf)
|
98 |
+
num += 1
|
99 |
+
nodes.append(nds)
|
100 |
+
|
101 |
+
edges = []
|
102 |
+
for p in pts:
|
103 |
+
for dp in nbs:
|
104 |
+
if img[p + dp] == 1:
|
105 |
+
edge = trace(img, p + dp, nbs, acc, buf)
|
106 |
+
edges.append(edge)
|
107 |
+
return nodes, edges
|
108 |
+
|
109 |
+
|
110 |
+
# use nodes and edges build a networkx graph
|
111 |
+
def build_graph(nodes, edges, multi=False):
|
112 |
+
graph = nx.MultiGraph() if multi else nx.Graph()
|
113 |
+
for i in range(len(nodes)):
|
114 |
+
graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0))
|
115 |
+
for s, e, pts in edges:
|
116 |
+
l = np.linalg.norm(pts[1:] - pts[:-1], axis=1).sum()
|
117 |
+
graph.add_edge(s, e, pts=pts, weight=l)
|
118 |
+
return graph
|
119 |
+
|
120 |
+
|
121 |
+
def buffer(ske):
|
122 |
+
buf = np.zeros(tuple(np.array(ske.shape) + 2), dtype=np.uint16)
|
123 |
+
buf[tuple([slice(1, -1)] * buf.ndim)] = ske
|
124 |
+
return buf
|
125 |
+
|
126 |
+
|
127 |
+
def build_sknw(ske, multi=False):
|
128 |
+
buf = buffer(ske)
|
129 |
+
nbs = neighbors(buf.shape)
|
130 |
+
acc = np.cumprod((1,) + buf.shape[::-1][:-1])[::-1]
|
131 |
+
mark(buf, nbs)
|
132 |
+
pts = np.array(np.where(buf.ravel() == 2))[0]
|
133 |
+
nodes, edges = parse_struc(buf, pts, nbs, acc)
|
134 |
+
return build_graph(nodes, edges, multi)
|
135 |
+
|
136 |
+
|
137 |
+
# draw the graph
|
138 |
+
def draw_graph(img, graph, cn=255, ce=128):
|
139 |
+
acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1]
|
140 |
+
img = img.ravel()
|
141 |
+
for idx in graph.nodes():
|
142 |
+
pts = graph.nodes[idx]['pts']
|
143 |
+
img[np.dot(pts, acc)] = cn
|
144 |
+
for (s, e) in graph.edges():
|
145 |
+
eds = graph[s][e]
|
146 |
+
for i in eds:
|
147 |
+
pts = eds[i]['pts']
|
148 |
+
img[np.dot(pts, acc)] = ce
|
149 |
+
|
150 |
+
|
151 |
+
def get_graph_from_skeleton(mask, subsample=False):
|
152 |
+
graph = build_sknw(mask, False)
|
153 |
+
if len(graph.nodes) > 1 and len(graph.edges) and subsample:
|
154 |
+
graph = subsample_graph(graph, 3)
|
155 |
+
return graph
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == '__main__':
|
159 |
+
g = nx.MultiGraph()
|
160 |
+
g.add_nodes_from([1, 2, 3, 4, 5])
|
161 |
+
g.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 5), (5, 4)])
|
162 |
+
print(g.nodes())
|
163 |
+
print(g.edges())
|
164 |
+
a = g.subgraph(1)
|
165 |
+
print('d')
|
166 |
+
print(a)
|
167 |
+
print('d')
|
Centerline/skeletonization.py
ADDED
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import numpy as np
|
3 |
+
import nibabel as nib
|
4 |
+
from scipy import ndimage as ndi
|
5 |
+
from scipy.signal import convolve
|
6 |
+
from numpy.linalg import norm
|
7 |
+
import networkx as nx
|
8 |
+
import logging
|
9 |
+
import traceback
|
10 |
+
import timeit
|
11 |
+
import time
|
12 |
+
import math
|
13 |
+
from ast import literal_eval as make_tuple
|
14 |
+
from skimage.measure import label
|
15 |
+
import subprocess
|
16 |
+
import platform
|
17 |
+
import glob
|
18 |
+
|
19 |
+
|
20 |
+
def loadVolume(volumeFolderPath, volumeName):
|
21 |
+
"""
|
22 |
+
Load nifti files (*.nii or *.nii.gz).
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
volumeFolderPath : str
|
26 |
+
Folder of the volume file.
|
27 |
+
volumeName : str
|
28 |
+
Name of the volume file.
|
29 |
+
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
volume : ndarray
|
33 |
+
Volume data in the form of numpy ndarray.
|
34 |
+
affine : ndarray
|
35 |
+
Associated affine transformation matrix in the form of numpy ndarray.
|
36 |
+
"""
|
37 |
+
volumeFilePath = os.path.join(volumeFolderPath, volumeName)
|
38 |
+
volumeImg = nib.load(volumeFilePath)
|
39 |
+
volume = volumeImg.get_data()
|
40 |
+
shape = volume.shape
|
41 |
+
affine = volumeImg.affine
|
42 |
+
print('Volume loaded from {} with shape = {}.'.format(volumeFilePath, shape))
|
43 |
+
|
44 |
+
return volume, affine
|
45 |
+
|
46 |
+
|
47 |
+
def saveVolume(volume, affine, path, astype=None):
|
48 |
+
"""
|
49 |
+
Save the given volume to the specified location in specified data type.
|
50 |
+
Parameters
|
51 |
+
----------
|
52 |
+
volume : ndarray
|
53 |
+
Volume data to be saved.
|
54 |
+
affine : ndarray
|
55 |
+
The affine transformation matrix associated with the volume.
|
56 |
+
path : str
|
57 |
+
The absolute path where the volume is going to be saved.
|
58 |
+
astype : numpy dtype, optional
|
59 |
+
The desired data type of the volume data.
|
60 |
+
"""
|
61 |
+
if astype is None:
|
62 |
+
astype = np.uint8
|
63 |
+
|
64 |
+
nib.save(nib.Nifti1Image(volume.astype(astype), affine), path)
|
65 |
+
print('Volume saved to {} as type {}.'.format(path, astype))
|
66 |
+
|
67 |
+
|
68 |
+
def labelVolume(volume, minSize=1, maxHop=3):
|
69 |
+
"""
|
70 |
+
Partition the volume into several connected components and attach labels.
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
volume : ndarray
|
74 |
+
Volume to be partitioned.
|
75 |
+
minSize : int, optional
|
76 |
+
The connected component that is less than this size will be disgarded.
|
77 |
+
maxHop : int, optional
|
78 |
+
Controls how neighboring voxels are defined. See `label` doc for details.
|
79 |
+
|
80 |
+
Returns
|
81 |
+
-------
|
82 |
+
labeled : ndarray
|
83 |
+
The partitioned and labeled volume. Each connected component has a label (a positive integer) and the background
|
84 |
+
is labeled as 0.
|
85 |
+
labelResult : list
|
86 |
+
In the form of [[label1, size1], [label2, size2], ...]
|
87 |
+
"""
|
88 |
+
labeled, maxNum = label(volume, return_num=True, connectivity=maxHop)
|
89 |
+
counts = np.bincount(labeled.ravel())
|
90 |
+
countLoc = np.nonzero(counts)[0]
|
91 |
+
sizeList = counts[countLoc]
|
92 |
+
labelResult = list(zip(countLoc[sizeList >= minSize], sizeList[sizeList >= minSize]))
|
93 |
+
# print(labelResult)
|
94 |
+
# print('Total segments: {}'.format(np.count_nonzero(sizeList >= minSize)))
|
95 |
+
return labeled, labelResult
|
96 |
+
|
97 |
+
|
98 |
+
def analyze(vesselVolumeMask, baseFolder):
|
99 |
+
"""
|
100 |
+
Main function to provoke the skeletonization process. Note that here I am using the docker version of the code. If
|
101 |
+
you have already downloaded the original C++ code and successfully compiled it, then you can run that compiled code
|
102 |
+
instead of this one.
|
103 |
+
"""
|
104 |
+
vesselVolumeMask = vesselVolumeMask.astype(np.uint8)
|
105 |
+
vesselVolumeMask[vesselVolumeMask != 0] = 1
|
106 |
+
vesselVolumeMask = np.swapaxes(vesselVolumeMask, 0, 2)
|
107 |
+
shape = vesselVolumeMask.shape
|
108 |
+
|
109 |
+
vesselVolumeMaskLabeled, vesselVolumeMaskLabelResult = labelVolume(vesselVolumeMask, minSize=1)
|
110 |
+
directory = os.path.join(baseFolder, 'skeletonizationResult')
|
111 |
+
if not os.path.exists(directory):
|
112 |
+
os.makedirs(directory)
|
113 |
+
print('Directory {} created.'.format(directory))
|
114 |
+
|
115 |
+
vesselVolumeMaskLabelInfoFilename = 'vesselVolumeMaskLabelInfo.npz'
|
116 |
+
vesselVolumeMaskLabelInfoFilePath = os.path.join(directory, vesselVolumeMaskLabelInfoFilename)
|
117 |
+
np.savez_compressed(vesselVolumeMaskLabelInfoFilePath, vesselVolumeMaskLabeled=vesselVolumeMaskLabeled,
|
118 |
+
vesselVolumeMaskLabelResult=vesselVolumeMaskLabelResult)
|
119 |
+
print('{} saved to {}.'.format(vesselVolumeMaskLabelInfoFilename, vesselVolumeMaskLabelInfoFilePath))
|
120 |
+
|
121 |
+
# directory2 = directory + 'labelNum=' + str(labelNum) + '/'
|
122 |
+
# if not os.path.exists(directory2):
|
123 |
+
# os.makedirs(directory2)
|
124 |
+
# with open(directory2 + 'BB.txt', 'w') as f1:
|
125 |
+
# f1.write('1\n')
|
126 |
+
# f1.write('{} {} {}\n'.format(0, 0, 0))
|
127 |
+
# f1.write('{} {} {}'.format(*shape))
|
128 |
+
# '''
|
129 |
+
BBFilePath = os.path.join(directory, 'BB.txt')
|
130 |
+
f1 = open(BBFilePath, 'w')
|
131 |
+
f1.write('1\n')
|
132 |
+
f1.write('{} {} {}\n'.format(0, 0, 0))
|
133 |
+
f1.write('{} {} {}'.format(*shape))
|
134 |
+
f1.close()
|
135 |
+
|
136 |
+
vesselCoords = np.array(np.where(vesselVolumeMask)).T
|
137 |
+
xyzFilePath = os.path.join(directory, 'xyz.txt')
|
138 |
+
np.savetxt(xyzFilePath, vesselCoords, fmt='%1u')
|
139 |
+
f2 = open(xyzFilePath, "r")
|
140 |
+
contents = f2.readlines()
|
141 |
+
f2.close()
|
142 |
+
|
143 |
+
contents.insert(0, '{}\n'.format(len(vesselCoords)))
|
144 |
+
|
145 |
+
f2 = open(xyzFilePath, "w")
|
146 |
+
contents = "".join(contents)
|
147 |
+
f2.write(contents)
|
148 |
+
f2.close()
|
149 |
+
# '''
|
150 |
+
|
151 |
+
# '''
|
152 |
+
currentPlatform = platform.system()
|
153 |
+
print('Current platform is {}.'.format(currentPlatform))
|
154 |
+
if currentPlatform == 'Windows':
|
155 |
+
cmd = '"C:/Program Files/Docker/Docker/Resources/bin/docker.exe" run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
|
156 |
+
elif currentPlatform == 'Darwin':
|
157 |
+
cmd = 'docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2'
|
158 |
+
elif currentPlatform == 'Linux':
|
159 |
+
cmd = '/usr/local/bin/docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
|
160 |
+
cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker'
|
161 |
+
cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2'
|
162 |
+
|
163 |
+
print('cmd={}'.format(cmd))
|
164 |
+
subprocess.call(cmd, shell=True)
|
165 |
+
# '''
|
166 |
+
|
167 |
+
|
168 |
+
def combineSkeletonSegments(skeletonSegmentFolderPath):
|
169 |
+
"""
|
170 |
+
Collect and combine the results from the skeletonization.
|
171 |
+
Parameters
|
172 |
+
----------
|
173 |
+
skeletonSegmentFolderPath : str
|
174 |
+
The folder that contains the segments information (result_segments_xyz*.txt).
|
175 |
+
|
176 |
+
Returns
|
177 |
+
-------
|
178 |
+
segmentList : list
|
179 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
180 |
+
represents a centerpoint coordinates.
|
181 |
+
"""
|
182 |
+
segmentList = []
|
183 |
+
files = glob.glob(os.path.join(skeletonSegmentFolderPath, 'result_segments_xyz*.txt'))
|
184 |
+
for segmentFile in files:
|
185 |
+
result = readSegmentFile(segmentFile)
|
186 |
+
segmentList += result
|
187 |
+
|
188 |
+
return segmentList
|
189 |
+
|
190 |
+
|
191 |
+
def readSegmentFile(segmentFile):
|
192 |
+
"""
|
193 |
+
Parse the segment files (result_segments_xyz*.txt) and return segments information in a list.
|
194 |
+
Parameters
|
195 |
+
----------
|
196 |
+
segmentFile : str
|
197 |
+
Path to the segment file.
|
198 |
+
|
199 |
+
Returns
|
200 |
+
-------
|
201 |
+
segmentList : list
|
202 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
203 |
+
represents a centerpoint coordinates.
|
204 |
+
"""
|
205 |
+
isFirstLine = True
|
206 |
+
isSegmentLength = True
|
207 |
+
segmentList = []
|
208 |
+
with open(segmentFile) as f:
|
209 |
+
for line in f:
|
210 |
+
if isFirstLine:
|
211 |
+
numOfSegments = int(line)
|
212 |
+
isFirstLine = False
|
213 |
+
else:
|
214 |
+
if isSegmentLength:
|
215 |
+
segmentLength = int(line)
|
216 |
+
isSegmentLength = False
|
217 |
+
segmentCounter = 1
|
218 |
+
segment = []
|
219 |
+
else:
|
220 |
+
if segmentCounter <= segmentLength:
|
221 |
+
voxel = tuple([int(x) for x in line.split(' ')])
|
222 |
+
segment.append(voxel[::-1])
|
223 |
+
segmentCounter += 1
|
224 |
+
else:
|
225 |
+
segmentCounter += 1
|
226 |
+
isSegmentLength = True
|
227 |
+
segmentList.append(segment)
|
228 |
+
assert (len(segment) == segmentLength)
|
229 |
+
|
230 |
+
return segmentList
|
231 |
+
|
232 |
+
|
233 |
+
# def drawSegments(segmentList):
|
234 |
+
# pass
|
235 |
+
|
236 |
+
def processSegments(segmentList, shape):
|
237 |
+
"""
|
238 |
+
Re-partition the segments so that each segment is a simple branch, i.e., it does not contain bifurcation point
|
239 |
+
unless at the two ends.
|
240 |
+
Note that this function might be replaced by another more concise function `getSegmentList`.
|
241 |
+
Parameters
|
242 |
+
----------
|
243 |
+
segmentList : list
|
244 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
245 |
+
represents a centerpoint coordinates.
|
246 |
+
shape : tuple
|
247 |
+
Shape of the vessel volume (used for ploting).
|
248 |
+
|
249 |
+
Returns
|
250 |
+
-------
|
251 |
+
G : NetworkX graph
|
252 |
+
A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch.
|
253 |
+
segmentList : list
|
254 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
255 |
+
represents a centerpoint coordinates.
|
256 |
+
errorSegments : list
|
257 |
+
A list that contains segments that cannot be fixed.
|
258 |
+
"""
|
259 |
+
## Import pyqtgraph ##
|
260 |
+
from pyqtgraph.Qt import QtCore, QtGui
|
261 |
+
import pyqtgraph as pg
|
262 |
+
import pyqtgraph.opengl as gl
|
263 |
+
|
264 |
+
## Init ##
|
265 |
+
app = pg.QtGui.QApplication([])
|
266 |
+
w = gl.GLViewWidget()
|
267 |
+
w.opts['distance'] = 800
|
268 |
+
w.setGeometry(0, 110, 1600, 900)
|
269 |
+
offset = np.array(shape) / (-2.0)
|
270 |
+
|
271 |
+
G = nx.Graph()
|
272 |
+
colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')]
|
273 |
+
colorPointer = 0
|
274 |
+
skeleton = np.full(shape, 0)
|
275 |
+
for segment in segmentList:
|
276 |
+
# G.add_path(list(map(tuple, segment)))
|
277 |
+
G.add_path(segment)
|
278 |
+
segmentCoords = np.array(segment)
|
279 |
+
skeleton[tuple(segmentCoords.T)] = 1
|
280 |
+
# segmentCoordsView = segmentCoords + offset
|
281 |
+
# aa = gl.GLLinePlotItem(pos=segmentCoordsView, color=colorList[colorPointer], width=3)
|
282 |
+
# w.addItem(aa)
|
283 |
+
# colorPointer = colorPointer + 1 if colorPointer < len(colorList) - 1 else 0
|
284 |
+
|
285 |
+
# skeletonCoords = np.array(np.where(skeleton)).T
|
286 |
+
# skeletonCoordsView = (skeletonCoords + offset) * affineTransform
|
287 |
+
# aa = gl.GLScatterPlotItem(pos=skeletonCoordsView, size=5)
|
288 |
+
# w.addItem(aa)
|
289 |
+
|
290 |
+
# w.show()
|
291 |
+
|
292 |
+
voxelDegrees = np.array([v for _, v in G.degree(G.nodes())])
|
293 |
+
maxVoxelDegree = np.amax(voxelDegrees)
|
294 |
+
voxelDegreesZippedResult = list(zip(np.arange(maxVoxelDegree + 1), np.bincount(voxelDegrees)))
|
295 |
+
print('Voxel degree distribution is \n{}'.format(voxelDegreesZippedResult))
|
296 |
+
print('Number of cycles is {}'.format(len(nx.cycle_basis(G))))
|
297 |
+
|
298 |
+
# Remove duplicate segments
|
299 |
+
keepList = np.full((len(segmentList),), True)
|
300 |
+
duplicateCounter = 0
|
301 |
+
for idx, seg in enumerate(segmentList):
|
302 |
+
for idx2, seg2 in enumerate(segmentList[idx + 1:]):
|
303 |
+
if seg == seg2 or seg == seg2[::-1]:
|
304 |
+
keepList[idx + idx2] = False
|
305 |
+
duplicateCounter += 1
|
306 |
+
|
307 |
+
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
308 |
+
print('{} duplicate segments removed!'.format(duplicateCounter))
|
309 |
+
|
310 |
+
# Cut segments into sub-pieces if there are bifurcation points in the middle
|
311 |
+
extraSegments = []
|
312 |
+
keepList = np.full((len(segmentList),), True)
|
313 |
+
for idx, segment in enumerate(segmentList):
|
314 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
315 |
+
if len(voxelDegrees) >= 3:
|
316 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or (not np.all(voxelDegrees[1:-1] == 2)):
|
317 |
+
keepList[idx] = False
|
318 |
+
locs = np.nonzero(voxelDegrees != 2)[0]
|
319 |
+
if voxelDegrees[0] == 2:
|
320 |
+
locs = np.hstack((0, locs))
|
321 |
+
|
322 |
+
if voxelDegrees[-1] == 2:
|
323 |
+
locs = np.hstack((locs, len(voxelDegrees)))
|
324 |
+
|
325 |
+
newSegments = []
|
326 |
+
for ii in range(len(locs) - 1):
|
327 |
+
newSegments.append(segment[locs[ii]:(locs[ii + 1] + 1)])
|
328 |
+
|
329 |
+
extraSegments += newSegments
|
330 |
+
|
331 |
+
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
332 |
+
segmentList += extraSegments
|
333 |
+
|
334 |
+
# Remove duplicate segments again
|
335 |
+
keepList = np.full((len(segmentList),), True)
|
336 |
+
duplicateCounter = 0
|
337 |
+
for idx, seg in enumerate(segmentList):
|
338 |
+
for idx2, seg2 in enumerate(segmentList[idx + 1:]):
|
339 |
+
if seg == seg2 or seg == seg2[::-1]:
|
340 |
+
keepList[idx + idx2] = False
|
341 |
+
duplicateCounter += 1
|
342 |
+
|
343 |
+
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
344 |
+
print('{} duplicate segments removed in the second stage!'.format(duplicateCounter))
|
345 |
+
|
346 |
+
# Remove segment if it is completely contained in another segment
|
347 |
+
# keepList = np.full((len(segmentList),), True)
|
348 |
+
# sublistCounter = 0
|
349 |
+
# for idx, seg in enumerate(segmentList):
|
350 |
+
# for idx2, seg2 in enumerate(segmentList[idx + 1:]):
|
351 |
+
# if contains(seg, seg2):
|
352 |
+
# keepList[idx] = False
|
353 |
+
# sublistCounter += 1
|
354 |
+
# elif contains(seg2, seg):
|
355 |
+
# keepList[idx + idx2] = False
|
356 |
+
# sublistCounter += 1
|
357 |
+
|
358 |
+
# segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]]
|
359 |
+
# print('{} sublist segments removed!'.format(sublistCounter))
|
360 |
+
|
361 |
+
# Treat the segment if either end is not correct
|
362 |
+
hasInvalidSegments = False
|
363 |
+
for idx, segment in enumerate(segmentList):
|
364 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
365 |
+
if len(voxelDegrees) == 2:
|
366 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
|
367 |
+
# print('Degrees on either end is 2: {}'.format(voxelDegrees))
|
368 |
+
hasInvalidSegments = True
|
369 |
+
elif len(voxelDegrees) > 2:
|
370 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2):
|
371 |
+
# print('Degrees not correct: {}'.format(voxelDegrees))
|
372 |
+
hasInvalidSegments = True
|
373 |
+
|
374 |
+
if not hasInvalidSegments:
|
375 |
+
drawSegments(segmentList, shape)
|
376 |
+
print('No errors!')
|
377 |
+
errorSegments = []
|
378 |
+
return G, segmentList, errorSegments
|
379 |
+
|
380 |
+
iterCounter = 1
|
381 |
+
while hasInvalidSegments:
|
382 |
+
print('\n\nIter={}'.format(iterCounter))
|
383 |
+
keepList = np.full((len(segmentList),), True)
|
384 |
+
extraSegments = []
|
385 |
+
for idx, segment in enumerate(segmentList):
|
386 |
+
if keepList[idx]:
|
387 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
388 |
+
if voxelDegrees[0] == 2 and voxelDegrees[-1] == 2:
|
389 |
+
print('Both end have 2 neighbours')
|
390 |
+
elif voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
|
391 |
+
# print('Degrees on either end is 2: {}'.format(voxelDegrees))
|
392 |
+
# pass
|
393 |
+
# segmentCoords = np.array(segment)
|
394 |
+
if voxelDegrees[0] == 2:
|
395 |
+
otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
396 |
+
(seg[0] == segment[0] or seg[-1] == segment[0]) and keepList[
|
397 |
+
idx2] and idx != idx2]
|
398 |
+
if len(otherSegmentInfo) != 0:
|
399 |
+
if len(otherSegmentInfo) > 1:
|
400 |
+
# print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment))
|
401 |
+
otherSegmentInfoTemp = []
|
402 |
+
for idx2, seg in otherSegmentInfo:
|
403 |
+
if contains(segment, seg) or contains(segment[::-1], seg):
|
404 |
+
keepList[idx] = False
|
405 |
+
continue
|
406 |
+
elif contains(seg, segment) or contains(seg[::-1], segment):
|
407 |
+
keepList[idx2] = False
|
408 |
+
otherSegmentInfoTemp.append((idx2, seg))
|
409 |
+
|
410 |
+
otherSegmentInfo = otherSegmentInfoTemp
|
411 |
+
# otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))]
|
412 |
+
if len(otherSegmentInfo) > 1:
|
413 |
+
print('More than one other segments found!')
|
414 |
+
print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees))
|
415 |
+
for otherSegmentIdx, otherSegment in otherSegmentInfo:
|
416 |
+
otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)])
|
417 |
+
print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment,
|
418 |
+
otherSegmentVoxelDegrees))
|
419 |
+
elif len(otherSegmentInfo) == 1:
|
420 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
421 |
+
else:
|
422 |
+
print('No valid other segments found!')
|
423 |
+
continue
|
424 |
+
else:
|
425 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
426 |
+
if contains(segment, otherSegment) or contains(segment[::-1], otherSegment):
|
427 |
+
keepList[idx] = False
|
428 |
+
continue
|
429 |
+
elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment):
|
430 |
+
keepList[otherSegmentIdx] = False
|
431 |
+
continue
|
432 |
+
|
433 |
+
newSegment = otherSegment + segment[1:] if otherSegment[-1] == segment[0] else otherSegment[
|
434 |
+
::-1] + segment[
|
435 |
+
1:]
|
436 |
+
if not validateSegment(G, newSegment):
|
437 |
+
newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)])
|
438 |
+
print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees,
|
439 |
+
newSegmentVoxelDegrees))
|
440 |
+
else:
|
441 |
+
print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx))
|
442 |
+
|
443 |
+
extraSegments.append(newSegment)
|
444 |
+
keepList[idx] = False
|
445 |
+
keepList[otherSegmentIdx] = False
|
446 |
+
else:
|
447 |
+
print(
|
448 |
+
'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment,
|
449 |
+
voxelDegrees))
|
450 |
+
possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
451 |
+
(seg[0] == segment[0] or seg[-1] == segment[0]) and idx != idx2]
|
452 |
+
print('Possible segments: {}'.format(len(possibleSegmentsInfo)))
|
453 |
+
|
454 |
+
elif voxelDegrees[-1] == 2:
|
455 |
+
otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
456 |
+
(seg[0] == segment[-1] or seg[-1] == segment[-1]) and keepList[
|
457 |
+
idx2] and idx != idx2]
|
458 |
+
if len(otherSegmentInfo) != 0:
|
459 |
+
if len(otherSegmentInfo) > 1:
|
460 |
+
# print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment))
|
461 |
+
otherSegmentInfoTemp = []
|
462 |
+
for idx2, seg in otherSegmentInfo:
|
463 |
+
if contains(segment, seg) or contains(segment[::-1], seg):
|
464 |
+
keepList[idx] = False
|
465 |
+
continue
|
466 |
+
elif contains(seg, segment) or contains(seg[::-1], segment):
|
467 |
+
keepList[idx2] = False
|
468 |
+
otherSegmentInfoTemp.append((idx2, seg))
|
469 |
+
|
470 |
+
otherSegmentInfo = otherSegmentInfoTemp
|
471 |
+
# otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))]
|
472 |
+
if len(otherSegmentInfo) > 1:
|
473 |
+
print('More than one other segments found!')
|
474 |
+
print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees))
|
475 |
+
for otherSegmentIdx, otherSegment in otherSegmentInfo:
|
476 |
+
otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)])
|
477 |
+
print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment,
|
478 |
+
otherSegmentVoxelDegrees))
|
479 |
+
elif len(otherSegmentInfo) == 1:
|
480 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
481 |
+
else:
|
482 |
+
print('No valid other segments found!')
|
483 |
+
continue
|
484 |
+
else:
|
485 |
+
otherSegmentIdx, otherSegment = otherSegmentInfo[0]
|
486 |
+
if contains(segment, otherSegment) or contains(segment[::-1], otherSegment):
|
487 |
+
keepList[idx] = False
|
488 |
+
continue
|
489 |
+
elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment):
|
490 |
+
keepList[otherSegmentIdx] = False
|
491 |
+
continue
|
492 |
+
|
493 |
+
newSegment = segment[:-1] + otherSegment if otherSegment[0] == segment[-1] else segment[
|
494 |
+
:-1] + otherSegment[
|
495 |
+
::-1]
|
496 |
+
if not validateSegment(G, newSegment):
|
497 |
+
newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)])
|
498 |
+
print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees,
|
499 |
+
newSegmentVoxelDegrees))
|
500 |
+
else:
|
501 |
+
print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx))
|
502 |
+
|
503 |
+
extraSegments.append(newSegment)
|
504 |
+
keepList[idx] = False
|
505 |
+
keepList[otherSegmentIdx] = False
|
506 |
+
else:
|
507 |
+
print(
|
508 |
+
'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment,
|
509 |
+
voxelDegrees))
|
510 |
+
possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if
|
511 |
+
(seg[0] == segment[-1] or seg[-1] == segment[-1]) and idx != idx2]
|
512 |
+
print('Possible segments: {}'.format(len(possibleSegmentsInfo)))
|
513 |
+
|
514 |
+
segmentList = [segment for idx, segment in enumerate(segmentList) if keepList[idx]]
|
515 |
+
segmentList += extraSegments
|
516 |
+
hasInvalidSegments = False
|
517 |
+
errorSegments = []
|
518 |
+
for idx, segment in enumerate(segmentList):
|
519 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
520 |
+
if len(voxelDegrees) == 2:
|
521 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2:
|
522 |
+
print('Degrees on either end is 2: {}'.format(voxelDegrees))
|
523 |
+
hasInvalidSegments = True
|
524 |
+
errorSegments.append(segment)
|
525 |
+
elif len(voxelDegrees) > 2:
|
526 |
+
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2):
|
527 |
+
print('Degrees not correct: {}'.format(voxelDegrees))
|
528 |
+
hasInvalidSegments = True
|
529 |
+
errorSegments.append(segment)
|
530 |
+
|
531 |
+
print('hasInvalidSegments = {}'.format(hasInvalidSegments))
|
532 |
+
iterCounter += 1
|
533 |
+
if len(extraSegments) == 0:
|
534 |
+
hasInvalidSegments = False
|
535 |
+
print('While loop aborted because there is no change in segments!')
|
536 |
+
|
537 |
+
for errorSegment in errorSegments:
|
538 |
+
segmentList.remove(errorSegment)
|
539 |
+
|
540 |
+
# np.savez_compressed(directory + 'segmentList.npz', segmentList=segmentList)
|
541 |
+
# if partIdx != 10:
|
542 |
+
# nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNamePartial + str(partIdx) + '.nii.gz')
|
543 |
+
# else:
|
544 |
+
# nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNameTotal + '.nii.gz')
|
545 |
+
|
546 |
+
# nx.write_graphml(G, directory + 'graphRepresentation.graphml')
|
547 |
+
|
548 |
+
# drawAbstractGraph(offset, segmentList)
|
549 |
+
# drawAbstractGraph(offset, errorSegments)
|
550 |
+
|
551 |
+
print(errorSegments)
|
552 |
+
|
553 |
+
return G, segmentList, errorSegments
|
554 |
+
|
555 |
+
|
556 |
+
def getSegmentList(G, nodeInfoDict):
|
557 |
+
"""
|
558 |
+
Generate segmentList from graph and nodeInfoDict.
|
559 |
+
Parameters
|
560 |
+
----------
|
561 |
+
G : NetworkX graph
|
562 |
+
The graph representation of the network.
|
563 |
+
nodeInfoDict : dict
|
564 |
+
A dictionary containing the information about nodes.
|
565 |
+
|
566 |
+
Returns
|
567 |
+
-------
|
568 |
+
segmentList : list
|
569 |
+
A list of segments in which each segment is a simple branch.
|
570 |
+
"""
|
571 |
+
startNodeIDList = [nodeID for nodeID in nodeInfoDict.keys() if nodeInfoDict[nodeID]['parentNodeID'] == -1]
|
572 |
+
print('startNodeIDList = {}'.format(startNodeIDList))
|
573 |
+
segmentList = []
|
574 |
+
for startNodeID in startNodeIDList:
|
575 |
+
segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID)
|
576 |
+
|
577 |
+
print('There are {} segments in segmentList'.format(len(segmentList)))
|
578 |
+
print(segmentList)
|
579 |
+
return segmentList
|
580 |
+
|
581 |
+
|
582 |
+
def getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID):
|
583 |
+
"""
|
584 |
+
Implementation of `getSegmentList`. Use DFS to traverse all the segments.
|
585 |
+
Parameters
|
586 |
+
----------
|
587 |
+
G : NetworkX graph
|
588 |
+
The graph representation of the network.
|
589 |
+
nodeInfoDict : dict
|
590 |
+
A dictionary containing the information about nodes.
|
591 |
+
segmentList : list
|
592 |
+
A list of segments in which each segment is a simple branch.
|
593 |
+
startNodeID : int
|
594 |
+
The index of the start point of a segment.
|
595 |
+
|
596 |
+
Returns
|
597 |
+
-------
|
598 |
+
segmentList : list
|
599 |
+
A list of segments in which each segment is a simple branch.
|
600 |
+
"""
|
601 |
+
neighborNodeIDList = [nodeID for nodeID in list(G[startNodeID].keys()) if
|
602 |
+
'visited' not in G[startNodeID][nodeID]] # use adjacency dict to find neighbors
|
603 |
+
newSegmentList = []
|
604 |
+
for neighborNodeID in neighborNodeIDList:
|
605 |
+
newSegment = [startNodeID, neighborNodeID]
|
606 |
+
G[startNodeID][neighborNodeID]['visited'] = True
|
607 |
+
currentNodeID = neighborNodeID
|
608 |
+
while G.degree(currentNodeID) == 2:
|
609 |
+
newNodeID = [nodeID for nodeID in G[currentNodeID].keys() if 'visited' not in G[currentNodeID][nodeID]][0]
|
610 |
+
G[currentNodeID][newNodeID]['visited'] = True
|
611 |
+
newSegment.append(newNodeID)
|
612 |
+
currentNodeID = newNodeID
|
613 |
+
|
614 |
+
newSegmentList.append(newSegment)
|
615 |
+
segmentList.append(newSegment)
|
616 |
+
segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, currentNodeID)
|
617 |
+
|
618 |
+
return segmentList
|
619 |
+
|
620 |
+
|
621 |
+
def sublist(ls1, ls2):
|
622 |
+
'''
|
623 |
+
>>> sublist([], [1,2,3])
|
624 |
+
True
|
625 |
+
>>> sublist([1,2,3,4], [2,5,3])
|
626 |
+
True
|
627 |
+
>>> sublist([1,2,3,4], [0,3,2])
|
628 |
+
False
|
629 |
+
>>> sublist([1,2,3,4], [1,2,5,6,7,8,5,76,4,3])
|
630 |
+
False
|
631 |
+
'''
|
632 |
+
|
633 |
+
def get_all_in(one, another):
|
634 |
+
for element in one:
|
635 |
+
if element in another:
|
636 |
+
yield element
|
637 |
+
|
638 |
+
for x1, x2 in zip(get_all_in(ls1, ls2), get_all_in(ls2, ls1)):
|
639 |
+
if x1 != x2:
|
640 |
+
return False
|
641 |
+
|
642 |
+
return True
|
643 |
+
|
644 |
+
|
645 |
+
def contains(lst1, lst2):
|
646 |
+
lst1, lst2 = (lst2, lst1) if len(lst1) > len(lst2) else (lst1, lst2)
|
647 |
+
if lst1[0] in lst2:
|
648 |
+
startLoc = lst2.index(lst1[0])
|
649 |
+
else:
|
650 |
+
return False
|
651 |
+
|
652 |
+
if lst1[-1] in lst2:
|
653 |
+
endLoc = lst2.index(lst1[-1])
|
654 |
+
else:
|
655 |
+
return False
|
656 |
+
|
657 |
+
if startLoc < endLoc:
|
658 |
+
if lst1 == lst2[startLoc:(endLoc + 1)]:
|
659 |
+
return True
|
660 |
+
else:
|
661 |
+
return False
|
662 |
+
else:
|
663 |
+
if lst1 == lst2[endLoc:(startLoc + 1)][::-1]:
|
664 |
+
return True
|
665 |
+
else:
|
666 |
+
return False
|
667 |
+
|
668 |
+
|
669 |
+
def validateSegment(G, segment):
|
670 |
+
"""
|
671 |
+
Check whether a segment is a simple branch.
|
672 |
+
Parameters
|
673 |
+
----------
|
674 |
+
G : NetworkX graph
|
675 |
+
A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch.
|
676 |
+
segment : list
|
677 |
+
A list containing the coordinates of the centerpoints of a segment.
|
678 |
+
|
679 |
+
Returns
|
680 |
+
-------
|
681 |
+
result : bool
|
682 |
+
If True, the segment is a simple branch.
|
683 |
+
"""
|
684 |
+
voxelDegrees = np.array([v for _, v in G.degree(segment)])
|
685 |
+
if voxelDegrees[0] != 2 and voxelDegrees[-1] != 2:
|
686 |
+
if len(voxelDegrees) == 2:
|
687 |
+
result = True
|
688 |
+
elif len(voxelDegrees) > 2:
|
689 |
+
if np.all(voxelDegrees[1:-1] == 2):
|
690 |
+
result = True
|
691 |
+
else:
|
692 |
+
result = False
|
693 |
+
else:
|
694 |
+
print('Error! Segment with length 1 found!')
|
695 |
+
result = False
|
696 |
+
else:
|
697 |
+
result = False
|
698 |
+
|
699 |
+
return result
|
700 |
+
|
701 |
+
|
702 |
+
def drawSegments(segmentList, shape):
|
703 |
+
"""
|
704 |
+
Plot all the segments in `segmentList`. Try to assign different colors to the segments connected to the same node.
|
705 |
+
Parameters
|
706 |
+
----------
|
707 |
+
segmentList : list
|
708 |
+
A list containing the segment information. Each sublist represents a segment and each element in the sublist
|
709 |
+
represents a centerpoint coordinates.
|
710 |
+
shape : tuple
|
711 |
+
Shape of the vessel volume (used for ploting).
|
712 |
+
"""
|
713 |
+
## Import pyqtgraph ##
|
714 |
+
from pyqtgraph.Qt import QtCore, QtGui
|
715 |
+
import pyqtgraph as pg
|
716 |
+
import pyqtgraph.opengl as gl
|
717 |
+
|
718 |
+
## Init ##
|
719 |
+
app = pg.QtGui.QApplication([])
|
720 |
+
w = gl.GLViewWidget()
|
721 |
+
w.opts['distance'] = 800
|
722 |
+
w.setGeometry(0, 110, 1600, 900)
|
723 |
+
offset = np.array(shape) / (-2.0)
|
724 |
+
|
725 |
+
colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')]
|
726 |
+
colorNames = ['Red', 'Green', 'Blue', 'Cyan', 'Magneta', 'Yellow']
|
727 |
+
numOfColors = len(colorList)
|
728 |
+
nodeColorDict = {}
|
729 |
+
for segment in segmentList:
|
730 |
+
startVoxel = segment[0]
|
731 |
+
endVoxel = segment[-1]
|
732 |
+
if startVoxel in nodeColorDict and endVoxel in nodeColorDict: # and endVoxel in [voxel for voxel, _ in nodeColorDict[startVoxel]]:
|
733 |
+
nodeColorDict[startVoxel].append([endVoxel, -1])
|
734 |
+
nodeColorDict[endVoxel].append([startVoxel, -1])
|
735 |
+
else:
|
736 |
+
if startVoxel not in nodeColorDict:
|
737 |
+
nodeColorDict[startVoxel] = [[endVoxel, -1]]
|
738 |
+
else:
|
739 |
+
nodeColorDict[startVoxel].append([endVoxel, -1])
|
740 |
+
|
741 |
+
if endVoxel not in nodeColorDict:
|
742 |
+
nodeColorDict[endVoxel] = [[startVoxel, -1]]
|
743 |
+
else:
|
744 |
+
nodeColorDict[endVoxel].append([startVoxel, -1])
|
745 |
+
|
746 |
+
existingColorsInStart = [colorCode for _, colorCode in nodeColorDict[startVoxel]]
|
747 |
+
existingColorsInEnd = [colorCode for _, colorCode in nodeColorDict[endVoxel]]
|
748 |
+
availableColors = [colorCode for colorCode in range(numOfColors) if
|
749 |
+
colorCode not in existingColorsInStart and colorCode not in existingColorsInEnd]
|
750 |
+
# print('color in start: {} and color in end: {}'.format(existingColorsInStart, existingColorsInEnd))
|
751 |
+
chosenColor = availableColors[0] if len(availableColors) != 0 else 0
|
752 |
+
nodeColorDict[startVoxel][-1][1] = chosenColor
|
753 |
+
nodeColorDict[endVoxel][-1][1] = chosenColor
|
754 |
+
|
755 |
+
segmentCoords = np.array(segment)
|
756 |
+
aa = gl.GLLinePlotItem(pos=segmentCoords, color=colorList[chosenColor], width=3)
|
757 |
+
aa.translate(*offset)
|
758 |
+
w.addItem(aa)
|
759 |
+
|
760 |
+
w.show()
|
761 |
+
pg.QtGui.QApplication.exec_()
|
762 |
+
# sys.exit(app.exec_())
|
763 |
+
|
764 |
+
|
765 |
+
def main():
|
766 |
+
start_time = timeit.default_timer()
|
767 |
+
baseFolder = os.path.abspath(os.path.dirname(__file__))
|
768 |
+
|
769 |
+
## Load existing volume ##
|
770 |
+
vesselVolumeMaskFolderPath = baseFolder
|
771 |
+
vesselVolumeMaskFileName = 'vesselVolumeMask.nii.gz'
|
772 |
+
vesselVolumeMask, vesselVolumeMaskAffine = loadVolume(vesselVolumeMaskFolderPath, vesselVolumeMaskFileName)
|
773 |
+
|
774 |
+
## Skeletonization ##
|
775 |
+
# analyze(vesselVolumeMask, baseFolder)
|
776 |
+
|
777 |
+
skeletonSegmentFolderPath = os.path.join(baseFolder, 'skeletonizationResult/segments_by_cc')
|
778 |
+
segmentListRough = combineSkeletonSegments(skeletonSegmentFolderPath)
|
779 |
+
|
780 |
+
shape = vesselVolumeMask.shape
|
781 |
+
# drawSegments(segmentListRough, shape)
|
782 |
+
|
783 |
+
G, segmentList, errorSegments = processSegments(segmentListRough, shape=shape)
|
784 |
+
# drawSegments(segmentList, shape)
|
785 |
+
G = nx.Graph()
|
786 |
+
segmentIndex = 0
|
787 |
+
for segment in segmentList:
|
788 |
+
G.add_path(segment, segmentIndex=segmentIndex)
|
789 |
+
segmentIndex += 1
|
790 |
+
|
791 |
+
## Save graph representation ##
|
792 |
+
graphFileName = 'graphRepresentation.graphml'
|
793 |
+
graphFilePath = os.path.join(baseFolder, graphFileName)
|
794 |
+
nx.write_graphml(G, graphFilePath)
|
795 |
+
print('{} saved to {}.'.format(graphFileName, graphFilePath))
|
796 |
+
|
797 |
+
## Save segmentList ##
|
798 |
+
segmentListFileName = 'segmentList.npz'
|
799 |
+
segmentListFilePath = os.path.join(baseFolder, segmentListFileName)
|
800 |
+
np.savez_compressed(segmentListFilePath, segmentList=segmentList)
|
801 |
+
print('{} saved to {}.'.format(segmentListFileName, segmentListFilePath))
|
802 |
+
|
803 |
+
## Save skeleton.nii.gz ##
|
804 |
+
skeleton = np.zeros_like(vesselVolumeMask)
|
805 |
+
for segment in segmentList:
|
806 |
+
skeleton[tuple(np.array(segment).T)] = 1
|
807 |
+
|
808 |
+
skeletonFileName = 'skeleton.nii.gz'
|
809 |
+
skeletonFilePath = os.path.join(baseFolder, skeletonFileName)
|
810 |
+
saveVolume(skeleton, vesselVolumeMaskAffine, skeletonFilePath, astype=np.uint8)
|
811 |
+
|
812 |
+
elapsed = timeit.default_timer() - start_time
|
813 |
+
print('Elapsed: {} sec'.format(elapsed))
|
814 |
+
|
815 |
+
|
816 |
+
if __name__ == "__main__":
|
817 |
+
main()
|
Centerline/thinPlateSplines.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial.distance import pdist, cdist, squareform
|
3 |
+
from sklearn.metrics import pairwise_distances
|
4 |
+
|
5 |
+
class ThinPlateSplines:
|
6 |
+
def __init__(self, ctrl_pts: np.ndarray, target_pts: np.ndarray, reg=0.0):
|
7 |
+
"""
|
8 |
+
|
9 |
+
:param ctrl_pts: [N, d] tensor of control d-dimensional points
|
10 |
+
:param target_pts: [N, d] tensor of target d-dimensional points
|
11 |
+
:param reg: regularization coefficient
|
12 |
+
"""
|
13 |
+
self.__ctrl_pts = ctrl_pts
|
14 |
+
self.__target_pts = target_pts
|
15 |
+
self.__reg = reg
|
16 |
+
self.__num_ctrl_pts = ctrl_pts.shape[0]
|
17 |
+
self.__dim = ctrl_pts.shape[1]
|
18 |
+
|
19 |
+
self.__K = None
|
20 |
+
self.__compute_coeffs()
|
21 |
+
self.__aff_params = self.__coeffs[self.__num_ctrl_pts:, ...] # Affine parameters of the TPS
|
22 |
+
self.__non_aff_paramms = self.__coeffs[:self.__num_ctrl_pts, ...] # Non-affine parameters of he TPS
|
23 |
+
|
24 |
+
def __compute_coeffs(self):
|
25 |
+
target_pts_aug = np.vstack([self.__target_pts,
|
26 |
+
np.zeros([self.__dim + 1, self.__dim])]).astype(self.__target_pts.dtype)
|
27 |
+
|
28 |
+
T_i = np.linalg.inv(self.__make_T()).astype(self.__target_pts.dtype)
|
29 |
+
self.__coeffs = np.matmul(T_i, target_pts_aug).astype(self.__target_pts.dtype)
|
30 |
+
|
31 |
+
def __make_T(self):
|
32 |
+
# cp: [K x 2] control points
|
33 |
+
# T: [(K+3) x (K+3)]
|
34 |
+
P = np.hstack([np.ones([self.__num_ctrl_pts, 1], dtype=np.float), self.__ctrl_pts])
|
35 |
+
zeros = np.zeros([self.__dim + 1, self.__dim + 1], dtype=np.float)
|
36 |
+
self.__K = self.__U_dist(self.__ctrl_pts)
|
37 |
+
alfa = np.mean(self.__K)
|
38 |
+
|
39 |
+
self.__K = self.__K + np.ones_like(self.__K) * np.power(alfa, 2) * self.__reg
|
40 |
+
|
41 |
+
top = np.hstack([P, self.__K])
|
42 |
+
bottom = np.hstack([P.transpose(), zeros])
|
43 |
+
|
44 |
+
return np.vstack([top, bottom])
|
45 |
+
|
46 |
+
def __U_dist(self, ctrl_pts, int_pts=None):
|
47 |
+
dist = pairwise_distances(ctrl_pts, int_pts)
|
48 |
+
|
49 |
+
if ctrl_pts.shape[-1] == 2:
|
50 |
+
u_dist = np.square(dist) * np.log(dist + 1e-6)
|
51 |
+
else:
|
52 |
+
u_dist = np.sqrt(dist)
|
53 |
+
|
54 |
+
return u_dist
|
55 |
+
|
56 |
+
def __lift_pts(self, int_pts: np.ndarray, num_pts):
|
57 |
+
# int_pts: [N x 2], input points
|
58 |
+
# cp: [K x 2], control points
|
59 |
+
# pLift: [N x (3+K)], lifted input points
|
60 |
+
|
61 |
+
int_pts_lift = np.hstack([self.__U_dist(self.__ctrl_pts, int_pts),
|
62 |
+
np.ones([num_pts, 1], dtype=np.float),
|
63 |
+
int_pts])
|
64 |
+
return int_pts_lift
|
65 |
+
|
66 |
+
def _get_coefficients(self):
|
67 |
+
return self.__coeffs
|
68 |
+
|
69 |
+
def interpolate(self, int_points):
|
70 |
+
"""
|
71 |
+
|
72 |
+
:param int_points: [K, d] flattened d-points of a mesh
|
73 |
+
:return:
|
74 |
+
"""
|
75 |
+
num_pts = int_points.shape[0]
|
76 |
+
int_points_lift = self.__lift_pts(int_points, num_pts)
|
77 |
+
return np.dot(int_points_lift, self.__coeffs)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def bending_energy(self):
|
81 |
+
aux = tf.matmul(self.__non_aff_paramms, self.__K, transpose_a=True)
|
82 |
+
return tf.matmul(aux, self.__non_aff_paramms)
|
Centerline/visualization_utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
from mpl_toolkits.mplot3d import Axes3D
|
3 |
+
from matplotlib.lines import Line2D
|
4 |
+
import numpy as np
|
5 |
+
from DeepDeformationMapRegistration.utils.visualization import add_axes_arrows_3d, remove_tick_labels, set_axes_size
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
def _plot_graph(graph, ax, nodes_colour='C3', edges_colour='C1', plot_nodes=True, plot_edges=True, add_axes=True):
|
10 |
+
if plot_edges:
|
11 |
+
for (start_node, end_node) in graph.edges():
|
12 |
+
edge_pts = graph[start_node][end_node]['pts']
|
13 |
+
edge_pts = np.vstack([graph.nodes[start_node]['o'], edge_pts])
|
14 |
+
edge_pts = np.vstack([edge_pts, graph.nodes[end_node]['o']])
|
15 |
+
ax.plot(edge_pts[:, 0], edge_pts[:, 1], edge_pts[:, 2], edges_colour)
|
16 |
+
|
17 |
+
if plot_nodes:
|
18 |
+
nodes = graph.nodes()
|
19 |
+
ps = np.array([nodes[i]['o'] for i in nodes])
|
20 |
+
if len(ps.shape) > 1:
|
21 |
+
ax.scatter(ps[:, 0], ps[:, 1], ps[:, 2], nodes_colour)
|
22 |
+
else:
|
23 |
+
ax.scatter(ps[0], ps[1], ps[2], nodes_colour)
|
24 |
+
ax.set_xlim(0, 63)
|
25 |
+
ax.set_ylim(0, 63)
|
26 |
+
ax.set_zlim(0, 63)
|
27 |
+
remove_tick_labels(ax, True)
|
28 |
+
if add_axes:
|
29 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b')
|
30 |
+
ax.view_init(None, 45)
|
31 |
+
|
32 |
+
return ax
|
33 |
+
|
34 |
+
|
35 |
+
def plot_skeleton(img, skeleton, graph, filename='skeleton', extension=['.png']):
|
36 |
+
if not isinstance(extension, list):
|
37 |
+
extension = [extension]
|
38 |
+
# Skeleton
|
39 |
+
f = plt.figure(figsize=(5, 5))
|
40 |
+
ax = f.add_subplot(111, projection='3d')
|
41 |
+
|
42 |
+
coords = np.argwhere(skeleton)
|
43 |
+
i = coords[:, 0]
|
44 |
+
j = coords[:, 1]
|
45 |
+
k = coords[:, 2]
|
46 |
+
|
47 |
+
seg = ax.voxels(img, facecolors=(0., 0., 1., 0.3), label='image')
|
48 |
+
ske = ax.scatter(i, j, k, color='C1', label='skeleton', s=1)
|
49 |
+
ax.set_xlim(0, 63)
|
50 |
+
ax.set_ylim(0, 63)
|
51 |
+
ax.set_zlim(0, 63)
|
52 |
+
remove_tick_labels(ax, True)
|
53 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b')
|
54 |
+
ax.view_init(None, 45)
|
55 |
+
for ex in extension:
|
56 |
+
f.savefig(filename + '_segmentation_skeleton' + ex)
|
57 |
+
|
58 |
+
# Combined
|
59 |
+
ax = _plot_graph(graph, ax, 'r', 'r')
|
60 |
+
|
61 |
+
for ex in extension:
|
62 |
+
f.savefig(filename + '_combined' + ex)
|
63 |
+
plt.close()
|
64 |
+
|
65 |
+
# Graph
|
66 |
+
f = plt.figure(figsize=(5, 5))
|
67 |
+
ax = f.add_subplot(111, projection='3d')
|
68 |
+
|
69 |
+
ax = _plot_graph(graph, ax)
|
70 |
+
|
71 |
+
for ex in extension:
|
72 |
+
f.savefig(filename + '_graph' + ex)
|
73 |
+
plt.close()
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def compare_graphs(graph_0, graph_1, graph_names=None, filename='compare_graphs'):
|
79 |
+
f = plt.figure(figsize=(12, 5))
|
80 |
+
if graph_names is None:
|
81 |
+
graph_names =['graph_0', 'graph_1']
|
82 |
+
else:
|
83 |
+
assert len(graph_names) == 2, 'A different name is expected for each graph'
|
84 |
+
ax = f.add_subplot(131, projection='3d')
|
85 |
+
ax = _plot_graph(graph_0, ax)
|
86 |
+
ax.set_title(graph_names[0], y=-0.01)
|
87 |
+
|
88 |
+
ax = f.add_subplot(132, projection='3d')
|
89 |
+
ax = _plot_graph(graph_1, ax)
|
90 |
+
ax.set_title(graph_names[1])
|
91 |
+
|
92 |
+
ax = f.add_subplot(133, projection='3d')
|
93 |
+
ax = _plot_graph(graph_0, ax, 'C2', 'C2', plot_nodes=False)
|
94 |
+
ax = _plot_graph(graph_1, ax, 'C4', 'C4', plot_nodes=False)
|
95 |
+
legend_elements = [Line2D([0], [0], color='C2', lw=2, label=graph_names[0]),
|
96 |
+
Line2D([0], [0], color='C4', lw=2, label=graph_names[1])]
|
97 |
+
ax.legend(handles=legend_elements)
|
98 |
+
|
99 |
+
f.savefig(filename + '_compare_graphs.png')
|
100 |
+
plt.close()
|
101 |
+
|
102 |
+
|
103 |
+
def plot_cpd_registration_step(iteration, error, X, Y, out_folder, add_axes=True, pdf=True):
|
104 |
+
fig = plt.figure(figsize=(8, 8))
|
105 |
+
ax = fig.add_axes([0, 0, .9, .9], projection='3d')
|
106 |
+
ax.scatter(X[:, 0], X[:, 1], X[:, 2], color='C1', label='Fixed')
|
107 |
+
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], color='C2', label='Moving')
|
108 |
+
|
109 |
+
ax.text2D(0.95, 0.98, 'Iteration: {:d}'.format(
|
110 |
+
iteration), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')
|
111 |
+
#ax.text2D(0.95, 0.90, 'Error: {:10.4f}'.format(
|
112 |
+
# error), horizontalalignment='right', verticalalignment='center', transform=ax.transAxes, fontsize='x-large')
|
113 |
+
ax.legend(loc='upper left', fontsize='x-large')
|
114 |
+
|
115 |
+
if add_axes:
|
116 |
+
x_range = [np.min(np.hstack([X[:, 0], Y[:, 0]])), np.max(np.hstack([X[:, 0], Y[:, 0]]))]
|
117 |
+
y_range = [np.min(np.hstack([X[:, 1], Y[:, 1]])), np.max(np.hstack([X[:, 1], Y[:, 1]]))]
|
118 |
+
z_range = [np.min(np.hstack([X[:, 2], Y[:, 2]])), np.max(np.hstack([X[:, 2], Y[:, 2]]))]
|
119 |
+
ax.set_xlim(x_range[0], x_range[1])
|
120 |
+
ax.set_ylim(y_range[0], y_range[1])
|
121 |
+
ax.set_zlim(z_range[0], z_range[1])
|
122 |
+
|
123 |
+
remove_tick_labels(ax, True)
|
124 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3)
|
125 |
+
ax.view_init(None, 45)
|
126 |
+
|
127 |
+
os.makedirs(out_folder, exist_ok=True)
|
128 |
+
fig.savefig(os.path.join(out_folder, '{:04d}.png'.format(iteration)))
|
129 |
+
if pdf:
|
130 |
+
fig.savefig(os.path.join(out_folder, '{:04d}.pdf'.format(iteration)))
|
131 |
+
plt.close()
|
132 |
+
|
133 |
+
|
134 |
+
def plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, file_name):
|
135 |
+
fig = plt.figure(figsize=(8, 8))
|
136 |
+
ax = fig.add_axes([0, 0, .9, .9], projection='3d')
|
137 |
+
ax.scatter(fix_pts[:, 0], fix_pts[:, 1], fix_pts[:, 2], color='C1', label='Fixed')
|
138 |
+
ax.scatter(mov_pts[:, 0], mov_pts[:, 1], mov_pts[:, 2], color='C2', label='Moving')
|
139 |
+
ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='none', s=100, edgecolor='b', label='Centroid')
|
140 |
+
ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='none', s=100, edgecolor='b')
|
141 |
+
ax.scatter(fix_centroid[0], fix_centroid[1], fix_centroid[2], color='C1')
|
142 |
+
ax.scatter(mov_centroid[0], mov_centroid[1], mov_centroid[2], color='C2')
|
143 |
+
|
144 |
+
x_range = [np.min(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]])),
|
145 |
+
np.max(np.hstack([fix_pts[:, 0], mov_pts[:, 0], fix_centroid[0], mov_centroid[0]]))]
|
146 |
+
y_range = [np.min(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]])),
|
147 |
+
np.max(np.hstack([fix_pts[:, 1], mov_pts[:, 1], fix_centroid[1], mov_centroid[1]]))]
|
148 |
+
z_range = [np.min(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]])),
|
149 |
+
np.max(np.hstack([fix_pts[:, 2], mov_pts[:, 2], fix_centroid[2], mov_centroid[2]]))]
|
150 |
+
ax.set_xlim(x_range[0], x_range[1])
|
151 |
+
ax.set_ylim(y_range[0], y_range[1])
|
152 |
+
ax.set_zlim(z_range[0], z_range[1])
|
153 |
+
|
154 |
+
remove_tick_labels(ax, True)
|
155 |
+
add_axes_arrows_3d(ax, x_color='r', y_color='g', z_color='b', arrow_length=25, dist_arrow_text=3)
|
156 |
+
ax.view_init(None, 45)
|
157 |
+
ax.legend(fontsize='xx-large')
|
158 |
+
fig.savefig(file_name + '.png')
|
159 |
+
fig.savefig(file_name + '.pdf')
|
160 |
+
plt.close()
|
161 |
+
|