|
import os, sys |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
import h5py |
|
from tqdm import tqdm |
|
from functools import partial |
|
import numpy as np |
|
from scipy.spatial.distance import euclidean |
|
import pandas as pd |
|
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space |
|
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd |
|
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration |
|
from scipy.spatial.distance import cdist |
|
from skimage.morphology import skeletonize_3d |
|
import re |
|
from probreg import bcpd |
|
|
|
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL' |
|
DATASET_NAMES = ['Affine', 'None', 'Translation'] |
|
DATASET_FILENAME = 'points' |
|
|
|
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton' |
|
|
|
SCALE = 1e-2 |
|
|
|
MAX_ITER = 200 |
|
ALPHA = 0.1 |
|
BETA = 1.0 |
|
TOLERANCE = 1e-8 |
|
|
|
if __name__ == '__main__': |
|
for dataset_name in DATASET_NAMES: |
|
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name) |
|
dataset_files = os.listdir(dataset_loc) |
|
dataset_files.sort() |
|
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f] |
|
|
|
iterator = tqdm(dataset_files) |
|
df = pd.DataFrame(columns=['DATASET', |
|
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF', |
|
'TIME_DEF', 'TIME_R_DEF', |
|
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF', |
|
'TRE_DEF', 'TRE_R_DEF', |
|
'DS_DISP', |
|
'DATA_PATH', |
|
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM']) |
|
for i, file_path in enumerate(iterator): |
|
fn = os.path.split(file_path)[-1].split('.hd5')[0] |
|
fnum = int(re.findall('(\d+)', fn)[0]) |
|
iterator.set_description('{}: start'.format(fn)) |
|
pts_file = h5py.File(file_path, 'r') |
|
|
|
|
|
fix_skel = pts_file['fix/skeleton'][:] |
|
fix_centroid = pts_file['fix/centroid'][:] |
|
|
|
|
|
|
|
mov_skel = pts_file['mov/skeleton'][:] |
|
mov_centroid = pts_file['mov/centroid'][:] |
|
|
|
bbox = pts_file['parameters/bbox'][:] |
|
first_reshape = pts_file['parameters/first_reshape'][:] |
|
isotropic_shape = pts_file['parameters/isotropic_shape'][:] |
|
iterator.set_description('{}: Loaded data'.format(fn)) |
|
|
|
|
|
|
|
|
|
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape) |
|
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape) |
|
fix_skel = skeletonize_3d(fix_skel) |
|
fix_skel_pts = np.argwhere(fix_skel) |
|
|
|
|
|
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape) |
|
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape) |
|
mov_skel = skeletonize_3d(mov_skel) |
|
mov_skel_pts = np.argwhere(mov_skel) |
|
iterator.set_description('{}: reshaped data'.format(fn)) |
|
|
|
ill_cond_def = False |
|
ill_cond_r_def = False |
|
|
|
iterator.set_description('{}: Computing only deformable reg.'.format(fn)) |
|
|
|
tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE) |
|
|
|
if np.isnan(deform_reg_def.diff): |
|
tre_def = np.nan |
|
pred_mov_centroid = mov_centroid |
|
else: |
|
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE) |
|
displacement_mov_centroid = tps(mov_centroid) |
|
pred_mov_centroid = mov_centroid + displacement_mov_centroid |
|
|
|
tre_def = euclidean(pred_mov_centroid, fix_centroid) |
|
|
|
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)) |
|
os.makedirs(plot_file, exist_ok=True) |
|
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration') |
|
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration') |
|
|
|
|
|
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn)) |
|
|
|
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum))) |
|
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum))) |
|
|
|
|
|
|
|
|
|
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True) |
|
rigid_yt = rigid_reg_r_def.TY |
|
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True, |
|
tolerance=TOLERANCE, max_iterations=MAX_ITER, |
|
alpha=ALPHA, beta=BETA) |
|
|
|
if np.isnan(deform_reg_r_def.diff): |
|
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE |
|
else: |
|
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE |
|
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, |
|
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE) |
|
displacement_mov_centroid_t = tps(mov_centroid_t) |
|
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t |
|
|
|
tre_r_def = euclidean(pred_mov_centroid, fix_centroid) |
|
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts) |
|
|
|
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum)) |
|
os.makedirs(plot_file, exist_ok=True) |
|
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration') |
|
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration') |
|
|
|
|
|
iterator.set_description('{}: Saving data'.format(fn)) |
|
df = df.append({'DATASET': dataset_name, |
|
'ITERATIONS_DEF': deform_reg_def.iteration, |
|
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration, |
|
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration, |
|
'TIME_DEF': time_def, |
|
'TIME_R_DEF': time_r_def__r + time_r_def__def, |
|
'Q_DEF': deform_reg_def.diff, |
|
'Q_R_DEF__R': rigid_reg_r_def.q, |
|
'Q_R_DEF__DEF': deform_reg_r_def.diff, |
|
'ILL_COND_DEF': ill_cond_def, |
|
'ILL_COND_R_DEF': ill_cond_r_def, |
|
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def, |
|
'DS_DISP':euclidean(mov_centroid, fix_centroid), |
|
'DATA_PATH': file_path, |
|
'DIST_CENTR': np.min(dist_centroid_to_pts), |
|
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95), |
|
'SAMPLE_NUM':fnum}, ignore_index=True) |
|
pts_file.close() |
|
|
|
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name))) |
|
|