Commit
·
dc4b749
1
Parent(s):
c9f9cef
Refactor the model utils functions (download weight files, get weight files path, and load models)
Browse files
DeepDeformationMapRegistration/main.py
CHANGED
@@ -28,8 +28,7 @@ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplifie
|
|
28 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
31 |
-
from DeepDeformationMapRegistration.utils.
|
32 |
-
from DeepDeformationMapRegistration.networks import load_model
|
33 |
|
34 |
from importlib.util import find_spec
|
35 |
|
|
|
28 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
31 |
+
from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
|
|
|
32 |
|
33 |
from importlib.util import find_spec
|
34 |
|
DeepDeformationMapRegistration/networks.py
CHANGED
@@ -9,23 +9,6 @@ import tensorflow as tf
|
|
9 |
import voxelmorph as vxm
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
from tensorflow.keras.layers import UpSampling3D
|
12 |
-
from DeepDeformationMapRegistration.utils.constants import ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
|
13 |
-
|
14 |
-
|
15 |
-
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
|
16 |
-
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
|
17 |
-
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
|
18 |
-
|
19 |
-
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
|
20 |
-
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
|
21 |
-
int_steps=0)
|
22 |
-
ret_val.load_weights(weights_file_path, by_name=True)
|
23 |
-
ret_val.trainable = trainable
|
24 |
-
|
25 |
-
if return_registration_model:
|
26 |
-
ret_val = (ret_val, ret_val.get_registration_model())
|
27 |
-
|
28 |
-
return ret_val
|
29 |
|
30 |
|
31 |
class WeaklySupervised(LoadableModel):
|
|
|
9 |
import voxelmorph as vxm
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
from tensorflow.keras.layers import UpSampling3D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class WeaklySupervised(LoadableModel):
|
DeepDeformationMapRegistration/utils/{model_downloader.py → model_utils.py}
RENAMED
@@ -2,7 +2,8 @@ import os
|
|
2 |
import requests
|
3 |
from datetime import datetime
|
4 |
from email.utils import parsedate_to_datetime, formatdate
|
5 |
-
from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES
|
|
|
6 |
|
7 |
|
8 |
# taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
|
@@ -35,9 +36,25 @@ def get_models_path(anatomy: str, model_type: str, output_root_dir: str):
|
|
35 |
assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
|
36 |
anatomy = ANATOMIES[anatomy]
|
37 |
model_type = MODEL_TYPES[model_type]
|
38 |
-
url = 'https://github.com/jpdefrutos/DDMR/releases/download/
|
39 |
file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
|
40 |
if not os.path.exists(file_path):
|
41 |
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
|
42 |
download(url, file_path)
|
43 |
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import requests
|
3 |
from datetime import datetime
|
4 |
from email.utils import parsedate_to_datetime, formatdate
|
5 |
+
from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES, ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
|
6 |
+
import voxelmorph as vxm
|
7 |
|
8 |
|
9 |
# taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
|
|
|
36 |
assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
|
37 |
anatomy = ANATOMIES[anatomy]
|
38 |
model_type = MODEL_TYPES[model_type]
|
39 |
+
url = 'https://github.com/jpdefrutos/DDMR/releases/download/trained_models_v0/' + anatomy + '_' + model_type + '.h5'
|
40 |
file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
|
41 |
if not os.path.exists(file_path):
|
42 |
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
|
43 |
download(url, file_path)
|
44 |
return file_path
|
45 |
+
|
46 |
+
|
47 |
+
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
|
48 |
+
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
|
49 |
+
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
|
50 |
+
|
51 |
+
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
|
52 |
+
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
|
53 |
+
int_steps=0)
|
54 |
+
ret_val.load_weights(weights_file_path, by_name=True)
|
55 |
+
ret_val.trainable = trainable
|
56 |
+
|
57 |
+
if return_registration_model:
|
58 |
+
ret_val = (ret_val, ret_val.get_registration_model())
|
59 |
+
|
60 |
+
return ret_val
|