jpdefrutos commited on
Commit
e1a3fc2
·
1 Parent(s): 4dfbecb

Added "step" parameters to reduce the interpolation resolution and memory footprint

Browse files
DeepDeformationMapRegistration/utils/misc.py CHANGED
@@ -55,24 +55,28 @@ class DatasetCopy:
55
  class DisplacementMapInterpolator:
56
  def __init__(self,
57
  image_shape=[64, 64, 64],
58
- method='rbf'):
 
59
  assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
60
  self.method = method
61
  self.image_shape = image_shape
 
62
 
63
  self.grid = self.__regular_grid()
64
 
65
  def __regular_grid(self):
66
  xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
67
- yy = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
68
- zz = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
69
 
70
  xx, yy, zz = np.meshgrid(xx, yy, zz)
71
 
72
- return np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
 
 
73
 
74
  def __call__(self, disp_map, interp_points, backwards=False):
75
- disp_map = disp_map.reshape([-1, 3])
76
  grid_pts = self.grid.copy()
77
  if backwards:
78
  grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
 
55
  class DisplacementMapInterpolator:
56
  def __init__(self,
57
  image_shape=[64, 64, 64],
58
+ method='rbf',
59
+ step=1):
60
  assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
61
  self.method = method
62
  self.image_shape = image_shape
63
+ self.step = step # If to use every point or even N-th point
64
 
65
  self.grid = self.__regular_grid()
66
 
67
  def __regular_grid(self):
68
  xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
69
+ yy = np.linspace(0, self.image_shape[1], self.image_shape[1], endpoint=False, dtype=np.uint16)
70
+ zz = np.linspace(0, self.image_shape[2], self.image_shape[2], endpoint=False, dtype=np.uint16)
71
 
72
  xx, yy, zz = np.meshgrid(xx, yy, zz)
73
 
74
+ return np.stack([xx[::self.step, ::self.step, ::self.step].flatten(),
75
+ yy[::self.step, ::self.step, ::self.step].flatten(),
76
+ zz[::self.step, ::self.step, ::self.step].flatten()], axis=0).T
77
 
78
  def __call__(self, disp_map, interp_points, backwards=False):
79
+ disp_map = disp_map.squeeze()[::self.step, ::self.step, ::self.step, ...].reshape([-1, 3])
80
  grid_pts = self.grid.copy()
81
  if backwards:
82
  grid_pts = np.add(grid_pts, disp_map).astype(np.float32)