jpdefrutos commited on
Commit
3b554c2
·
1 Parent(s): 67a11d3

Updated train scripts

Browse files
Brain_study/MultiTrain_config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import argparse
7
+ from configparser import ConfigParser
8
+ from datetime import datetime
9
+
10
+ import DeepDeformationMapRegistration.utils.constants as C
11
+
12
+ TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
13
+
14
+ err = list()
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--ini', help='Configuration file')
19
+ args = parser.parse_args()
20
+
21
+ configFile = ConfigParser()
22
+ configFile.read(args.ini)
23
+ print('Loaded configuration file: ' + args.ini)
24
+ print({section: dict(configFile[section]) for section in configFile.sections()})
25
+ print('\n\n')
26
+
27
+ trainConfig = configFile['TRAIN']
28
+ lossesConfig = configFile['LOSSES']
29
+ datasetConfig = configFile['DATASETS']
30
+ othersConfig = configFile['OTHERS']
31
+ augmentationConfig = configFile['AUGMENTATION']
32
+
33
+ simil = lossesConfig['similarity'].split(',')
34
+ segm = lossesConfig['segmentation'].split(',')
35
+ if trainConfig['name'].lower() == 'uw':
36
+ from Brain_study.Train_UncertaintyWeighted import launch_train
37
+ loss_config = {'simil': simil, 'segm': segm}
38
+ elif trainConfig['name'].lower() == 'segguided':
39
+ from Brain_study.Train_SegmentationGuided import launch_train
40
+ loss_config = {'simil': simil[0], 'segm': segm[0]}
41
+ else:
42
+ from Brain_study.Train_Baseline import launch_train
43
+ loss_config = {'simil': simil[0]}
44
+
45
+ output_folder = os.path.join(othersConfig['outputFolder'],
46
+ '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], '_'.join(simil), '_'.join(segm)))
47
+ output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")
48
+
49
+ print('TRAIN ' + datasetConfig['train'])
50
+
51
+ if augmentationConfig:
52
+ C.GAMMA_AUGMENTATION = augmentationConfig['gamma'].lower() == 'true'
53
+ C.BRIGHTNESS_AUGMENTATION = augmentationConfig['brightness'].lower() == 'true'
54
+
55
+ try:
56
+ unet = [int(x) for x in trainConfig['unet'].split(',')]
57
+ except KeyError as e:
58
+ unet = [16, 32, 64, 128, 256]
59
+
60
+ try:
61
+ head = [int(x) for x in trainConfig['head'].split(',')]
62
+ except KeyError as e:
63
+ head = [16, 16]
64
+
65
+ launch_train(dataset_folder=datasetConfig['train'],
66
+ validation_folder=datasetConfig['validation'],
67
+ output_folder=output_folder,
68
+ gpu_num=eval(trainConfig['gpu']),
69
+ lr=eval(trainConfig['learningRate']),
70
+ rw=eval(trainConfig['regularizationWeight']),
71
+ acc_gradients=eval(trainConfig['accumulativeGradients']),
72
+ batch_size=eval(trainConfig['batchSize']),
73
+ max_epochs=eval(trainConfig['epochs']),
74
+ image_size=eval(trainConfig['imageSize']),
75
+ early_stop_patience=eval(trainConfig['earlyStopPatience']),
76
+ unet=unet,
77
+ head=head,
78
+ **loss_config)
Brain_study/Train_Baseline.py CHANGED
@@ -1,4 +1,6 @@
1
  import os, sys
 
 
2
  currentdir = os.path.dirname(os.path.realpath(__file__))
3
  parentdir = os.path.dirname(currentdir)
4
  sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
@@ -32,25 +34,30 @@ from tqdm import tqdm
32
  from datetime import datetime
33
 
34
 
35
- def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim'):
 
 
36
  assert dataset_folder is not None and output_folder is not None
37
 
38
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
39
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
40
  C.GPU_NUM = str(gpu_num)
41
 
 
 
 
42
  output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
43
  os.makedirs(output_folder, exist_ok=True)
44
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
45
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
46
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
47
  C.VALIDATION_DATASET = validation_folder
48
- C.ACCUM_GRADIENT_STEP = 16
49
- C.BATCH_SIZE = 16 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
50
- C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
51
  C.LEARNING_RATE = lr
52
  C.LIMIT_NUM_SAMPLES = None
53
- C.EPOCHS = 10000
54
 
55
  aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
56
  "GPU: {}\n" \
@@ -84,7 +91,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
84
  validation_generator = data_generator.get_validation_generator()
85
 
86
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
87
- image_output_shape = [64] * 3
88
 
89
  # Config the training sessions
90
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
@@ -125,13 +132,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
125
  augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
126
 
127
  # Build model
128
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
129
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
 
 
130
  nb_features = [enc_features, dec_features]
131
  network = vxm.networks.VxmDense(inshape=image_output_shape,
132
  nb_unet_features=nb_features,
133
  int_steps=0)
134
-
135
  # Losses and loss weights
136
  SSIM_KER_SIZE = 5
137
  MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
 
1
  import os, sys
2
+ import warnings
3
+
4
  currentdir = os.path.dirname(os.path.realpath(__file__))
5
  parentdir = os.path.dirname(currentdir)
6
  sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
 
34
  from datetime import datetime
35
 
36
 
37
+ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
38
+ acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
39
+ unet=[16, 32, 64, 128, 256], head=[16, 16]):
40
  assert dataset_folder is not None and output_folder is not None
