jpdefrutos commited on
Commit
5b0dbe4
·
1 Parent(s): 61f0e36

SpatialTransfomer was embedded into a Keras model, and it is downloaded from the repo releases

Browse files
DeepDeformationMapRegistration/layers/SpatialTransformer.py CHANGED
@@ -3,6 +3,9 @@ import tensorflow.keras.backend as K
3
  import tensorflow as tf
4
  import neurite as ne
5
 
 
 
 
6
 
7
  class SpatialTransformer(kl.Layer):
8
  """
@@ -184,3 +187,15 @@ class SpatialTransformer(kl.Layer):
184
  # test single
185
  return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import tensorflow as tf
4
  import neurite as ne
5
 
6
+ import h5py
7
+ from DeepDeformationMapRegistration.utils.constants import IMG_SHAPE, DISP_MAP_SHAPE
8
+
9
 
10
  class SpatialTransformer(kl.Layer):
11
  """
 
187
  # test single
188
  return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
189
 
190
+
191
+ if __name__ == "__main__":
192
+ output_file = './spatialtransformer.h5'
193
+
194
+ in_dm = tf.keras.Input(DISP_MAP_SHAPE)
195
+ in_image = tf.keras.Input(IMG_SHAPE)
196
+ pred = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([in_image, in_dm])
197
+
198
+ model = tf.keras.Model(inputs=[in_image, in_dm], outputs=pred)
199
+
200
+ model.save(output_file)
201
+ print(f"SpatialTransformer layer saved in: {output_file}")
DeepDeformationMapRegistration/main.py CHANGED
@@ -19,7 +19,7 @@ import DeepDeformationMapRegistration.utils.constants as C
19
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
20
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
21
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
22
- from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
23
  from DeepDeformationMapRegistration.utils.logger import LOGGER
24
 
25
  from importlib.util import find_spec
@@ -279,8 +279,11 @@ def main():
279
 
280
  LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
281
  MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
 
282
 
283
  network, registration_model = load_model(MODEL_FILE, False, True)
 
 
284
 
285
  LOGGER.info('Computing registration')
286
  with sess.as_default():
@@ -297,7 +300,8 @@ def main():
297
 
298
  LOGGER.info('Applying displacement map...')
299
  time_pred_img_start = time.time()
300
- pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
 
301
  time_pred_img_end = time.time()
302
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
303
  pred_image = pred_image[0, ...]
 
19
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
20
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
21
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
22
+ from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model, get_spatialtransformer_model
23
  from DeepDeformationMapRegistration.utils.logger import LOGGER
24
 
25
  from importlib.util import find_spec
 
279
 
280
  LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
281
  MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
282
+ ST_MODEL_FILE = get_spatialtransformer_model()
283
 
284
  network, registration_model = load_model(MODEL_FILE, False, True)
285
+ spatialtransformer_model = tf.keras.models.load_model(ST_MODEL_FILE,
286
+ custom_objects={'SpatialTransformer': SpatialTransformer})
287
 
288
  LOGGER.info('Computing registration')
289
  with sess.as_default():
 
300
 
301
  LOGGER.info('Applying displacement map...')
302
  time_pred_img_start = time.time()
303
+ # pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
304
+ pred_image = spatialtransformer_model.predict([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]])
305
  time_pred_img_end = time.time()
306
  LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
307
  pred_image = pred_image[0, ...]
DeepDeformationMapRegistration/utils/model_utils.py CHANGED
@@ -63,3 +63,16 @@ def load_model(weights_file_path: str, trainable: bool = False, return_registrat
63
  ret_val = (ret_val, ret_val.get_registration_model())
64
 
65
  return ret_val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ret_val = (ret_val, ret_val.get_registration_model())
64
 
65
  return ret_val
66
+
67
+
68
+ def get_spatialtransformer_model():
69
+ url = 'https://github.com/jpdefrutos/DDMR/releases/download/spatialtransformer_model_v0/spatialtransformer.h5'
70
+ file_path = os.path.join(os.getcwd(), 'models', 'spatialtransformer.h5')
71
+ if not os.path.exists(file_path):
72
+ LOGGER.info(f'Model not found. Downloading from {url}... ')
73
+ os.makedirs(os.path.split(file_path)[0], exist_ok=True)
74
+ download(url, file_path)
75
+ LOGGER.info(f'... downloaded model. Stored in {file_path}')
76
+ else:
77
+ LOGGER.info(f'Found model: {file_path}')
78
+ return file_path