Commit
·
5b0dbe4
1
Parent(s):
61f0e36
SpatialTransfomer was embedded into a Keras model, and it is downloaded from the repo releases
Browse files
DeepDeformationMapRegistration/layers/SpatialTransformer.py
CHANGED
@@ -3,6 +3,9 @@ import tensorflow.keras.backend as K
|
|
3 |
import tensorflow as tf
|
4 |
import neurite as ne
|
5 |
|
|
|
|
|
|
|
6 |
|
7 |
class SpatialTransformer(kl.Layer):
|
8 |
"""
|
@@ -184,3 +187,15 @@ class SpatialTransformer(kl.Layer):
|
|
184 |
# test single
|
185 |
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import tensorflow as tf
|
4 |
import neurite as ne
|
5 |
|
6 |
+
import h5py
|
7 |
+
from DeepDeformationMapRegistration.utils.constants import IMG_SHAPE, DISP_MAP_SHAPE
|
8 |
+
|
9 |
|
10 |
class SpatialTransformer(kl.Layer):
|
11 |
"""
|
|
|
187 |
# test single
|
188 |
return ne.utils.interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)
|
189 |
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
output_file = './spatialtransformer.h5'
|
193 |
+
|
194 |
+
in_dm = tf.keras.Input(DISP_MAP_SHAPE)
|
195 |
+
in_image = tf.keras.Input(IMG_SHAPE)
|
196 |
+
pred = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([in_image, in_dm])
|
197 |
+
|
198 |
+
model = tf.keras.Model(inputs=[in_image, in_dm], outputs=pred)
|
199 |
+
|
200 |
+
model.save(output_file)
|
201 |
+
print(f"SpatialTransformer layer saved in: {output_file}")
|
DeepDeformationMapRegistration/main.py
CHANGED
@@ -19,7 +19,7 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
19 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
20 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
21 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
22 |
-
from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
|
23 |
from DeepDeformationMapRegistration.utils.logger import LOGGER
|
24 |
|
25 |
from importlib.util import find_spec
|
@@ -279,8 +279,11 @@ def main():
|
|
279 |
|
280 |
LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
281 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
|
|
282 |
|
283 |
network, registration_model = load_model(MODEL_FILE, False, True)
|
|
|
|
|
284 |
|
285 |
LOGGER.info('Computing registration')
|
286 |
with sess.as_default():
|
@@ -297,7 +300,8 @@ def main():
|
|
297 |
|
298 |
LOGGER.info('Applying displacement map...')
|
299 |
time_pred_img_start = time.time()
|
300 |
-
pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
|
|
301 |
time_pred_img_end = time.time()
|
302 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
303 |
pred_image = pred_image[0, ...]
|
|
|
19 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
20 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
21 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
22 |
+
from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model, get_spatialtransformer_model
|
23 |
from DeepDeformationMapRegistration.utils.logger import LOGGER
|
24 |
|
25 |
from importlib.util import find_spec
|
|
|
279 |
|
280 |
LOGGER.info(f'Getting model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
281 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
282 |
+
ST_MODEL_FILE = get_spatialtransformer_model()
|
283 |
|
284 |
network, registration_model = load_model(MODEL_FILE, False, True)
|
285 |
+
spatialtransformer_model = tf.keras.models.load_model(ST_MODEL_FILE,
|
286 |
+
custom_objects={'SpatialTransformer': SpatialTransformer})
|
287 |
|
288 |
LOGGER.info('Computing registration')
|
289 |
with sess.as_default():
|
|
|
300 |
|
301 |
LOGGER.info('Applying displacement map...')
|
302 |
time_pred_img_start = time.time()
|
303 |
+
# pred_image = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]]).eval()
|
304 |
+
pred_image = spatialtransformer_model.predict([moving_image[np.newaxis, ...], disp_map[np.newaxis, ...]])
|
305 |
time_pred_img_end = time.time()
|
306 |
LOGGER.info(f'\t... done ({time_pred_img_end - time_pred_img_start} s)')
|
307 |
pred_image = pred_image[0, ...]
|
DeepDeformationMapRegistration/utils/model_utils.py
CHANGED
@@ -63,3 +63,16 @@ def load_model(weights_file_path: str, trainable: bool = False, return_registrat
|
|
63 |
ret_val = (ret_val, ret_val.get_registration_model())
|
64 |
|
65 |
return ret_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
ret_val = (ret_val, ret_val.get_registration_model())
|
64 |
|
65 |
return ret_val
|
66 |
+
|
67 |
+
|
68 |
+
def get_spatialtransformer_model():
|
69 |
+
url = 'https://github.com/jpdefrutos/DDMR/releases/download/spatialtransformer_model_v0/spatialtransformer.h5'
|
70 |
+
file_path = os.path.join(os.getcwd(), 'models', 'spatialtransformer.h5')
|
71 |
+
if not os.path.exists(file_path):
|
72 |
+
LOGGER.info(f'Model not found. Downloading from {url}... ')
|
73 |
+
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
|
74 |
+
download(url, file_path)
|
75 |
+
LOGGER.info(f'... downloaded model. Stored in {file_path}')
|
76 |
+
else:
|
77 |
+
LOGGER.info(f'Found model: {file_path}')
|
78 |
+
return file_path
|