jpdefrutos commited on
Commit
0978cbc
·
1 Parent(s): f4c45ef

Created convenient function for loading models

Browse files
DeepDeformationMapRegistration/main.py CHANGED
@@ -7,9 +7,9 @@ import subprocess
7
  import logging
8
  import time
9
 
10
- currentdir = os.path.dirname(os.path.realpath(__file__))
11
- parentdir = os.path.dirname(currentdir)
12
- sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
13
 
14
  import tensorflow as tf
15
 
@@ -29,6 +29,7 @@ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimila
29
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
30
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
31
  from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
 
32
 
33
  from importlib.util import find_spec
34
 
@@ -284,39 +285,7 @@ def main():
284
  LOGGER.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
285
  MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
286
 
287
- # try:
288
- # network = tf.keras.models.load_model(MODEL_FILE,
289
- # {'VxmDense': vxm.networks.VxmDense,
290
- # # 'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
291
- # 'AdamAccumulated': AdamAccumulated
292
- # },
293
- # compile=False)
294
- # except ValueError as e:
295
- # enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
296
- # dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
297
- # nb_features = [enc_features, dec_features]
298
- # if re.search('^UW|SEGGUIDED_', MODEL_FILE):
299
- # network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
300
- # nb_unet_features=nb_features,
301
- # int_steps=0,
302
- # int_downsize=1,
303
- # seg_downsize=1)
304
- # else:
305
- # network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
306
- # nb_unet_features=nb_features,
307
- # int_steps=0)
308
- # network.load_weights(MODEL_FILE, by_name=True)
309
-
310
- enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
311
- dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
312
- nb_features = [enc_features, dec_features]
313
- network = vxm.networks.VxmDense(inshape=C.IMG_SHAPE[:-1],
314
- nb_unet_features=nb_features,
315
- int_steps=0)
316
- network.load_weights(MODEL_FILE, by_name=True)
317
- network.trainable = False
318
-
319
- registration_model = network.get_registration_model()
320
  deb_model = network.apply_transform
321
 
322
  LOGGER.info('Computing registration')
 
7
  import logging
8
  import time
9
 
10
+ # currentdir = os.path.dirname(os.path.realpath(__file__))
11
+ # parentdir = os.path.dirname(currentdir)
12
+ # sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
13
 
14
  import tensorflow as tf
15
 
 
29
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
30
  from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
31
  from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
32
+ from DeepDeformationMapRegistration.networks import load_model
33
 
34
  from importlib.util import find_spec
35
 
 
285
  LOGGER.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
286
  MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
287
 
288
+ network, registration_model = load_model(MODEL_FILE, False, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  deb_model = network.apply_transform
290
 
291
  LOGGER.info('Computing registration')
DeepDeformationMapRegistration/networks.py CHANGED
@@ -1,14 +1,31 @@
1
  import os, sys
2
- currentdir = os.path.dirname(os.path.realpath(__file__))
3
- parentdir = os.path.dirname(currentdir)
4
- sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
-
6
- PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
 
8
  import tensorflow as tf
9
  import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
  from tensorflow.keras.layers import UpSampling3D
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class WeaklySupervised(LoadableModel):
 
1
  import os, sys
2
+ # currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ # parentdir = os.path.dirname(currentdir)
4
+ # sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+ #
6
+ # PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
 
8
  import tensorflow as tf
9
  import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
  from tensorflow.keras.layers import UpSampling3D
12
+ from DeepDeformationMapRegistration.utils.constants import ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
13
+
14
+
15
+ def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
16
+ assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
17
+ assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
18
+
19
+ ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
20
+ nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
21
+ int_steps=0)
22
+ ret_val.load_weights(weights_file_path, by_name=True)
23
+ ret_val.trainable = trainable
24
+
25
+ if return_registration_model:
26
+ ret_val = (ret_val, ret_val.get_registration_model())
27
+
28
+ return ret_val
29
 
30
 
31
  class WeaklySupervised(LoadableModel):
DeepDeformationMapRegistration/utils/constants.py CHANGED
@@ -196,8 +196,8 @@ DROPOUT = True
196
  DROPOUT_RATE = 0.2
197
  MAX_DATA_SIZE = (1000, 1000, 1)
198
  PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
199
- ENCODER_FILTERS = [4, 8, 16, 32, 64]
200
-
201
  # SSIM
202
  SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
203
  SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
@@ -205,7 +205,7 @@ SSIM_K1 = 0.01 # Def. 0.01
205
  SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
206
  MAX_VALUE = 1.0 # Maximum intensity values
207
 
208
- # Mathematic constants
209
  EPS = 1e-8
210
  EPS_tf = tf.constant(EPS, dtype=tf.float32)
211
  LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
 
196
  DROPOUT_RATE = 0.2
197
  MAX_DATA_SIZE = (1000, 1000, 1)
198
  PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
199
+ ENCODER_FILTERS = [32, 64, 128, 256, 512, 1024]
200
+ DECODER_FILTERS = ENCODER_FILTERS[::-1] + [16, 16]
201
  # SSIM
202
  SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
203
  SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
 
205
  SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
206
  MAX_VALUE = 1.0 # Maximum intensity values
207
 
208
+ # Mathematics constants
209
  EPS = 1e-8
210
  EPS_tf = tf.constant(EPS, dtype=tf.float32)
211
  LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))