Commit
·
e1a3fc2
1
Parent(s):
4dfbecb
Added "step" parameters to reduce the interpolation resolution and memory footprint
Browse files
DeepDeformationMapRegistration/utils/misc.py
CHANGED
@@ -55,24 +55,28 @@ class DatasetCopy:
|
|
55 |
class DisplacementMapInterpolator:
|
56 |
def __init__(self,
|
57 |
image_shape=[64, 64, 64],
|
58 |
-
method='rbf'
|
|
|
59 |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
60 |
self.method = method
|
61 |
self.image_shape = image_shape
|
|
|
62 |
|
63 |
self.grid = self.__regular_grid()
|
64 |
|
65 |
def __regular_grid(self):
|
66 |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
67 |
-
yy = np.linspace(0, self.image_shape[
|
68 |
-
zz = np.linspace(0, self.image_shape[
|
69 |
|
70 |
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
71 |
|
72 |
-
return np.stack([xx.
|
|
|
|
|
73 |
|
74 |
def __call__(self, disp_map, interp_points, backwards=False):
|
75 |
-
disp_map = disp_map.reshape([-1, 3])
|
76 |
grid_pts = self.grid.copy()
|
77 |
if backwards:
|
78 |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|
|
|
55 |
class DisplacementMapInterpolator:
|
56 |
def __init__(self,
|
57 |
image_shape=[64, 64, 64],
|
58 |
+
method='rbf',
|
59 |
+
step=1):
|
60 |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
61 |
self.method = method
|
62 |
self.image_shape = image_shape
|
63 |
+
self.step = step # If to use every point or even N-th point
|
64 |
|
65 |
self.grid = self.__regular_grid()
|
66 |
|
67 |
def __regular_grid(self):
|
68 |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
69 |
+
yy = np.linspace(0, self.image_shape[1], self.image_shape[1], endpoint=False, dtype=np.uint16)
|
70 |
+
zz = np.linspace(0, self.image_shape[2], self.image_shape[2], endpoint=False, dtype=np.uint16)
|
71 |
|
72 |
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
73 |
|
74 |
+
return np.stack([xx[::self.step, ::self.step, ::self.step].flatten(),
|
75 |
+
yy[::self.step, ::self.step, ::self.step].flatten(),
|
76 |
+
zz[::self.step, ::self.step, ::self.step].flatten()], axis=0).T
|
77 |
|
78 |
def __call__(self, disp_map, interp_points, backwards=False):
|
79 |
+
disp_map = disp_map.squeeze()[::self.step, ::self.step, ::self.step, ...].reshape([-1, 3])
|
80 |
grid_pts = self.grid.copy()
|
81 |
if backwards:
|
82 |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|