Commit
·
0978cbc
1
Parent(s):
f4c45ef
Created convenient function for loading models
Browse files
DeepDeformationMapRegistration/main.py
CHANGED
@@ -7,9 +7,9 @@ import subprocess
|
|
7 |
import logging
|
8 |
import time
|
9 |
|
10 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
11 |
-
parentdir = os.path.dirname(currentdir)
|
12 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
13 |
|
14 |
import tensorflow as tf
|
15 |
|
@@ -29,6 +29,7 @@ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimila
|
|
29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
31 |
from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
|
|
|
32 |
|
33 |
from importlib.util import find_spec
|
34 |
|
@@ -284,39 +285,7 @@ def main():
|
|
284 |
LOGGER.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
285 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
286 |
|
287 |
-
|
288 |
-
# network = tf.keras.models.load_model(MODEL_FILE,
|
289 |
-
# {'VxmDense': vxm.networks.VxmDense,
|
290 |
-
# # 'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
291 |
-
# 'AdamAccumulated': AdamAccumulated
|
292 |
-
# },
|
293 |
-
# compile=False)
|
294 |
-
# except ValueError as e:
|
295 |
-
# enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
296 |
-
# dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
297 |
-
# nb_features = [enc_features, dec_features]
|
298 |
-
# if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
299 |
-
# network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
|
300 |
-
# nb_unet_features=nb_features,
|
301 |
-
# int_steps=0,
|
302 |
-
# int_downsize=1,
|
303 |
-
# seg_downsize=1)
|
304 |
-
# else:
|
305 |
-
# network = vxm.networks.VxmDense(inshape=IMAGE_INTPUT_SHAPE[:-1],
|
306 |
-
# nb_unet_features=nb_features,
|
307 |
-
# int_steps=0)
|
308 |
-
# network.load_weights(MODEL_FILE, by_name=True)
|
309 |
-
|
310 |
-
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
|
311 |
-
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
|
312 |
-
nb_features = [enc_features, dec_features]
|
313 |
-
network = vxm.networks.VxmDense(inshape=C.IMG_SHAPE[:-1],
|
314 |
-
nb_unet_features=nb_features,
|
315 |
-
int_steps=0)
|
316 |
-
network.load_weights(MODEL_FILE, by_name=True)
|
317 |
-
network.trainable = False
|
318 |
-
|
319 |
-
registration_model = network.get_registration_model()
|
320 |
deb_model = network.apply_transform
|
321 |
|
322 |
LOGGER.info('Computing registration')
|
|
|
7 |
import logging
|
8 |
import time
|
9 |
|
10 |
+
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
11 |
+
# parentdir = os.path.dirname(currentdir)
|
12 |
+
# sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
13 |
|
14 |
import tensorflow as tf
|
15 |
|
|
|
29 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
30 |
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map
|
31 |
from DeepDeformationMapRegistration.utils.model_downloader import get_models_path
|
32 |
+
from DeepDeformationMapRegistration.networks import load_model
|
33 |
|
34 |
from importlib.util import find_spec
|
35 |
|
|
|
285 |
LOGGER.info(f'Using model: {"Brain" if args.anatomy == "B" else "Liver"} -> {args.model}')
|
286 |
MODEL_FILE = get_models_path(args.anatomy, args.model, os.getcwd()) # MODELS_FILE[args.anatomy][args.model]
|
287 |
|
288 |
+
network, registration_model = load_model(MODEL_FILE, False, True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
deb_model = network.apply_transform
|
290 |
|
291 |
LOGGER.info('Computing registration')
|
DeepDeformationMapRegistration/networks.py
CHANGED
@@ -1,14 +1,31 @@
|
|
1 |
import os, sys
|
2 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
-
parentdir = os.path.dirname(currentdir)
|
4 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
-
|
6 |
-
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
|
8 |
import tensorflow as tf
|
9 |
import voxelmorph as vxm
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
from tensorflow.keras.layers import UpSampling3D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
class WeaklySupervised(LoadableModel):
|
|
|
1 |
import os, sys
|
2 |
+
# currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
# parentdir = os.path.dirname(currentdir)
|
4 |
+
# sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
#
|
6 |
+
# PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
|
8 |
import tensorflow as tf
|
9 |
import voxelmorph as vxm
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
from tensorflow.keras.layers import UpSampling3D
|
12 |
+
from DeepDeformationMapRegistration.utils.constants import ENCODER_FILTERS, DECODER_FILTERS, IMG_SHAPE
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(weights_file_path: str, trainable: bool = False, return_registration_model: bool=True):
|
16 |
+
assert os.path.exists(weights_file_path), f'File {weights_file_path} not found'
|
17 |
+
assert weights_file_path.endswith('h5'), 'Invalid file extension. Expected .h5'
|
18 |
+
|
19 |
+
ret_val = vxm.networks.VxmDense(inshape=IMG_SHAPE[:-1],
|
20 |
+
nb_unet_features=[ENCODER_FILTERS, DECODER_FILTERS],
|
21 |
+
int_steps=0)
|
22 |
+
ret_val.load_weights(weights_file_path, by_name=True)
|
23 |
+
ret_val.trainable = trainable
|
24 |
+
|
25 |
+
if return_registration_model:
|
26 |
+
ret_val = (ret_val, ret_val.get_registration_model())
|
27 |
+
|
28 |
+
return ret_val
|
29 |
|
30 |
|
31 |
class WeaklySupervised(LoadableModel):
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
@@ -196,8 +196,8 @@ DROPOUT = True
|
|
196 |
DROPOUT_RATE = 0.2
|
197 |
MAX_DATA_SIZE = (1000, 1000, 1)
|
198 |
PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
|
199 |
-
ENCODER_FILTERS = [
|
200 |
-
|
201 |
# SSIM
|
202 |
SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
|
203 |
SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
|
@@ -205,7 +205,7 @@ SSIM_K1 = 0.01 # Def. 0.01
|
|
205 |
SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
|
206 |
MAX_VALUE = 1.0 # Maximum intensity values
|
207 |
|
208 |
-
#
|
209 |
EPS = 1e-8
|
210 |
EPS_tf = tf.constant(EPS, dtype=tf.float32)
|
211 |
LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
|
|
|
196 |
DROPOUT_RATE = 0.2
|
197 |
MAX_DATA_SIZE = (1000, 1000, 1)
|
198 |
PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
|
199 |
+
ENCODER_FILTERS = [32, 64, 128, 256, 512, 1024]
|
200 |
+
DECODER_FILTERS = ENCODER_FILTERS[::-1] + [16, 16]
|
201 |
# SSIM
|
202 |
SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
|
203 |
SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
|
|
|
205 |
SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
|
206 |
MAX_VALUE = 1.0 # Maximum intensity values
|
207 |
|
208 |
+
# Mathematics constants
|
209 |
EPS = 1e-8
|
210 |
EPS_tf = tf.constant(EPS, dtype=tf.float32)
|
211 |
LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
|