Commit
·
ca253db
1
Parent(s):
8fb4d8e
Generalized the DataGenerator to accept two lists of input and output (wrt network) labels to fetch from the dataset files.
Browse files
DeepDeformationMapRegistration/data_generator.py
CHANGED
@@ -1,10 +1,3 @@
|
|
1 |
-
import sys, os
|
2 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
-
parentdir = os.path.dirname(currentdir)
|
4 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
-
|
6 |
-
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
-
|
8 |
import numpy as np
|
9 |
from tensorflow import keras
|
10 |
import os
|
@@ -17,9 +10,9 @@ from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
17 |
|
18 |
|
19 |
class DataGeneratorManager(keras.utils.Sequence):
|
20 |
-
def __init__(self, dataset_path, batch_size=32, shuffle=True,
|
21 |
-
clip_range=[0., 1.],
|
22 |
-
|
23 |
# Get the list of files
|
24 |
self.__list_files = self.__get_dataset_files(dataset_path)
|
25 |
self.__list_files.sort()
|
@@ -32,9 +25,8 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
32 |
|
33 |
self.__validation_samples = validation_samples
|
34 |
|
35 |
-
self.
|
36 |
-
self.
|
37 |
-
self.__seg_labels = seg_labels
|
38 |
|
39 |
if num_samples is not None:
|
40 |
self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
|
@@ -93,6 +85,14 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
93 |
def shuffle(self):
|
94 |
return self.__shuffle
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def get_generator_idxs(self, generator_type):
|
97 |
if generator_type == 'train':
|
98 |
return self.train_idxs
|
@@ -150,18 +150,6 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
150 |
else:
|
151 |
raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
|
152 |
|
153 |
-
@property
|
154 |
-
def is_voxelmorph(self):
|
155 |
-
return self.__voxelmorph
|
156 |
-
|
157 |
-
@property
|
158 |
-
def give_segmentations(self):
|
159 |
-
return self.__segmentations
|
160 |
-
|
161 |
-
@property
|
162 |
-
def seg_labels(self):
|
163 |
-
return self.__seg_labels
|
164 |
-
|
165 |
|
166 |
class DataGenerator(DataGeneratorManager):
|
167 |
def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
|
@@ -173,8 +161,6 @@ class DataGenerator(DataGeneratorManager):
|
|
173 |
self.__manager = GeneratorManager
|
174 |
self.__shuffle = GeneratorManager.shuffle
|
175 |
|
176 |
-
self.__seg_labels = GeneratorManager.seg_labels
|
177 |
-
|
178 |
self.__num_samples = len(self.__list_files)
|
179 |
self.__internal_idxs = np.arange(self.__num_samples)
|
180 |
# These indices are internal to the generator, they are not the same as the dataset_idxs!!
|
@@ -184,8 +170,8 @@ class DataGenerator(DataGeneratorManager):
|
|
184 |
self.__last_batch = 0
|
185 |
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
|
186 |
|
187 |
-
self.
|
188 |
-
self.
|
189 |
|
190 |
@staticmethod
|
191 |
def __get_dataset_files(search_path):
|
@@ -228,6 +214,22 @@ class DataGenerator(DataGeneratorManager):
|
|
228 |
"""
|
229 |
return self.__batches_per_epoch
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
def __getitem__(self, index):
|
232 |
"""
|
233 |
Generate one batch of data
|
@@ -236,36 +238,15 @@ class DataGenerator(DataGeneratorManager):
|
|
236 |
"""
|
237 |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
try:
|
242 |
-
fix_img = min_max_norm(fix_img).astype(np.float32)
|
243 |
-
mov_img = min_max_norm(mov_img).astype(np.float32)
|
244 |
-
except ValueError:
|
245 |
-
print(idxs, fix_img.shape, mov_img.shape)
|
246 |
-
er_str = 'ERROR:\t[file]:\t{}\t[idx]:\t{}\t[fix_img.shape]:\t{}\t[mov_img.shape]:\t{}\t'.format(self.__list_files[idxs], idxs, fix_img.shape, mov_img.shape)
|
247 |
-
raise ValueError(er_str)
|
248 |
-
|
249 |
-
fix_vessels[fix_vessels > 0.] = self.__seg_labels['vessels']
|
250 |
-
mov_vessels[mov_vessels > 0.] = self.__seg_labels['vessels']
|
251 |
-
|
252 |
-
# fix_tumour[fix_tumour > 0.] = self.__seg_labels['tumour']
|
253 |
-
# mov_tumour[mov_tumour > 0.] = self.__seg_labels['tumour']
|
254 |
|
255 |
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
|
256 |
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
|
257 |
# The second element must match the outputs of the model, in this case (image, displacement map)
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
outputs = [] #[fix_img, zero_grad]
|
263 |
-
else:
|
264 |
-
inputs = [mov_img, fix_img]
|
265 |
-
outputs = [fix_img, zero_grad]
|
266 |
-
return (inputs, outputs)
|
267 |
-
else:
|
268 |
-
return (fix_img, mov_img, fix_vessels, mov_vessels), # (None, fix_seg, fix_seg, fix_img)
|
269 |
|
270 |
def next_batch(self):
|
271 |
if self.__last_batch > self.__batches_per_epoch:
|
@@ -274,6 +255,24 @@ class DataGenerator(DataGeneratorManager):
|
|
274 |
self.__last_batch += 1
|
275 |
return batch
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
def __load_data(self, idx_list):
|
278 |
"""
|
279 |
Build the batch with the samples in idx_list
|
@@ -283,73 +282,87 @@ class DataGenerator(DataGeneratorManager):
|
|
283 |
if isinstance(idx_list, (list, np.ndarray)):
|
284 |
fix_img = np.empty((0, ) + C.IMG_SHAPE)
|
285 |
mov_img = np.empty((0, ) + C.IMG_SHAPE)
|
286 |
-
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
|
287 |
|
288 |
-
|
289 |
-
|
290 |
|
291 |
fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
292 |
mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
293 |
fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
294 |
mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
|
|
|
|
295 |
for idx in idx_list:
|
296 |
data_file = h5py.File(self.__list_files[idx], 'r')
|
297 |
|
298 |
-
fix_img =
|
299 |
-
mov_img =
|
|
|
|
|
|
|
300 |
|
301 |
-
|
302 |
-
|
303 |
|
304 |
-
|
|
|
305 |
|
306 |
-
|
307 |
-
mov_vessels = np.append(mov_vessels, [data_file[C.H5_MOV_VESSELS_MASK][:]], axis=0)
|
308 |
-
fix_tumors = np.append(fix_tumors, [data_file[C.H5_FIX_TUMORS_MASK][:]], axis=0)
|
309 |
-
mov_tumors = np.append(mov_tumors, [data_file[C.H5_MOV_TUMORS_MASK][:]], axis=0)
|
310 |
|
311 |
data_file.close()
|
|
|
312 |
else:
|
313 |
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
314 |
|
315 |
-
fix_img =
|
316 |
-
mov_img =
|
317 |
|
318 |
-
|
319 |
-
|
320 |
|
321 |
-
fix_vessels =
|
322 |
-
mov_vessels =
|
323 |
-
fix_tumors = np.expand_dims(data_file[C.H5_FIX_TUMORS_MASK][:], axis=0)
|
324 |
-
mov_tumors = np.expand_dims(data_file[C.H5_MOV_TUMORS_MASK][:], axis=0)
|
325 |
|
326 |
-
|
|
|
327 |
|
328 |
-
data_file.
|
329 |
-
|
330 |
-
return fix_img, mov_img, fix_vessels, mov_vessels, fix_tumors, mov_tumors, disp_map
|
331 |
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
# return X, y
|
335 |
-
return
|
336 |
-
|
337 |
-
def get_random_sample(self, num_samples):
|
338 |
-
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
339 |
-
fix_img, mov_img, fix_segm, mov_segm, disp_map = self.__load_data(idxs)
|
340 |
-
|
341 |
-
return (fix_img, mov_img, fix_segm, mov_segm, disp_map), [self.__list_files[f] for f in idxs]
|
342 |
|
343 |
def get_input_shape(self):
|
344 |
input_batch, _ = self.__getitem__(0)
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
ret_val = input_batch.shape
|
351 |
-
ret_val = (None, ) + ret_val[1:]
|
352 |
-
return ret_val # const.BATCH_SHAPE_SEGM
|
353 |
|
354 |
def who_are_you(self):
|
355 |
return self.__dataset_type
|
@@ -361,6 +374,7 @@ class DataGenerator(DataGeneratorManager):
|
|
361 |
class DataGeneratorManager2D:
|
362 |
FIX_IMG_H5 = 'input/1'
|
363 |
MOV_IMG_H5 = 'input/0'
|
|
|
364 |
def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
|
365 |
fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
|
366 |
self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
from tensorflow import keras
|
3 |
import os
|
|
|
10 |
|
11 |
|
12 |
class DataGeneratorManager(keras.utils.Sequence):
|
13 |
+
def __init__(self, dataset_path, batch_size=32, shuffle=True,
|
14 |
+
num_samples=None, validation_split=None, validation_samples=None, clip_range=[0., 1.],
|
15 |
+
input_labels=[C.H5_MOV_IMG, C.H5_FIX_IMG], output_labels=[C.H5_FIX_IMG, 'zero_gradient']):
|
16 |
# Get the list of files
|
17 |
self.__list_files = self.__get_dataset_files(dataset_path)
|
18 |
self.__list_files.sort()
|
|
|
25 |
|
26 |
self.__validation_samples = validation_samples
|
27 |
|
28 |
+
self.__input_labels = input_labels
|
29 |
+
self.__output_labels = output_labels
|
|
|
30 |
|
31 |
if num_samples is not None:
|
32 |
self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
|
|
|
85 |
def shuffle(self):
|
86 |
return self.__shuffle
|
87 |
|
88 |
+
@property
|
89 |
+
def input_labels(self):
|
90 |
+
return self.__input_labels
|
91 |
+
|
92 |
+
@property
|
93 |
+
def output_labels(self):
|
94 |
+
return self.__output_labels
|
95 |
+
|
96 |
def get_generator_idxs(self, generator_type):
|
97 |
if generator_type == 'train':
|
98 |
return self.train_idxs
|
|
|
150 |
else:
|
151 |
raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
class DataGenerator(DataGeneratorManager):
|
155 |
def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
|
|
|
161 |
self.__manager = GeneratorManager
|
162 |
self.__shuffle = GeneratorManager.shuffle
|
163 |
|
|
|
|
|
164 |
self.__num_samples = len(self.__list_files)
|
165 |
self.__internal_idxs = np.arange(self.__num_samples)
|
166 |
# These indices are internal to the generator, they are not the same as the dataset_idxs!!
|
|
|
170 |
self.__last_batch = 0
|
171 |
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
|
172 |
|
173 |
+
self.__input_labels = GeneratorManager.input_labels
|
174 |
+
self.__output_labels = GeneratorManager.output_labels
|
175 |
|
176 |
@staticmethod
|
177 |
def __get_dataset_files(search_path):
|
|
|
214 |
"""
|
215 |
return self.__batches_per_epoch
|
216 |
|
217 |
+
@staticmethod
|
218 |
+
def __build_list(data_dict, labels):
|
219 |
+
ret_list = list()
|
220 |
+
for label in labels:
|
221 |
+
if label in data_dict.keys():
|
222 |
+
if label in [C.DG_LBL_FIX_IMG, C.DG_LBL_MOV_IMG]:
|
223 |
+
ret_list.append(min_max_norm(data_dict[label]).astype(np.float32))
|
224 |
+
elif label in [C.DG_LBL_FIX_PARENCHYMA, C.DG_LBL_FIX_VESSELS, C.DG_LBL_FIX_TUMOR,
|
225 |
+
C.DG_LBL_MOV_PARENCHYMA, C.DG_LBL_MOV_VESSELS, C.DG_LBL_MOV_TUMOR]:
|
226 |
+
aux = data_dict[label]
|
227 |
+
aux[aux > 0.] = 1.
|
228 |
+
ret_list.append(aux)
|
229 |
+
elif label == C.DG_LBL_ZERO_GRADS:
|
230 |
+
ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
|
231 |
+
return ret_list
|
232 |
+
|
233 |
def __getitem__(self, index):
|
234 |
"""
|
235 |
Generate one batch of data
|
|
|
238 |
"""
|
239 |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
240 |
|
241 |
+
data_dict = self.__load_data(idxs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
|
244 |
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
|
245 |
# The second element must match the outputs of the model, in this case (image, displacement map)
|
246 |
+
inputs = self.__build_list(data_dict, self.__input_labels)
|
247 |
+
outputs = self.__build_list(data_dict, self.__output_labels)
|
248 |
+
|
249 |
+
return (inputs, outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
def next_batch(self):
|
252 |
if self.__last_batch > self.__batches_per_epoch:
|
|
|
255 |
self.__last_batch += 1
|
256 |
return batch
|
257 |
|
258 |
+
def __try_load(self, data_file, label, append_array=None):
|
259 |
+
if label in self.__input_labels or label in self.__output_labels:
|
260 |
+
# To avoid extra overhead
|
261 |
+
try:
|
262 |
+
retVal = data_file[label][:]
|
263 |
+
except KeyError:
|
264 |
+
# That particular label is not found in the file. But this should be known by the user by now
|
265 |
+
retVal = None
|
266 |
+
|
267 |
+
if append_array is not None and retVal is not None:
|
268 |
+
return np.append(append_array, [data_file[C.H5_FIX_IMG][:]], axis=0)
|
269 |
+
elif append_array is None:
|
270 |
+
return retVal[np.newaxis, ...]
|
271 |
+
else:
|
272 |
+
return retVal # None
|
273 |
+
else:
|
274 |
+
return None
|
275 |
+
|
276 |
def __load_data(self, idx_list):
|
277 |
"""
|
278 |
Build the batch with the samples in idx_list
|
|
|
282 |
if isinstance(idx_list, (list, np.ndarray)):
|
283 |
fix_img = np.empty((0, ) + C.IMG_SHAPE)
|
284 |
mov_img = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
285 |
|
286 |
+
fix_parench = np.empty((0, ) + C.IMG_SHAPE)
|
287 |
+
mov_parench = np.empty((0, ) + C.IMG_SHAPE)
|
288 |
|
289 |
fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
290 |
mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
291 |
+
|
292 |
fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
293 |
mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
294 |
+
|
295 |
+
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
|
296 |
+
|
297 |
for idx in idx_list:
|
298 |
data_file = h5py.File(self.__list_files[idx], 'r')
|
299 |
|
300 |
+
fix_img = self.__try_load(data_file, C.H5_FIX_IMG, fix_img)
|
301 |
+
mov_img = self.__try_load(data_file, C.H5_MOV_IMG, mov_img)
|
302 |
+
|
303 |
+
fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK, fix_parench)
|
304 |
+
mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK, mov_parench)
|
305 |
|
306 |
+
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
|
307 |
+
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
|
308 |
|
309 |
+
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, mov_parench)
|
310 |
+
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_parench)
|
311 |
|
312 |
+
disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
|
|
|
|
|
|
|
313 |
|
314 |
data_file.close()
|
315 |
+
batch_size = len(idx_list)
|
316 |
else:
|
317 |
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
318 |
|
319 |
+
fix_img = self.__try_load(data_file, C.H5_FIX_IMG)
|
320 |
+
mov_img = self.__try_load(data_file, C.H5_MOV_IMG)
|
321 |
|
322 |
+
fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK)
|
323 |
+
mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK)
|
324 |
|
325 |
+
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK)
|
326 |
+
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK)
|
|
|
|
|
327 |
|
328 |
+
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK)
|
329 |
+
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK)
|
330 |
|
331 |
+
disp_map = self.__try_load(data_file, C.H5_GT_DISP)
|
|
|
|
|
332 |
|
333 |
+
data_file.close()
|
334 |
+
batch_size = 1
|
335 |
+
|
336 |
+
data_dict = {C.H5_FIX_IMG: fix_img,
|
337 |
+
C.H5_FIX_TUMORS_MASK: fix_tumors,
|
338 |
+
C.H5_FIX_VESSELS_MASK: fix_vessels,
|
339 |
+
C.H5_FIX_PARENCHYMA_MASK: fix_parench,
|
340 |
+
C.H5_MOV_IMG: mov_img,
|
341 |
+
C.H5_MOV_TUMORS_MASK: mov_tumors,
|
342 |
+
C.H5_MOV_VESSELS_MASK: mov_vessels,
|
343 |
+
C.H5_MOV_PARENCHYMA_MASK: mov_parench,
|
344 |
+
C.H5_GT_DISP: disp_map,
|
345 |
+
'BATCH_SIZE': batch_size
|
346 |
+
}
|
347 |
+
|
348 |
+
return data_dict
|
349 |
+
|
350 |
+
def get_samples(self, num_samples, random=False):
|
351 |
+
if random:
|
352 |
+
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
353 |
+
else:
|
354 |
+
idxs = np.arange(0, num_samples)
|
355 |
+
data_dict = self.__load_data(idxs)
|
356 |
# return X, y
|
357 |
+
return self.__build_list(data_dict, self.__input_labels), self.__build_list(data_dict, self.__output_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
def get_input_shape(self):
|
360 |
input_batch, _ = self.__getitem__(0)
|
361 |
+
data_dict = self.__load_data(0)
|
362 |
+
|
363 |
+
ret_val = data_dict[self.__input_labels[0]].shape
|
364 |
+
ret_val = (None, ) + ret_val[1:]
|
365 |
+
return ret_val # const.BATCH_SHAPE_SEGM
|
|
|
|
|
|
|
366 |
|
367 |
def who_are_you(self):
|
368 |
return self.__dataset_type
|
|
|
374 |
class DataGeneratorManager2D:
|
375 |
FIX_IMG_H5 = 'input/1'
|
376 |
MOV_IMG_H5 = 'input/0'
|
377 |
+
|
378 |
def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
|
379 |
fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
|
380 |
self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
@@ -65,6 +65,17 @@ MAX_FLIPS = 2 # Axes to flip over
|
|
65 |
NUM_ROTATIONS = 5
|
66 |
MAX_WORKERS = 10
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# Training constants
|
69 |
MODEL = 'unet'
|
70 |
BATCH_NORM = False
|
@@ -478,10 +489,9 @@ EPS_1_tf = tf.constant(EPS_1)
|
|
478 |
# LDDMM
|
479 |
GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
|
480 |
|
481 |
-
# Constants for
|
482 |
PRIOR_W = [1., 1 / 60, 1.]
|
483 |
MANUAL_W = [1.] * len(PRIOR_W)
|
484 |
|
485 |
REG_PRIOR_W = [1e-3]
|
486 |
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
487 |
-
|
|
|
65 |
NUM_ROTATIONS = 5
|
66 |
MAX_WORKERS = 10
|
67 |
|
68 |
+
# Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
|
69 |
+
DG_LBL_FIX_IMG = H5_FIX_IMG
|
70 |
+
DG_LBL_FIX_VESSELS = H5_FIX_VESSELS_MASK
|
71 |
+
DG_LBL_FIX_PARENCHYMA = H5_FIX_PARENCHYMA_MASK
|
72 |
+
DG_LBL_FIX_TUMOR = H5_FIX_TUMORS_MASK
|
73 |
+
DG_LBL_MOV_IMG = H5_MOV_IMG
|
74 |
+
DG_LBL_MOV_VESSELS = H5_MOV_VESSELS_MASK
|
75 |
+
DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
|
76 |
+
DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
|
77 |
+
DG_LBL_ZERO_GRADS = 'zero_gradients'
|
78 |
+
|
79 |
# Training constants
|
80 |
MODEL = 'unet'
|
81 |
BATCH_NORM = False
|
|
|
489 |
# LDDMM
|
490 |
GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
|
491 |
|
492 |
+
# Constants for Unsupervised Learning layer
|
493 |
PRIOR_W = [1., 1 / 60, 1.]
|
494 |
MANUAL_W = [1.] * len(PRIOR_W)
|
495 |
|
496 |
REG_PRIOR_W = [1e-3]
|
497 |
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
|