Commit
·
dc4b749
1
Parent(s):
c9f9cef
Refactor the model utils functions (download weight files, get weight files path, and load models)
Browse files
DeepDeformationMapRegistration/main.py
CHANGED
|
@@ -28,8 +28,7 @@ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplifie
|
|
| 28 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
| 29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
| 31 |
-
from DeepDeformationMapRegistration.utils.
|
| 32 |
-
from DeepDeformationMapRegistration.networks import load_model
|
| 33 |
|
| 34 |
from importlib.util import find_spec
|
| 35 |
|
|
|
|
| 28 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
| 29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
| 30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
| 31 |
+
from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
|
|
|
|
| 32 |
|
| 33 |
from importlib.util import find_spec
|
| 34 |
|
DeepDeformationMapRegistration/networks.py
CHANGED
|
@@ -9,23 +9,6 @@ import tensorflow as tf
|
|
| 9 |
import voxelmorph as vxm
|
| 10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
| 11 |
from tensorflow.keras.layers import UpSampling3D
|
| 12 |
-
from DeepDeformationMapRegistration.utils.constants import ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
|
| 16 |
-
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
|
| 17 |
-
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
|
| 18 |
-
|
| 19 |
-
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
|
| 20 |
-
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
|
| 21 |
-
int_steps=0)
|
| 22 |
-
ret_val.load_weights(weights_file_path, by_name=True)
|
| 23 |
-
ret_val.trainable = trainable
|
| 24 |
-
|
| 25 |
-
if return_registration_model:
|
| 26 |
-
ret_val = (ret_val, ret_val.get_registration_model())
|
| 27 |
-
|
| 28 |
-
return ret_val
|
| 29 |
|
| 30 |
|
| 31 |
class WeaklySupervised(LoadableModel):
|
|
|
|
| 9 |
import voxelmorph as vxm
|
| 10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
| 11 |
from tensorflow.keras.layers import UpSampling3D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class WeaklySupervised(LoadableModel):
|
DeepDeformationMapRegistration/utils/{model_downloader.py → model_utils.py}
RENAMED
|
@@ -2,7 +2,8 @@ import os
|
|
| 2 |
import requests
|
| 3 |
from datetime import datetime
|
| 4 |
from email.utils import parsedate_to_datetime, formatdate
|
| 5 |
-
from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
# taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
|
|
@@ -35,9 +36,25 @@ def get_models_path(anatomy: str, model_type: str, output_root_dir: str):
|
|
| 35 |
assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
|
| 36 |
anatomy = ANATOMIES[anatomy]
|
| 37 |
model_type = MODEL_TYPES[model_type]
|
| 38 |
-
url = 'https://github.com/jpdefrutos/DDMR/releases/download/
|
| 39 |
file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
|
| 40 |
if not os.path.exists(file_path):
|
| 41 |
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
|
| 42 |
download(url, file_path)
|
| 43 |
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import requests
|
| 3 |
from datetime import datetime
|
| 4 |
from email.utils import parsedate_to_datetime, formatdate
|
| 5 |
+
from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES, ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
|
| 6 |
+
import voxelmorph as vxm
|
| 7 |
|
| 8 |
|
| 9 |
# taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
|
|
|
|
| 36 |
assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
|
| 37 |
anatomy = ANATOMIES[anatomy]
|
| 38 |
model_type = MODEL_TYPES[model_type]
|
| 39 |
+
url = 'https://github.com/jpdefrutos/DDMR/releases/download/trained_models_v0/' + anatomy + '_' + model_type + '.h5'
|
| 40 |
file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
|
| 41 |
if not os.path.exists(file_path):
|
| 42 |
os.makedirs(os.path.split(file_path)[0], exist_ok=True)
|
| 43 |
download(url, file_path)
|
| 44 |
return file_path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
|
| 48 |
+
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
|
| 49 |
+
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
|
| 50 |
+
|
| 51 |
+
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
|
| 52 |
+
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
|
| 53 |
+
int_steps=0)
|
| 54 |
+
ret_val.load_weights(weights_file_path, by_name=True)
|
| 55 |
+
ret_val.trainable = trainable
|
| 56 |
+
|
| 57 |
+
if return_registration_model:
|
| 58 |
+
ret_val = (ret_val, ret_val.get_registration_model())
|
| 59 |
+
|
| 60 |
+
return ret_val
|