jpdefrutos commited on
Commit
dc4b749
·
1 Parent(s): c9f9cef

Refactor the model utils functions (download weight files, get weight files path, and load models)

Browse files
DeepDeformationMapRegistration/main.py CHANGED
@@ -28,8 +28,7 @@ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplifie
28
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
29
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
30
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
31
- from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
32
- from DeepDeformationMapRegistration.networks import load_model
33
 
34
  from importlib.util import find_spec
35
 
 
28
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
29
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
30
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
31
+ from DeepDeformationMapRegistration.utils.model_utils import get_models_path, load_model
 
32
 
33
  from importlib.util import find_spec
34
 
DeepDeformationMapRegistration/networks.py CHANGED
@@ -9,23 +9,6 @@ import tensorflow as tf
9
  import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
  from tensorflow.keras.layers import UpSampling3D
12
- from DeepDeformationMapRegistration.utils.constants import ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
13
-
14
-
15
- def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
16
- assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
17
- assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
18
-
19
- ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
20
- nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
21
- int_steps=0)
22
- ret_val.load_weights(weights_file_path, by_name=True)
23
- ret_val.trainable = trainable
24
-
25
- if return_registration_model:
26
- ret_val = (ret_val, ret_val.get_registration_model())
27
-
28
- return ret_val
29
 
30
 
31
  class WeaklySupervised(LoadableModel):
 
9
  import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
  from tensorflow.keras.layers import UpSampling3D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class WeaklySupervised(LoadableModel):
DeepDeformationMapRegistration/utils/{model_downloader.py → model_utils.py} RENAMED
@@ -2,7 +2,8 @@ import os
2
  import requests
3
  from datetime import datetime
4
  from email.utils import parsedate_to_datetime, formatdate
5
- from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES
 
6
 
7
 
8
  # taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
@@ -35,9 +36,25 @@ def get_models_path(anatomy: str, model_type: str, output_root_dir: str):
35
  assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
36
  anatomy = ANATOMIES[anatomy]
37
  model_type = MODEL_TYPES[model_type]
38
- url = 'https://github.com/jpdefrutos/DDMR/releases/download/trained-models/' + anatomy + '_' + model_type + '.h5'
39
  file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
40
  if not os.path.exists(file_path):
41
  os.makedirs(os.path.split(file_path)[0], exist_ok=True)
42
  download(url, file_path)
43
  return file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import requests
3
  from datetime import datetime
4
  from email.utils import parsedate_to_datetime, formatdate
5
+ from DeepDeformationMapRegistration.utils.constants import ANATOMIES, MODEL_TYPES, ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
6
+ import voxelmorph as vxm
7
 
8
 
9
  # taken from: https://lenon.dev/blog/downloading-and-caching-large-files-using-python/
 
36
  assert model_type in MODEL_TYPES.keys(), 'Invalid model type'
37
  anatomy = ANATOMIES[anatomy]
38
  model_type = MODEL_TYPES[model_type]
39
+ url = 'https://github.com/jpdefrutos/DDMR/releases/download/trained_models_v0/' + anatomy + '_' + model_type + '.h5'
40
  file_path = os.path.join(output_root_dir, 'models', anatomy, model_type + '.h5')
41
  if not os.path.exists(file_path):
42
  os.makedirs(os.path.split(file_path)[0], exist_ok=True)
43
  download(url, file_path)
44
  return file_path
45
+
46
+
47
+ def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
48
+ assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
49
+ assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
50
+
51
+ ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
52
+ nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
53
+ int_steps=0)
54
+ ret_val.load_weights(weights_file_path, by_name=True)
55
+ ret_val.trainable = trainable
56
+
57
+ if return_registration_model:
58
+ ret_val = (ret_val, ret_val.get_registration_model())
59
+
60
+ return ret_val