File size: 8,910 Bytes
b10768a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
import h5py
from tqdm import tqdm
from functools import partial
import numpy as np
from scipy.spatial.distance import euclidean
import pandas as pd
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
from scipy.spatial.distance import cdist
from skimage.morphology import skeletonize_3d
import re
from probreg import bcpd
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
DATASET_NAMES = ['Affine', 'None', 'Translation']
DATASET_FILENAME = 'points'
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
SCALE = 1e-2 # mm to cm
# CPD PARAMS (deform)
MAX_ITER = 200
ALPHA = 0.1
BETA = 1.0 # None = Use default
TOLERANCE = 1e-8
if __name__ == '__main__':
for dataset_name in DATASET_NAMES:
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
dataset_files = os.listdir(dataset_loc)
dataset_files.sort()
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
iterator = tqdm(dataset_files)
df = pd.DataFrame(columns=['DATASET',
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
'TIME_DEF', 'TIME_R_DEF',
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
'TRE_DEF', 'TRE_R_DEF',
'DS_DISP',
'DATA_PATH',
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
for i, file_path in enumerate(iterator):
fn = os.path.split(file_path)[-1].split('.hd5')[0]
fnum = int(re.findall('(\d+)', fn)[0])
iterator.set_description('{}: start'.format(fn))
pts_file = h5py.File(file_path, 'r')
# fix_pts = pts_file['fix/points'][:]
# fix_nodes = pts_file['fix/nodes'][:]
fix_skel = pts_file['fix/skeleton'][:]
fix_centroid = pts_file['fix/centroid'][:]
# mov_pts = pts_file['mov/points'][:]
# mov_nodes = pts_file['mov/nodes'][:]
mov_skel = pts_file['mov/skeleton'][:]
mov_centroid = pts_file['mov/centroid'][:]
bbox = pts_file['parameters/bbox'][:]
first_reshape = pts_file['parameters/first_reshape'][:]
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
iterator.set_description('{}: Loaded data'.format(fn))
# TODO: bring back to original shape!
# Reshape to original_shape
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
# fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
fix_skel = skeletonize_3d(fix_skel)
fix_skel_pts = np.argwhere(fix_skel)
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
# mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
mov_skel = skeletonize_3d(mov_skel)
mov_skel_pts = np.argwhere(mov_skel)
iterator.set_description('{}: reshaped data'.format(fn))
ill_cond_def = False
ill_cond_r_def = False
# Deformable only
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE)
if np.isnan(deform_reg_def.diff):
tre_def = np.nan
pred_mov_centroid = mov_centroid
else:
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
displacement_mov_centroid = tps(mov_centroid)
pred_mov_centroid = mov_centroid + displacement_mov_centroid
tre_def = euclidean(pred_mov_centroid, fix_centroid)
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
os.makedirs(plot_file, exist_ok=True)
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
# Rigid followed by deformable
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
rigid_yt = rigid_reg_r_def.TY
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
tolerance=TOLERANCE, max_iterations=MAX_ITER,
alpha=ALPHA, beta=BETA)
if np.isnan(deform_reg_r_def.diff):
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
else:
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
displacement_mov_centroid_t = tps(mov_centroid_t)
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
os.makedirs(plot_file, exist_ok=True)
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
iterator.set_description('{}: Saving data'.format(fn))
df = df.append({'DATASET': dataset_name,
'ITERATIONS_DEF': deform_reg_def.iteration,
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
'TIME_DEF': time_def,
'TIME_R_DEF': time_r_def__r + time_r_def__def,
'Q_DEF': deform_reg_def.diff,
'Q_R_DEF__R': rigid_reg_r_def.q,
'Q_R_DEF__DEF': deform_reg_r_def.diff,
'ILL_COND_DEF': ill_cond_def,
'ILL_COND_R_DEF': ill_cond_r_def,
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
'DS_DISP':euclidean(mov_centroid, fix_centroid),
'DATA_PATH': file_path,
'DIST_CENTR': np.min(dist_centroid_to_pts),
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
'SAMPLE_NUM':fnum}, ignore_index=True)
pts_file.close()
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))
|