DDMR / Centerline /evaluate_BayesianCPD_skeleton.py
jpdefrutos's picture
CPD scripts
b10768a
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
import h5py
from tqdm import tqdm
from functools import partial
import numpy as np
from scipy.spatial.distance import euclidean
import pandas as pd
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration
from scipy.spatial.distance import cdist
from skimage.morphology import skeletonize_3d
import re
from probreg import bcpd
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL'
DATASET_NAMES = ['Affine', 'None', 'Translation']
DATASET_FILENAME = 'points'
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/skeleton'
SCALE = 1e-2 # mm to cm
# CPD PARAMS (deform)
MAX_ITER = 200
ALPHA = 0.1
BETA = 1.0 # None = Use default
TOLERANCE = 1e-8
if __name__ == '__main__':
for dataset_name in DATASET_NAMES:
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name)
dataset_files = os.listdir(dataset_loc)
dataset_files.sort()
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f]
iterator = tqdm(dataset_files)
df = pd.DataFrame(columns=['DATASET',
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF',
'TIME_DEF', 'TIME_R_DEF',
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF',
'TRE_DEF', 'TRE_R_DEF',
'DS_DISP',
'DATA_PATH',
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM'])
for i, file_path in enumerate(iterator):
fn = os.path.split(file_path)[-1].split('.hd5')[0]
fnum = int(re.findall('(\d+)', fn)[0])
iterator.set_description('{}: start'.format(fn))
pts_file = h5py.File(file_path, 'r')
# fix_pts = pts_file['fix/points'][:]
# fix_nodes = pts_file['fix/nodes'][:]
fix_skel = pts_file['fix/skeleton'][:]
fix_centroid = pts_file['fix/centroid'][:]
# mov_pts = pts_file['mov/points'][:]
# mov_nodes = pts_file['mov/nodes'][:]
mov_skel = pts_file['mov/skeleton'][:]
mov_centroid = pts_file['mov/centroid'][:]
bbox = pts_file['parameters/bbox'][:]
first_reshape = pts_file['parameters/first_reshape'][:]
isotropic_shape = pts_file['parameters/isotropic_shape'][:]
iterator.set_description('{}: Loaded data'.format(fn))
# TODO: bring back to original shape!
# Reshape to original_shape
# fix_nodes = resize_pts_to_original_space(fix_nodes, bbox, [64]*3, first_reshape, original_shape)
# fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape)
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
fix_skel = resize_img_to_original_space(fix_skel, bbox, first_reshape, isotropic_shape)
fix_skel = skeletonize_3d(fix_skel)
fix_skel_pts = np.argwhere(fix_skel)
# mov_nodes = resize_pts_to_original_space(mov_nodes, bbox, [64] * 3, first_reshape, original_shape)
# mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape)
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, isotropic_shape)
mov_skel = resize_img_to_original_space(mov_skel, bbox, first_reshape, isotropic_shape)
mov_skel = skeletonize_3d(mov_skel)
mov_skel_pts = np.argwhere(mov_skel)
iterator.set_description('{}: reshaped data'.format(fn))
ill_cond_def = False
ill_cond_r_def = False
# Deformable only
iterator.set_description('{}: Computing only deformable reg.'.format(fn))
tf_param = bcpd.registration_bcpd(mov_skel_pts*SCALE, fix_skel_pts*SCALE)
if np.isnan(deform_reg_def.diff):
tre_def = np.nan
pred_mov_centroid = mov_centroid
else:
tps, ill_cond_def = radial_basis_function(mov_skel_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE)
displacement_mov_centroid = tps(mov_centroid)
pred_mov_centroid = mov_centroid + displacement_mov_centroid
tre_def = euclidean(pred_mov_centroid, fix_centroid)
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum))
os.makedirs(plot_file, exist_ok=True)
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
plot_cpd(fix_skel_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
# Rigid followed by deformable
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn))
rigid_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/rigid'.format(dataset_name, fnum)))
deform_cb = partial(plot_cpd_registration_step, out_folder=os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF/deform'.format(dataset_name, fnum)))
# rigid_yt, rigid_p, rigid_reg_r_def = rigid_registration(fix_pts, mov_pts, rigid_cb)
# deform_yt, deform_p, deform_reg_r_def = deform_registration(fix_pts, rigid_yt, deform_cb)
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_skel_pts*SCALE, mov_skel_pts*SCALE, time_it=True)
rigid_yt = rigid_reg_r_def.TY
time_r_def__def, deform_reg_r_def = deform_registration(fix_skel_pts*SCALE, rigid_yt, time_it=True,
tolerance=TOLERANCE, max_iterations=MAX_ITER,
alpha=ALPHA, beta=BETA)
if np.isnan(deform_reg_r_def.diff):
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
else:
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE,
np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE)
displacement_mov_centroid_t = tps(mov_centroid_t)
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t
tre_r_def = euclidean(pred_mov_centroid, fix_centroid)
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_skel_pts)
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum))
os.makedirs(plot_file, exist_ok=True)
plot_cpd(fix_skel_pts, mov_skel_pts, fix_centroid, mov_centroid, plot_file + '/before_registration')
plot_cpd(fix_skel_pts, deform_reg_r_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration')
iterator.set_description('{}: Saving data'.format(fn))
df = df.append({'DATASET': dataset_name,
'ITERATIONS_DEF': deform_reg_def.iteration,
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration,
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration,
'TIME_DEF': time_def,
'TIME_R_DEF': time_r_def__r + time_r_def__def,
'Q_DEF': deform_reg_def.diff,
'Q_R_DEF__R': rigid_reg_r_def.q,
'Q_R_DEF__DEF': deform_reg_r_def.diff,
'ILL_COND_DEF': ill_cond_def,
'ILL_COND_R_DEF': ill_cond_r_def,
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def,
'DS_DISP':euclidean(mov_centroid, fix_centroid),
'DATA_PATH': file_path,
'DIST_CENTR': np.min(dist_centroid_to_pts),
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95),
'SAMPLE_NUM':fnum}, ignore_index=True)
pts_file.close()
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name)))