41
 
42
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
43
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
44
  C.GPU_NUM = str(gpu_num)
45
 
46
+ if batch_size != 1 and acc_gradients != 1:
47
+ warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
48
+
49
  output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
50
  os.makedirs(output_folder, exist_ok=True)
51
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
52
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
53
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
54
  C.VALIDATION_DATASET = validation_folder
55
+ C.ACCUM_GRADIENT_STEP = acc_gradients
56
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
57
+ C.EARLY_STOP_PATIENCE = early_stop_patience
58
  C.LEARNING_RATE = lr
59
  C.LIMIT_NUM_SAMPLES = None
60
+ C.EPOCHS = max_epochs
61
 
62
  aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
63
  "GPU: {}\n" \
 
91
  validation_generator = data_generator.get_validation_generator()
92
 
93
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
94
+ image_output_shape = [image_size] * 3
95
 
96
  # Config the training sessions
97
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
 
132
  augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
133
 
134
  # Build model
135
+ # enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
136
+ # dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
137
+ enc_features = unet # const.ENCODER_FILTERS
138
+ dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
139
  nb_features = [enc_features, dec_features]
140
  network = vxm.networks.VxmDense(inshape=image_output_shape,
141
  nb_unet_features=nb_features,
142
  int_steps=0)
143
+ network.summary(line_length=150)
144
  # Losses and loss weights
145
  SSIM_KER_SIZE = 5
146
  MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
Brain_study/Train_SegmentationGuided.py CHANGED
@@ -29,25 +29,32 @@ from Brain_study.data_generator import BatchGenerator
29
  from Brain_study.utils import SummaryDictionary, named_logs
30
 
31
  import time
 
32
 
33
- def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd'):
 
 
 
34
  assert dataset_folder is not None and output_folder is not None
35
 
36
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
37
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
38
  C.GPU_NUM = str(gpu_num)
39
 
 
 
 
40
  output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
41
  os.makedirs(output_folder, exist_ok=True)
42
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
43
  C.TRAINING_DATASET = dataset_folder
44
  C.VALIDATION_DATASET = validation_folder
45
- C.ACCUM_GRADIENT_STEP = 16
46
- C.BATCH_SIZE = 2 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
47
- C.EARLY_STOP_PATIENCE = 10 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
48
  C.LEARNING_RATE = lr
49
  C.LIMIT_NUM_SAMPLES = None
50
- C.EPOCHS = 10000
51
 
52
  aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
53
  "GPU: {}\n" \
@@ -81,7 +88,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
81
  validation_generator = data_generator.get_validation_generator()
82
 
83
  image_input_shape = train_generator.get_data_shape()[1][:-1]
84
- image_output_shape = [64] * 3
85
 
