added to flag to output the displacement map (takes long to resize back to the original resolution)
c292437
import os | |
import errno | |
import shutil | |
import numpy as np | |
from scipy.interpolate import griddata, Rbf, LinearNDInterpolator, NearestNDInterpolator | |
from skimage.measure import regionprops | |
from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline | |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines | |
from tensorflow import squeeze | |
from scipy.ndimage import zoom | |
import tensorflow as tf | |
def try_mkdir(dir, verbose=True): | |
try: | |
os.makedirs(dir) | |
except OSError as err: | |
if err.errno == errno.EEXIST and verbose: | |
print("Directory " + dir + " already exists") | |
else: | |
raise ValueError("Can't create dir " + dir) | |
else: | |
print("Created directory " + dir) | |
def function_decorator(new_name): | |
"""" | |
Change the __name__ property of a function using new_name. | |
:param new_name: | |
:return: | |
""" | |
def decorator(func): | |
func.__name__ = new_name | |
return func | |
return decorator | |
class DatasetCopy: | |
def __init__(self, dataset_location, copy_location=None, verbose=True): | |
self.__copy_loc = os.path.join(os.getcwd(), 'temp_dataset') if copy_location is None else copy_location | |
self.__dst_loc = dataset_location | |
self.__verbose = verbose | |
def copy_dataset(self): | |
shutil.copytree(self.__dst_loc, self.__copy_loc) | |
if self.__verbose: | |
print('{} copied to {}'.format(self.__dst_loc, self.__copy_loc)) | |
return self.__copy_loc | |
def delete_temp(self): | |
shutil.rmtree(self.__copy_loc) | |
if self.__verbose: | |
print('Deleted: ', self.__copy_loc) | |
class DisplacementMapInterpolator: | |
def __init__(self, | |
image_shape=[64, 64, 64], | |
method='rbf', | |
step=1): | |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'" | |
self.method = method | |
self.image_shape = image_shape | |
self.step = step # If to use every point or even N-th point | |
self.grid = self.__regular_grid() | |
def __regular_grid(self): | |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16) | |
yy = np.linspace(0, self.image_shape[1], self.image_shape[1], endpoint=False, dtype=np.uint16) | |
zz = np.linspace(0, self.image_shape[2], self.image_shape[2], endpoint=False, dtype=np.uint16) | |
xx, yy, zz = np.meshgrid(xx, yy, zz) | |
return np.stack([xx[::self.step, ::self.step, ::self.step].flatten(), | |
yy[::self.step, ::self.step, ::self.step].flatten(), | |
zz[::self.step, ::self.step, ::self.step].flatten()], axis=0).T | |
def __call__(self, disp_map, interp_points, backwards=False): | |
disp_map = disp_map.squeeze()[::self.step, ::self.step, ::self.step, ...].reshape([-1, 3]) | |
grid_pts = self.grid.copy() | |
if backwards: | |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32) | |
disp_map *= -1 | |
if self.method == 'rbf': | |
interpolator = Rbf(grid_pts[:, 0], grid_pts[:, 1], grid_pts[:, 2], disp_map[:, :], | |
method='thin_plate', mode='N-D') | |
disp = interpolator(interp_points) | |
elif self.method == 'griddata': | |
linear_interp = LinearNDInterpolator(grid_pts, disp_map) | |
disp = linear_interp(interp_points).copy() | |
del linear_interp | |
if np.any(np.isnan(disp)): | |
# It might happen (though it shouldn't) that the interpolation point is outside the convex hull of grid points. | |
# in this situation, linear interpolation fails and will put NaN. Nearest can give a value, so we are going to | |
# substitute those unexpected NaNs with the nearest value. Unexpected == not in interp_points | |
nan_disp_idx = set(np.unique(np.argwhere(np.isnan(disp))[:, 0])) | |
nan_interp_pts_idx = set(np.unique(np.argwhere(np.isnan(interp_points))[:, 0])) | |
idx = nan_disp_idx - nan_interp_pts_idx if len(nan_disp_idx) > len(nan_interp_pts_idx) else nan_interp_pts_idx - nan_disp_idx | |
idx = list(idx) | |
if len(idx): | |
# We have unexpected NaNs | |
near_interp = NearestNDInterpolator(grid_pts, disp_map) | |
near_disp = near_interp(interp_points[idx, ...]).copy() | |
del near_interp | |
for n, i in enumerate(idx): | |
disp[i, ...] = near_disp[n, ...] | |
elif self.method == 'tf': | |
# Order: 1 -> linear, 2 -> thin plate, 3 -> cubic | |
disp = squeeze(interpolate_spline(grid_pts[np.newaxis, ...][::4, :], # Batch axis | |
disp_map[np.newaxis, ...][::4, :], | |
interp_points[np.newaxis, ...], order=2), axis=0) | |
else: | |
tps_interp = ThinPlateSplines(grid_pts[::8, :], self.grid.copy().astype(np.float32)[::8, :]) | |
disp = tps_interp.interpolate(interp_points).eval() | |
del tps_interp | |
return disp | |
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(1, 28), missing_centroid=[np.nan]*3, brain_study=True): | |
segmentations = np.squeeze(segmentations) | |
if ohe: | |
segmentations = segmentation_ohe_to_cardinal(segmentations) | |
lbls = set(np.unique(segmentations)) - {0} # Remove the 0 value returned by np.unique, no label | |
# missing_lbls = set(expected_lbls) - lbls | |
# if brain_study: | |
# segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1 | |
else: | |
lbls = set(np.unique(segmentations)) if 0 in expected_lbls else set(np.unique(segmentations)) - {0} | |
missing_lbls = set(expected_lbls) - lbls | |
if 0 in expected_lbls: | |
segmentations += np.ones_like(segmentations) # Regionsprops neglects the label 0. But we need it, so offset all labels by 1 | |
segmentations = np.squeeze(segmentations) # remove channel dimension, not needed anyway | |
seg_props = regionprops(segmentations) | |
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32) | |
for lbl in missing_lbls: | |
idx = expected_lbls.index(lbl) | |
centroids = np.insert(centroids, idx, missing_centroid, axis=0) | |
return centroids.copy(), missing_lbls | |
def segmentation_ohe_to_cardinal(segmentation): | |
cpy = segmentation.copy() | |
for lbl in range(segmentation.shape[-1]): | |
cpy[..., lbl] *= (lbl + 1) | |
# Add the Background | |
cpy = np.concatenate([np.zeros(segmentation.shape[:-1])[..., np.newaxis], cpy], axis=-1) | |
return np.argmax(cpy, axis=-1)[..., np.newaxis] | |
def segmentation_cardinal_to_ohe(segmentation, labels_list: list = None): | |
# Keep in mind that we don't handle the overlap between the segmentations! | |
#labels_list = np.unique(segmentation)[1:] if labels_list is None else labels_list | |
num_labels = len(labels_list) | |
expected_shape = segmentation.shape[:-1] + (num_labels,) | |
cpy = np.zeros(expected_shape, dtype=np.uint8) | |
seg_squeezed = np.squeeze(segmentation, axis=-1) | |
for ch, lbl in enumerate(labels_list): | |
cpy[seg_squeezed == lbl, ch] = 1 | |
return cpy | |
def resize_displacement_map(displacement_map: np.ndarray, dest_shape: [list, np.ndarray, tuple], scale_trf: np.ndarray = None, resolution_factors: [tuple, np.ndarray] = np.ones((3,))): | |
if scale_trf is None: | |
scale_trf = scale_transformation(displacement_map.shape, dest_shape) | |
else: | |
assert isinstance(scale_trf, np.ndarray) and scale_trf.shape == (4, 4), 'Invalid transformation: {}'.format(scale_trf) | |
zoom_factors = scale_trf.diagonal() | |
# First scale the values, so we cut down the number of multiplications | |
dm_resized = np.copy(displacement_map) | |
# Then rescale using zoom | |
dm_resized = zoom(dm_resized, zoom_factors) | |
dm_resized *= np.asarray(resolution_factors) | |
# dm_resized[..., 0] *= resolution_factors[0] | |
# dm_resized[..., 1] *= resolution_factors[1] | |
# dm_resized[..., 2] *= resolution_factors[2] | |
return dm_resized | |
def scale_transformation(original_shape: [list, tuple, np.ndarray], dest_shape: [list, tuple, np.ndarray]) -> np.ndarray: | |
if isinstance(original_shape, (list, tuple)): | |
original_shape = np.asarray(original_shape, dtype=int) | |
if isinstance(dest_shape, (list, tuple)): | |
dest_shape = np.asarray(dest_shape, dtype=int) | |
original_shape = original_shape.astype(int) | |
dest_shape = dest_shape.astype(int) | |
trf = np.eye(4) | |
np.fill_diagonal(trf, [*np.divide(dest_shape, original_shape), 1]) | |
return trf | |
class GaussianFilter: | |
def __init__(self, size, sigma, dim, num_channels, stride=None, batch: bool=True): | |
""" | |
Gaussian filter | |
:param size: Kernel size | |
:param sigma: Sigma of the Gaussian filter. | |
:param dim: Data dimensionality. Must be {2, 3}. | |
:param num_channels: Number of channels of the image to filter. | |
""" | |
self.size = size | |
self.dim = dim | |
self.sigma = float(sigma) | |
self.num_channels = num_channels | |
self.stride = size // 2 if stride is None else int(stride) | |
if batch: | |
self.stride = [1] + [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims | |
else: | |
self.stride = [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims | |
self.convDN = getattr(tf.nn, 'conv%dd' % dim) | |
self.__GF = None | |
self.__build_gaussian_filter() | |
def __build_gaussian_filter(self): | |
range_1d = tf.range(-(self.size/2) + 1, self.size//2 + 1) | |
g_1d = tf.math.exp(-1.0 * tf.pow(range_1d, 2) / (2. * tf.pow(self.sigma, 2))) | |
g_1d_expanded = tf.expand_dims(g_1d, -1) | |
iterator = tf.constant(1) | |
self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim), | |
lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)), | |
[iterator, g_1d], | |
[iterator.get_shape(), tf.TensorShape(None)], # Shape invariants | |
back_prop=False | |
)[-1] | |
self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization | |
self.__GF = tf.reshape(self.__GF, (*[self.size]*self.dim, 1, 1)) # Add Ch_in and Ch_out for convolution | |
self.__GF = tf.tile(self.__GF, (*[1] * self.dim, self.num_channels, self.num_channels,)) | |
def apply_filter(self, in_image): | |
return self.convDN(in_image, self.__GF, self.stride, 'SAME') | |
def kernel(self): | |
return self.__GF |