jpdefrutos commited on
Commit
62161e8
·
1 Parent(s): ab9857f

Renamed model

Browse files
DeepDeformationMapRegistration/networks.py CHANGED
@@ -10,7 +10,7 @@ import voxelmorph as vxm
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
 
12
 
13
- class VxmWeaklySupervised(LoadableModel):
14
 
15
  @store_config_args
16
  def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
@@ -61,3 +61,4 @@ class VxmWeaklySupervised(LoadableModel):
61
  img_input = tf.keras.Input(shape=mov_img.shape[1:], name='input_img')
62
  pred_img = vxm.layers.SpatialTransformer(interp_method=interp_method)([img_input, warp_model.output])
63
  return tf.keras.Model(warp_model.inputs, pred_img).predict([mov_segm, fix_segm, mov_img])
 
 
10
  from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
 
12
 
13
+ class WeaklySupervised(LoadableModel):
14
 
15
  @store_config_args
16
  def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
 
61
  img_input = tf.keras.Input(shape=mov_img.shape[1:], name='input_img')
62
  pred_img = vxm.layers.SpatialTransformer(interp_method=interp_method)([img_input, warp_model.output])
63
  return tf.keras.Model(warp_model.inputs, pred_img).predict([mov_segm, fix_segm, mov_img])
64
+
TrainingScripts/Train_3d_weaklySupervised.py CHANGED
@@ -16,7 +16,7 @@ from datetime import datetime
16
  import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
- from DeepDeformationMapRegistration.networks import VxmWeaklySupervised
20
  from DeepDeformationMapRegistration.losses import HausdorffDistance
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
@@ -44,7 +44,7 @@ in_shape = train_generator.get_input_shape()[1:-1]
44
  enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
45
  dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
46
  nb_features = [enc_features, dec_features]
47
- vxm_model = VxmWeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
48
 
49
  # Losses and loss weights
50
 
 
16
  import DeepDeformationMapRegistration.utils.constants as C
17
  from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
  from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
+ from DeepDeformationMapRegistration.networks import WeaklySupervised
20
  from DeepDeformationMapRegistration.losses import HausdorffDistance
21
  from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
 
 
44
  enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
45
  dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
46
  nb_features = [enc_features, dec_features]
47
+ vxm_model = WeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
48
 
49
  # Losses and loss weights
50