86
  nb_labels = len(train_generator.get_segmentation_labels())
87
 
@@ -109,8 +116,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
109
  trainable=False)
110
  augm_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
111
 
112
- enc_features = [16, 32, 32, 32]# const.ENCODER_FILTERS
113
- dec_features = [32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
 
 
114
  nb_features = [enc_features, dec_features]
115
 
116
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
@@ -138,7 +147,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
138
  @function_decorator('MS_SSIM_MSE__loss')
139
  def loss_simil(y_true, y_pred):
140
  return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred)
141
- elif simil=='ssim__ncc' or simil=='ncc__ssim' :
142
  @function_decorator('SSIM_NCC__loss')
143
  def loss_simil(y_true, y_pred):
144
  return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred)
@@ -153,6 +162,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
153
  loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
154
  elif segm == 'dice':
155
  loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
 
 
156
  else:
157
  raise ValueError('No valid value for segm')
158
 
@@ -163,8 +174,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
163
  'seg_transformer': 1.,
164
  'flow': 5e-3}
165
  metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
166
- 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
167
- HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
168
  ]}
169
  metrics_weights = {'transformer': 1,
170
  'seg_transformer': 1,
 
29
  from Brain_study.utils import SummaryDictionary, named_logs
30
 
31
  import time
32
+ import warnings
33
 
34
+
35
+ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd',
36
+ acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
37
+ unet=[16, 32, 64, 128, 256], head=[16, 16]):
38
  assert dataset_folder is not None and output_folder is not None
39
 
40
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
41
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
42
  C.GPU_NUM = str(gpu_num)
43
 
44
+ if batch_size != 1 and acc_gradients != 1:
45
+ warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
46
+
47
  output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
48
  os.makedirs(output_folder, exist_ok=True)
49
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
50
  C.TRAINING_DATASET = dataset_folder
51
  C.VALIDATION_DATASET = validation_folder
52
+ C.ACCUM_GRADIENT_STEP = acc_gradients
53
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
54
+ C.EARLY_STOP_PATIENCE = early_stop_patience
55
  C.LEARNING_RATE = lr
56
  C.LIMIT_NUM_SAMPLES = None
57
+ C.EPOCHS = max_epochs
58
 
59
  aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
60
  "GPU: {}\n" \
 
88
  validation_generator = data_generator.get_validation_generator()
89
 
90
  image_input_shape = train_generator.get_data_shape()[1][:-1]
91
+ image_output_shape = [image_size] * 3
92
 
93
  nb_labels = len(train_generator.get_segmentation_labels())
94
 
 
116
  trainable=False)
117
  augm_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
118
 
119
+ # enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
120
+ # dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
121
+ enc_features = unet # const.ENCODER_FILTERS
122
+ dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
123
  nb_features = [enc_features, dec_features]
124
 
125
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
 
147
  @function_decorator('MS_SSIM_MSE__loss')
148
  def loss_simil(y_true, y_pred):
149
  return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred)
150
+ elif simil=='ssim__ncc' or simil=='ncc__ssim':
151
  @function_decorator('SSIM_NCC__loss')
152
  def loss_simil(y_true, y_pred):
153
  return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred)
 
162
  loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
163
  elif segm == 'dice':
164
  loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
165
+ elif segm == 'dice_macro':
166
+ loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
167
  else:
168
  raise ValueError('No valid value for segm')
169
 
 
174
  'seg_transformer': 1.,
175
  'flow': 5e-3}
176
  metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
177
+ 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
178
+ #HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
179
  ]}
180
  metrics_weights = {'transformer': 1,
181
  'seg_transformer': 1,
Brain_study/Train_UncertaintyWeighted.py CHANGED
@@ -26,28 +26,33 @@ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccum
26
 
27
  from Brain_study.data_generator import BatchGenerator
28
  from Brain_study.utils import SummaryDictionary, named_logs
 
29
 
30
 
31
- def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4,
32
- gpu_num=0, simil=['mse'], segm=['dice']):
 
33
  assert dataset_folder is not None and output_folder is not None
34
 
