Commit
·
62161e8
1
Parent(s):
ab9857f
Renamed model
Browse files
DeepDeformationMapRegistration/networks.py
CHANGED
@@ -10,7 +10,7 @@ import voxelmorph as vxm
|
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
|
12 |
|
13 |
-
class
|
14 |
|
15 |
@store_config_args
|
16 |
def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
|
@@ -61,3 +61,4 @@ class VxmWeaklySupervised(LoadableModel):
|
|
61 |
img_input = tf.keras.Input(shape=mov_img.shape[1:], name='input_img')
|
62 |
pred_img = vxm.layers.SpatialTransformer(interp_method=interp_method)([img_input, warp_model.output])
|
63 |
return tf.keras.Model(warp_model.inputs, pred_img).predict([mov_segm, fix_segm, mov_img])
|
|
|
|
10 |
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
|
12 |
|
13 |
+
class WeaklySupervised(LoadableModel):
|
14 |
|
15 |
@store_config_args
|
16 |
def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
|
|
|
61 |
img_input = tf.keras.Input(shape=mov_img.shape[1:], name='input_img')
|
62 |
pred_img = vxm.layers.SpatialTransformer(interp_method=interp_method)([img_input, warp_model.output])
|
63 |
return tf.keras.Model(warp_model.inputs, pred_img).predict([mov_segm, fix_segm, mov_img])
|
64 |
+
|
TrainingScripts/Train_3d_weaklySupervised.py
CHANGED
@@ -16,7 +16,7 @@ from datetime import datetime
|
|
16 |
import DeepDeformationMapRegistration.utils.constants as C
|
17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
-
from DeepDeformationMapRegistration.networks import
|
20 |
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
@@ -44,7 +44,7 @@ in_shape = train_generator.get_input_shape()[1:-1]
|
|
44 |
enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
|
45 |
dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
46 |
nb_features = [enc_features, dec_features]
|
47 |
-
vxm_model =
|
48 |
|
49 |
# Losses and loss weights
|
50 |
|
|
|
16 |
import DeepDeformationMapRegistration.utils.constants as C
|
17 |
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
+
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
20 |
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
21 |
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
|
|
|
44 |
enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
|
45 |
dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
46 |
nb_features = [enc_features, dec_features]
|
47 |
+
vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
|
48 |
|
49 |
# Losses and loss weights
|
50 |
|