|
import os, sys |
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
import h5py |
|
from tqdm import tqdm |
|
from functools import partial |
|
import numpy as np |
|
from scipy.spatial.distance import euclidean |
|
import pandas as pd |
|
from EvaluationScripts.Evaluate_class import resize_img_to_original_space, resize_pts_to_original_space |
|
from Centerline.visualization_utils import plot_cpd_registration_step, plot_cpd |
|
from Centerline.cpd_utils import cpd_non_rigid_transform_pt, radial_basis_function, deform_registration, rigid_registration |
|
from scipy.spatial.distance import cdist |
|
import re |
|
|
|
DATASET_LOCATION = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/dataset/EVAL' |
|
DATASET_NAMES = ['None', 'Affine', 'None', 'Translation'] |
|
DATASET_FILENAME = 'points' |
|
|
|
OUT_IMG_FOLDER = '/mnt/EncryptedData1/Users/javier/vessel_registration/Centerline/cpd/dense_final' |
|
|
|
SCALE = 1e-2 |
|
|
|
|
|
MAX_ITER = 200 |
|
ALPHA = 2. |
|
BETA = 2. |
|
TOLERANCE = 1e-8 |
|
RBF_FUNCTION='thin-plate' |
|
|
|
if __name__ == '__main__': |
|
for dataset_name in DATASET_NAMES: |
|
dataset_loc = os.path.join(DATASET_LOCATION, dataset_name) |
|
dataset_files = os.listdir(dataset_loc) |
|
dataset_files.sort() |
|
dataset_files = [os.path.join(dataset_loc, f) for f in dataset_files if DATASET_FILENAME in f] |
|
|
|
iterator = tqdm(dataset_files) |
|
df = pd.DataFrame(columns=['DATASET', |
|
'ITERATIONS_DEF', 'ITERATIONS_R_DEF__R', 'ITERATIONS_R_DEF__DEF', |
|
'TIME_DEF', 'TIME_R_DEF', |
|
'Q_DEF', 'Q_R_DEF__R', 'Q_R_DEF__DEF', |
|
'TRE_DEF', 'TRE_R_DEF', |
|
'DS_DISP', |
|
'DATA_PATH', |
|
'DIST_CENTR', 'DIST_CENTR_DEF_95', 'SAMPLE_NUM']) |
|
for i, file_path in enumerate(iterator): |
|
fn = os.path.split(file_path)[-1].split('.hd5')[0] |
|
fnum = int(re.findall('(\d+)', fn)[0]) |
|
iterator.set_description('{}: start'.format(fn)) |
|
pts_file = h5py.File(file_path, 'r') |
|
fix_pts = pts_file['fix/points'][:] |
|
|
|
fix_centroid = pts_file['fix/centroid'][:] |
|
|
|
mov_pts = pts_file['mov/points'][:] |
|
|
|
mov_centroid = pts_file['mov/centroid'][:] |
|
|
|
bbox = pts_file['parameters/bbox'][:] |
|
first_reshape = pts_file['parameters/first_reshape'][:] |
|
original_shape = pts_file['parameters/isotropic_shape'][:] |
|
iterator.set_description('{}: Loaded data'.format(fn)) |
|
|
|
|
|
|
|
fix_pts = resize_pts_to_original_space(fix_pts, bbox, [64] * 3, first_reshape, original_shape) |
|
fix_centroid = resize_pts_to_original_space(fix_centroid, bbox, [64] * 3, first_reshape, original_shape) |
|
|
|
mov_pts = resize_pts_to_original_space(mov_pts, bbox, [64] * 3, first_reshape, original_shape) |
|
mov_centroid = resize_pts_to_original_space(mov_centroid, bbox, [64] * 3, first_reshape, original_shape) |
|
iterator.set_description('{}: reshaped data'.format(fn)) |
|
|
|
ill_cond_def = False |
|
ill_cond_r_def = False |
|
|
|
iterator.set_description('{}: Computing only deformable reg.'.format(fn)) |
|
|
|
|
|
|
|
|
|
time_def, deform_reg_def = deform_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True, |
|
tolerance=TOLERANCE, max_iterations=MAX_ITER, |
|
alpha=ALPHA, beta=BETA) |
|
if np.isnan(deform_reg_def.diff): |
|
tre_def = np.nan |
|
pred_mov_centroid = np.zeros((3,)) |
|
else: |
|
tps, ill_cond_def = radial_basis_function(mov_pts, np.dot(*deform_reg_def.get_registration_parameters()) / SCALE, RBF_FUNCTION) |
|
displacement_mov_centroid = tps(mov_centroid) |
|
pred_mov_centroid = mov_centroid + displacement_mov_centroid |
|
|
|
tre_def = euclidean(pred_mov_centroid, fix_centroid) |
|
|
|
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/DEF'.format(dataset_name, fnum)) |
|
os.makedirs(plot_file, exist_ok=True) |
|
plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration') |
|
plot_cpd(fix_pts, deform_reg_def.TY/SCALE, fix_centroid, pred_mov_centroid, plot_file + '/after_registration') |
|
|
|
|
|
iterator.set_description('{}: Computing rigid and deform. reg.'.format(fn)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_r_def__r, rigid_reg_r_def = rigid_registration(fix_pts*SCALE, mov_pts*SCALE, time_it=True) |
|
rigid_yt = rigid_reg_r_def.TY |
|
time_r_def__def, deform_reg_r_def = deform_registration(fix_pts*SCALE, rigid_yt, time_it=True, |
|
tolerance=TOLERANCE, max_iterations=MAX_ITER, |
|
alpha=ALPHA, beta=BETA) |
|
|
|
if np.isnan(deform_reg_r_def.diff): |
|
pred_mov_centroid = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE |
|
else: |
|
mov_centroid_t = rigid_reg_r_def.transform_point_cloud(mov_centroid*SCALE)/SCALE |
|
tps, ill_cond_r_def = radial_basis_function(rigid_yt / SCALE, np.dot(*deform_reg_r_def.get_registration_parameters()) / SCALE, RBF_FUNCTION) |
|
displacement_mov_centroid_t = tps(mov_centroid_t) |
|
pred_mov_centroid = mov_centroid_t + displacement_mov_centroid_t |
|
|
|
tre_r_def = euclidean(pred_mov_centroid, fix_centroid) |
|
dist_centroid_to_pts = cdist(mov_centroid[np.newaxis, ...], mov_pts) |
|
|
|
plot_file = os.path.join(OUT_IMG_FOLDER, '{}/{:04d}/RIGID_DEF'.format(dataset_name, fnum)) |
|
os.makedirs(plot_file, exist_ok=True) |
|
plot_cpd(fix_pts, mov_pts, fix_centroid, mov_centroid, plot_file + '/before_registration') |
|
plot_cpd(fix_pts, deform_reg_r_def.TY, fix_centroid, pred_mov_centroid, plot_file + '/after_registration') |
|
|
|
iterator.set_description('{}: Saving data'.format(fn)) |
|
df = df.append({'DATASET': dataset_name, |
|
'ITERATIONS_DEF': deform_reg_def.iteration, |
|
'ITERATIONS_R_DEF__R': rigid_reg_r_def.iteration, |
|
'ITERATIONS_R_DEF__DEF': deform_reg_r_def.iteration, |
|
'TIME_DEF': time_def, |
|
'TIME_R_DEF': time_r_def__r + time_r_def__def, |
|
'Q_DEF': deform_reg_def.diff, |
|
'Q_R_DEF__R': rigid_reg_r_def.q, |
|
'Q_R_DEF__DEF': deform_reg_r_def.diff, |
|
'ILL_COND_DEF': ill_cond_def, |
|
'ILL_COND_R_DEF': ill_cond_r_def, |
|
'TRE_DEF': tre_def, 'TRE_R_DEF': tre_r_def, |
|
'DS_DISP':euclidean(mov_centroid, fix_centroid), |
|
'DATA_PATH': file_path, |
|
'DIST_CENTR': np.min(dist_centroid_to_pts), |
|
'DIST_CENTR_DEF_95': np.percentile(dist_centroid_to_pts, 95), |
|
'SAMPLE_NUM': fnum}, ignore_index=True) |
|
pts_file.close() |
|
|
|
df.to_csv(os.path.join(OUT_IMG_FOLDER, 'cpd_{}.csv'.format(dataset_name))) |
|
|