|
import os |
|
import errno |
|
import shutil |
|
import numpy as np |
|
from scipy.interpolate import griddata, Rbf, LinearNDInterpolator, NearestNDInterpolator |
|
from skimage.measure import regionprops |
|
from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline |
|
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines |
|
from tensorflow import squeeze |
|
|
|
|
|
def try_mkdir(dir, verbose=True): |
|
try: |
|
os.makedirs(dir) |
|
except OSError as err: |
|
if err.errno == errno.EEXIST and verbose: |
|
print("Directory " + dir + " already exists") |
|
else: |
|
raise ValueError("Can't create dir " + dir) |
|
else: |
|
print("Created directory " + dir) |
|
|
|
|
|
def function_decorator(new_name): |
|
"""" |
|
Change the __name__ property of a function using new_name. |
|
:param new_name: |
|
:return: |
|
""" |
|
def decorator(func): |
|
func.__name__ = new_name |
|
return func |
|
return decorator |
|
|
|
|
|
class DatasetCopy: |
|
def __init__(self, dataset_location, copy_location=None, verbose=True): |
|
self.__copy_loc = os.path.join(os.getcwd(), 'temp_dataset') if copy_location is None else copy_location |
|
self.__dst_loc = dataset_location |
|
self.__verbose = verbose |
|
|
|
def copy_dataset(self): |
|
shutil.copytree(self.__dst_loc, self.__copy_loc) |
|
if self.__verbose: |
|
print('{} copied to {}'.format(self.__dst_loc, self.__copy_loc)) |
|
return self.__copy_loc |
|
|
|
def delete_temp(self): |
|
shutil.rmtree(self.__copy_loc) |
|
if self.__verbose: |
|
print('Deleted: ', self.__copy_loc) |
|
|
|
|
|
class DisplacementMapInterpolator: |
|
def __init__(self, |
|
image_shape=[64, 64, 64], |
|
method='rbf'): |
|
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'" |
|
self.method = method |
|
self.image_shape = image_shape |
|
|
|
self.grid = self.__regular_grid() |
|
|
|
def __regular_grid(self): |
|
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16) |
|
yy = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16) |
|
zz = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16) |
|
|
|
xx, yy, zz = np.meshgrid(xx, yy, zz) |
|
|
|
return np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T |
|
|
|
def __call__(self, disp_map, interp_points, backwards=False): |
|
disp_map = disp_map.reshape([-1, 3]) |
|
grid_pts = self.grid.copy() |
|
if backwards: |
|
grid_pts = np.add(grid_pts, disp_map).astype(np.float32) |
|
disp_map *= -1 |
|
|
|
if self.method == 'rbf': |
|
interpolator = Rbf(grid_pts[:, 0], grid_pts[:, 1], grid_pts[:, 2], disp_map[:, :], |
|
method='thin_plate', mode='N-D') |
|
disp = interpolator(interp_points) |
|
elif self.method == 'griddata': |
|
linear_interp = LinearNDInterpolator(grid_pts, disp_map) |
|
disp = linear_interp(interp_points).copy() |
|
del linear_interp |
|
|
|
if np.any(np.isnan(disp)): |
|
|
|
|
|
|
|
nan_disp_idx = set(np.unique(np.argwhere(np.isnan(disp))[:, 0])) |
|
nan_interp_pts_idx = set(np.unique(np.argwhere(np.isnan(interp_points))[:, 0])) |
|
idx = nan_disp_idx - nan_interp_pts_idx if len(nan_disp_idx) > len(nan_interp_pts_idx) else nan_interp_pts_idx - nan_disp_idx |
|
idx = list(idx) |
|
if len(idx): |
|
|
|
near_interp = NearestNDInterpolator(grid_pts, disp_map) |
|
near_disp = near_interp(interp_points[idx, ...]).copy() |
|
del near_interp |
|
for n, i in enumerate(idx): |
|
disp[i, ...] = near_disp[n, ...] |
|
elif self.method == 'tf': |
|
|
|
disp = squeeze(interpolate_spline(grid_pts[np.newaxis, ...][::4, :], |
|
disp_map[np.newaxis, ...][::4, :], |
|
interp_points[np.newaxis, ...], order=2), axis=0) |
|
else: |
|
tps_interp = ThinPlateSplines(grid_pts[::8, :], self.grid.copy().astype(np.float32)[::8, :]) |
|
disp = tps_interp.interpolate(interp_points).eval() |
|
del tps_interp |
|
|
|
return disp |
|
|
|
|
|
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(0, 28), missing_centroid=[np.nan]*3, brain_study=True): |
|
segmentations = np.squeeze(segmentations) |
|
if ohe: |
|
segmentations = np.sum(segmentations, axis=-1).astype(np.uint8) |
|
missing_lbls = set(expected_lbls) - set(np.unique(segmentations)) |
|
if brain_study: |
|
segmentations += np.ones_like(segmentations) |
|
else: |
|
missing_lbls = set(expected_lbls) - set(np.unique(segmentations)) |
|
|
|
seg_props = regionprops(segmentations) |
|
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32) |
|
|
|
for lbl in missing_lbls: |
|
idx = expected_lbls.index(lbl) |
|
centroids = np.insert(centroids, idx, missing_centroid, axis=0) |
|
return centroids.copy(), missing_lbls |
|
|
|
|
|
def segmentation_ohe_to_cardinal(segmentation): |
|
cpy = segmentation.copy() |
|
for lbl in range(segmentation.shape[-1]): |
|
cpy[..., lbl] *= (lbl + 1) |
|
|
|
cpy = np.concatenate([np.zeros(segmentation.shape[:-1])[..., np.newaxis], cpy], axis=-1) |
|
return np.argmax(cpy, axis=-1)[..., np.newaxis] |
|
|
|
|
|
def segmentation_cardinal_to_ohe(segmentation): |
|
|
|
cpy = np.tile(np.zeros_like(segmentation), (1, 1, 1, len(np.unique(segmentation)[1:]))) |
|
for ch, lbl in enumerate(np.unique(segmentation)[1:]): |
|
cpy[segmentation == lbl, ch] = 1 |
|
return cpy |
|
|