Commit
·
e1a3fc2
1
Parent(s):
4dfbecb
Added "step" parameters to reduce the interpolation resolution and memory footprint
Browse files
DeepDeformationMapRegistration/utils/misc.py
CHANGED
|
@@ -55,24 +55,28 @@ class DatasetCopy:
|
|
| 55 |
class DisplacementMapInterpolator:
|
| 56 |
def __init__(self,
|
| 57 |
image_shape=[64, 64, 64],
|
| 58 |
-
method='rbf'
|
|
|
|
| 59 |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
| 60 |
self.method = method
|
| 61 |
self.image_shape = image_shape
|
|
|
|
| 62 |
|
| 63 |
self.grid = self.__regular_grid()
|
| 64 |
|
| 65 |
def __regular_grid(self):
|
| 66 |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
| 67 |
-
yy = np.linspace(0, self.image_shape[
|
| 68 |
-
zz = np.linspace(0, self.image_shape[
|
| 69 |
|
| 70 |
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
| 71 |
|
| 72 |
-
return np.stack([xx.
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def __call__(self, disp_map, interp_points, backwards=False):
|
| 75 |
-
disp_map = disp_map.reshape([-1, 3])
|
| 76 |
grid_pts = self.grid.copy()
|
| 77 |
if backwards:
|
| 78 |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|
|
|
|
| 55 |
class DisplacementMapInterpolator:
|
| 56 |
def __init__(self,
|
| 57 |
image_shape=[64, 64, 64],
|
| 58 |
+
method='rbf',
|
| 59 |
+
step=1):
|
| 60 |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
| 61 |
self.method = method
|
| 62 |
self.image_shape = image_shape
|
| 63 |
+
self.step = step # If to use every point or even N-th point
|
| 64 |
|
| 65 |
self.grid = self.__regular_grid()
|
| 66 |
|
| 67 |
def __regular_grid(self):
|
| 68 |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
| 69 |
+
yy = np.linspace(0, self.image_shape[1], self.image_shape[1], endpoint=False, dtype=np.uint16)
|
| 70 |
+
zz = np.linspace(0, self.image_shape[2], self.image_shape[2], endpoint=False, dtype=np.uint16)
|
| 71 |
|
| 72 |
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
| 73 |
|
| 74 |
+
return np.stack([xx[::self.step, ::self.step, ::self.step].flatten(),
|
| 75 |
+
yy[::self.step, ::self.step, ::self.step].flatten(),
|
| 76 |
+
zz[::self.step, ::self.step, ::self.step].flatten()], axis=0).T
|
| 77 |
|
| 78 |
def __call__(self, disp_map, interp_points, backwards=False):
|
| 79 |
+
disp_map = disp_map.squeeze()[::self.step, ::self.step, ::self.step, ...].reshape([-1, 3])
|
| 80 |
grid_pts = self.grid.copy()
|
| 81 |
if backwards:
|
| 82 |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|