35
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
36
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
37
  C.GPU_NUM = str(gpu_num)
38
 
 
 
 
39
  output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
40
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
41
  os.makedirs(output_folder, exist_ok=True)
42
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
43
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
44
  C.VALIDATION_DATASET = validation_folder
45
- C.ACCUM_GRADIENT_STEP = 16
46
- C.BATCH_SIZE = 2 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
47
- C.EARLY_STOP_PATIENCE = 10 * (C.ACCUM_GRADIENT_STEP/2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
48
  C.LEARNING_RATE = lr
49
  C.LIMIT_NUM_SAMPLES = None
50
- C.EPOCHS = 10000
51
 
52
  aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
53
  "GPU: {}\n" \
@@ -81,7 +86,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
81
  validation_generator = data_generator.get_validation_generator()
82
 
83
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
84
- image_output_shape = [64] * 3
85
 
86
  nb_labels = len(train_generator.get_segmentation_labels())
87
 
@@ -119,13 +124,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
119
  loss_segm = []
120
  for s in segm:
121
  if s=='dice':
122
- loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).loss)
123
  prior_loss_w.append(1.)
124
  elif s=='hd':
125
- loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).loss)
 
 
 
126
  prior_loss_w.append(1.)
127
  else:
128
- raise ValueError('Unknown similarity function: ', s)
129
 
130
  # Build augmentation layer model
131
  input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation')
@@ -142,8 +150,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
142
  trainable=False)
143
  augmentation_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
144
 
145
- enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
146
- dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
 
 
147
  nb_features = [enc_features, dec_features]
148
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
149
  nb_labels=nb_labels,
 
26
 
27
  from Brain_study.data_generator import BatchGenerator
28
  from Brain_study.utils import SummaryDictionary, named_logs
29
+ import warnings
30
 
31
 
32
+ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4, rw=5e-3,
33
+ gpu_num=0, simil=['mse'], segm=['dice'], acc_gradients=16, batch_size=1, max_epochs=10000,
34
+ early_stop_patience=1000, image_size=64, unet=[16, 32, 64, 128, 256], head=[16, 16]):
35
  assert dataset_folder is not None and output_folder is not None
36
 
37
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
38
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
39
  C.GPU_NUM = str(gpu_num)
40
 
41
+ if batch_size != 1 and acc_gradients != 1:
42
+ warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
43
+
44
  output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
45
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
46
  os.makedirs(output_folder, exist_ok=True)
47
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
48
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
49
  C.VALIDATION_DATASET = validation_folder
50
+ C.ACCUM_GRADIENT_STEP = acc_gradients
51
+ C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else 1
52
+ C.EARLY_STOP_PATIENCE = early_stop_patience
53
  C.LEARNING_RATE = lr
54
  C.LIMIT_NUM_SAMPLES = None
55
+ C.EPOCHS = max_epochs
56
 
57
  aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
58
  "GPU: {}\n" \
 
86
  validation_generator = data_generator.get_validation_generator()
87
 
88
  image_input_shape = train_generator.get_data_shape()[-1][:-1]
89
+ image_output_shape = [image_size] * 3
90
 
91
  nb_labels = len(train_generator.get_segmentation_labels())
92
 
 
124
  loss_segm = []
125
  for s in segm:
126
  if s=='dice':
127
+ loss_segm.append(GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss)
128
  prior_loss_w.append(1.)
129
  elif s=='hd':
130
+ loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss)
131
+ prior_loss_w.append(1.)
132
+ elif s == 'dice_macro':
133
+ loss_segm.append(GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro)
134
  prior_loss_w.append(1.)
135
  else:
136
+ raise ValueError('Unknown similarity function: ' + s)
137
 
138
  # Build augmentation layer model
139
  input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation')
 
150
  trainable=False)
151
  augmentation_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
152
 
153
+ # enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
154
+ # dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
155
+ enc_features = unet # const.ENCODER_FILTERS
156
+ dec_features = enc_features[::-1] + head # const.ENCODER_FILTERS[::-1]
157
  nb_features = [enc_features, dec_features]
158
  network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
159
  nb_labels=nb_labels,