Commit
·
ca253db
1
Parent(s):
8fb4d8e
Generalized the DataGenerator to accept two lists of input and output (wrt network) labels to fetch from the dataset files.
Browse files
DeepDeformationMapRegistration/data_generator.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
| 1 |
-
import sys, os
|
| 2 |
-
currentdir = os.path.dirname(os.path.realpath(__file__))
|
| 3 |
-
parentdir = os.path.dirname(currentdir)
|
| 4 |
-
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
| 5 |
-
|
| 6 |
-
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
| 7 |
-
|
| 8 |
import numpy as np
|
| 9 |
from tensorflow import keras
|
| 10 |
import os
|
|
@@ -17,9 +10,9 @@ from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
| 17 |
|
| 18 |
|
| 19 |
class DataGeneratorManager(keras.utils.Sequence):
|
| 20 |
-
def __init__(self, dataset_path, batch_size=32, shuffle=True,
|
| 21 |
-
clip_range=[0., 1.],
|
| 22 |
-
|
| 23 |
# Get the list of files
|
| 24 |
self.__list_files = self.__get_dataset_files(dataset_path)
|
| 25 |
self.__list_files.sort()
|
|
@@ -32,9 +25,8 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
| 32 |
|
| 33 |
self.__validation_samples = validation_samples
|
| 34 |
|
| 35 |
-
self.
|
| 36 |
-
self.
|
| 37 |
-
self.__seg_labels = seg_labels
|
| 38 |
|
| 39 |
if num_samples is not None:
|
| 40 |
self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
|
|
@@ -93,6 +85,14 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
| 93 |
def shuffle(self):
|
| 94 |
return self.__shuffle
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def get_generator_idxs(self, generator_type):
|
| 97 |
if generator_type == 'train':
|
| 98 |
return self.train_idxs
|
|
@@ -150,18 +150,6 @@ class DataGeneratorManager(keras.utils.Sequence):
|
|
| 150 |
else:
|
| 151 |
raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
|
| 152 |
|
| 153 |
-
@property
|
| 154 |
-
def is_voxelmorph(self):
|
| 155 |
-
return self.__voxelmorph
|
| 156 |
-
|
| 157 |
-
@property
|
| 158 |
-
def give_segmentations(self):
|
| 159 |
-
return self.__segmentations
|
| 160 |
-
|
| 161 |
-
@property
|
| 162 |
-
def seg_labels(self):
|
| 163 |
-
return self.__seg_labels
|
| 164 |
-
|
| 165 |
|
| 166 |
class DataGenerator(DataGeneratorManager):
|
| 167 |
def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
|
|
@@ -173,8 +161,6 @@ class DataGenerator(DataGeneratorManager):
|
|
| 173 |
self.__manager = GeneratorManager
|
| 174 |
self.__shuffle = GeneratorManager.shuffle
|
| 175 |
|
| 176 |
-
self.__seg_labels = GeneratorManager.seg_labels
|
| 177 |
-
|
| 178 |
self.__num_samples = len(self.__list_files)
|
| 179 |
self.__internal_idxs = np.arange(self.__num_samples)
|
| 180 |
# These indices are internal to the generator, they are not the same as the dataset_idxs!!
|
|
@@ -184,8 +170,8 @@ class DataGenerator(DataGeneratorManager):
|
|
| 184 |
self.__last_batch = 0
|
| 185 |
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
|
| 186 |
|
| 187 |
-
self.
|
| 188 |
-
self.
|
| 189 |
|
| 190 |
@staticmethod
|
| 191 |
def __get_dataset_files(search_path):
|
|
@@ -228,6 +214,22 @@ class DataGenerator(DataGeneratorManager):
|
|
| 228 |
"""
|
| 229 |
return self.__batches_per_epoch
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
def __getitem__(self, index):
|
| 232 |
"""
|
| 233 |
Generate one batch of data
|
|
@@ -236,36 +238,15 @@ class DataGenerator(DataGeneratorManager):
|
|
| 236 |
"""
|
| 237 |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
try:
|
| 242 |
-
fix_img = min_max_norm(fix_img).astype(np.float32)
|
| 243 |
-
mov_img = min_max_norm(mov_img).astype(np.float32)
|
| 244 |
-
except ValueError:
|
| 245 |
-
print(idxs, fix_img.shape, mov_img.shape)
|
| 246 |
-
er_str = 'ERROR:\t[file]:\t{}\t[idx]:\t{}\t[fix_img.shape]:\t{}\t[mov_img.shape]:\t{}\t'.format(self.__list_files[idxs], idxs, fix_img.shape, mov_img.shape)
|
| 247 |
-
raise ValueError(er_str)
|
| 248 |
-
|
| 249 |
-
fix_vessels[fix_vessels > 0.] = self.__seg_labels['vessels']
|
| 250 |
-
mov_vessels[mov_vessels > 0.] = self.__seg_labels['vessels']
|
| 251 |
-
|
| 252 |
-
# fix_tumour[fix_tumour > 0.] = self.__seg_labels['tumour']
|
| 253 |
-
# mov_tumour[mov_tumour > 0.] = self.__seg_labels['tumour']
|
| 254 |
|
| 255 |
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
|
| 256 |
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
|
| 257 |
# The second element must match the outputs of the model, in this case (image, displacement map)
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
outputs = [] #[fix_img, zero_grad]
|
| 263 |
-
else:
|
| 264 |
-
inputs = [mov_img, fix_img]
|
| 265 |
-
outputs = [fix_img, zero_grad]
|
| 266 |
-
return (inputs, outputs)
|
| 267 |
-
else:
|
| 268 |
-
return (fix_img, mov_img, fix_vessels, mov_vessels), # (None, fix_seg, fix_seg, fix_img)
|
| 269 |
|
| 270 |
def next_batch(self):
|
| 271 |
if self.__last_batch > self.__batches_per_epoch:
|
|
@@ -274,6 +255,24 @@ class DataGenerator(DataGeneratorManager):
|
|
| 274 |
self.__last_batch += 1
|
| 275 |
return batch
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
def __load_data(self, idx_list):
|
| 278 |
"""
|
| 279 |
Build the batch with the samples in idx_list
|
|
@@ -283,73 +282,87 @@ class DataGenerator(DataGeneratorManager):
|
|
| 283 |
if isinstance(idx_list, (list, np.ndarray)):
|
| 284 |
fix_img = np.empty((0, ) + C.IMG_SHAPE)
|
| 285 |
mov_img = np.empty((0, ) + C.IMG_SHAPE)
|
| 286 |
-
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
|
| 291 |
fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
| 292 |
mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
|
| 293 |
fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
| 294 |
mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
|
|
|
|
|
|
|
| 295 |
for idx in idx_list:
|
| 296 |
data_file = h5py.File(self.__list_files[idx], 'r')
|
| 297 |
|
| 298 |
-
fix_img =
|
| 299 |
-
mov_img =
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
|
| 304 |
-
|
|
|
|
| 305 |
|
| 306 |
-
|
| 307 |
-
mov_vessels = np.append(mov_vessels, [data_file[C.H5_MOV_VESSELS_MASK][:]], axis=0)
|
| 308 |
-
fix_tumors = np.append(fix_tumors, [data_file[C.H5_FIX_TUMORS_MASK][:]], axis=0)
|
| 309 |
-
mov_tumors = np.append(mov_tumors, [data_file[C.H5_MOV_TUMORS_MASK][:]], axis=0)
|
| 310 |
|
| 311 |
data_file.close()
|
|
|
|
| 312 |
else:
|
| 313 |
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
| 314 |
|
| 315 |
-
fix_img =
|
| 316 |
-
mov_img =
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
|
| 321 |
-
fix_vessels =
|
| 322 |
-
mov_vessels =
|
| 323 |
-
fix_tumors = np.expand_dims(data_file[C.H5_FIX_TUMORS_MASK][:], axis=0)
|
| 324 |
-
mov_tumors = np.expand_dims(data_file[C.H5_MOV_TUMORS_MASK][:], axis=0)
|
| 325 |
|
| 326 |
-
|
|
|
|
| 327 |
|
| 328 |
-
data_file.
|
| 329 |
-
|
| 330 |
-
return fix_img, mov_img, fix_vessels, mov_vessels, fix_tumors, mov_tumors, disp_map
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
# return X, y
|
| 335 |
-
return
|
| 336 |
-
|
| 337 |
-
def get_random_sample(self, num_samples):
|
| 338 |
-
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
| 339 |
-
fix_img, mov_img, fix_segm, mov_segm, disp_map = self.__load_data(idxs)
|
| 340 |
-
|
| 341 |
-
return (fix_img, mov_img, fix_segm, mov_segm, disp_map), [self.__list_files[f] for f in idxs]
|
| 342 |
|
| 343 |
def get_input_shape(self):
|
| 344 |
input_batch, _ = self.__getitem__(0)
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
ret_val = input_batch.shape
|
| 351 |
-
ret_val = (None, ) + ret_val[1:]
|
| 352 |
-
return ret_val # const.BATCH_SHAPE_SEGM
|
| 353 |
|
| 354 |
def who_are_you(self):
|
| 355 |
return self.__dataset_type
|
|
@@ -361,6 +374,7 @@ class DataGenerator(DataGeneratorManager):
|
|
| 361 |
class DataGeneratorManager2D:
|
| 362 |
FIX_IMG_H5 = 'input/1'
|
| 363 |
MOV_IMG_H5 = 'input/0'
|
|
|
|
| 364 |
def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
|
| 365 |
fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
|
| 366 |
self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from tensorflow import keras
|
| 3 |
import os
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class DataGeneratorManager(keras.utils.Sequence):
|
| 13 |
+
def __init__(self, dataset_path, batch_size=32, shuffle=True,
|
| 14 |
+
num_samples=None, validation_split=None, validation_samples=None, clip_range=[0., 1.],
|
| 15 |
+
input_labels=[C.H5_MOV_IMG, C.H5_FIX_IMG], output_labels=[C.H5_FIX_IMG, 'zero_gradient']):
|
| 16 |
# Get the list of files
|
| 17 |
self.__list_files = self.__get_dataset_files(dataset_path)
|
| 18 |
self.__list_files.sort()
|
|
|
|
| 25 |
|
| 26 |
self.__validation_samples = validation_samples
|
| 27 |
|
| 28 |
+
self.__input_labels = input_labels
|
| 29 |
+
self.__output_labels = output_labels
|
|
|
|
| 30 |
|
| 31 |
if num_samples is not None:
|
| 32 |
self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
|
|
|
|
| 85 |
def shuffle(self):
|
| 86 |
return self.__shuffle
|
| 87 |
|
| 88 |
+
@property
|
| 89 |
+
def input_labels(self):
|
| 90 |
+
return self.__input_labels
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def output_labels(self):
|
| 94 |
+
return self.__output_labels
|
| 95 |
+
|
| 96 |
def get_generator_idxs(self, generator_type):
|
| 97 |
if generator_type == 'train':
|
| 98 |
return self.train_idxs
|
|
|
|
| 150 |
else:
|
| 151 |
raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
class DataGenerator(DataGeneratorManager):
|
| 155 |
def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
|
|
|
|
| 161 |
self.__manager = GeneratorManager
|
| 162 |
self.__shuffle = GeneratorManager.shuffle
|
| 163 |
|
|
|
|
|
|
|
| 164 |
self.__num_samples = len(self.__list_files)
|
| 165 |
self.__internal_idxs = np.arange(self.__num_samples)
|
| 166 |
# These indices are internal to the generator, they are not the same as the dataset_idxs!!
|
|
|
|
| 170 |
self.__last_batch = 0
|
| 171 |
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
|
| 172 |
|
| 173 |
+
self.__input_labels = GeneratorManager.input_labels
|
| 174 |
+
self.__output_labels = GeneratorManager.output_labels
|
| 175 |
|
| 176 |
@staticmethod
|
| 177 |
def __get_dataset_files(search_path):
|
|
|
|
| 214 |
"""
|
| 215 |
return self.__batches_per_epoch
|
| 216 |
|
| 217 |
+
@staticmethod
|
| 218 |
+
def __build_list(data_dict, labels):
|
| 219 |
+
ret_list = list()
|
| 220 |
+
for label in labels:
|
| 221 |
+
if label in data_dict.keys():
|
| 222 |
+
if label in [C.DG_LBL_FIX_IMG, C.DG_LBL_MOV_IMG]:
|
| 223 |
+
ret_list.append(min_max_norm(data_dict[label]).astype(np.float32))
|
| 224 |
+
elif label in [C.DG_LBL_FIX_PARENCHYMA, C.DG_LBL_FIX_VESSELS, C.DG_LBL_FIX_TUMOR,
|
| 225 |
+
C.DG_LBL_MOV_PARENCHYMA, C.DG_LBL_MOV_VESSELS, C.DG_LBL_MOV_TUMOR]:
|
| 226 |
+
aux = data_dict[label]
|
| 227 |
+
aux[aux > 0.] = 1.
|
| 228 |
+
ret_list.append(aux)
|
| 229 |
+
elif label == C.DG_LBL_ZERO_GRADS:
|
| 230 |
+
ret_list.append(np.zeros([data_dict['BATCH_SIZE'], *C.DISP_MAP_SHAPE]))
|
| 231 |
+
return ret_list
|
| 232 |
+
|
| 233 |
def __getitem__(self, index):
|
| 234 |
"""
|
| 235 |
Generate one batch of data
|
|
|
|
| 238 |
"""
|
| 239 |
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
| 240 |
|
| 241 |
+
data_dict = self.__load_data(idxs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
|
| 244 |
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
|
| 245 |
# The second element must match the outputs of the model, in this case (image, displacement map)
|
| 246 |
+
inputs = self.__build_list(data_dict, self.__input_labels)
|
| 247 |
+
outputs = self.__build_list(data_dict, self.__output_labels)
|
| 248 |
+
|
| 249 |
+
return (inputs, outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
def next_batch(self):
|
| 252 |
if self.__last_batch > self.__batches_per_epoch:
|
|
|
|
| 255 |
self.__last_batch += 1
|
| 256 |
return batch
|
| 257 |
|
| 258 |
+
def __try_load(self, data_file, label, append_array=None):
|
| 259 |
+
if label in self.__input_labels or label in self.__output_labels:
|
| 260 |
+
# To avoid extra overhead
|
| 261 |
+
try:
|
| 262 |
+
retVal = data_file[label][:]
|
| 263 |
+
except KeyError:
|
| 264 |
+
# That particular label is not found in the file. But this should be known by the user by now
|
| 265 |
+
retVal = None
|
| 266 |
+
|
| 267 |
+
if append_array is not None and retVal is not None:
|
| 268 |
+
return np.append(append_array, [data_file[C.H5_FIX_IMG][:]], axis=0)
|
| 269 |
+
elif append_array is None:
|
| 270 |
+
return retVal[np.newaxis, ...]
|
| 271 |
+
else:
|
| 272 |
+
return retVal # None
|
| 273 |
+
else:
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
def __load_data(self, idx_list):
|
| 277 |
"""
|
| 278 |
Build the batch with the samples in idx_list
|
|
|
|
| 282 |
if isinstance(idx_list, (list, np.ndarray)):
|
| 283 |
fix_img = np.empty((0, ) + C.IMG_SHAPE)
|
| 284 |
mov_img = np.empty((0, ) + C.IMG_SHAPE)
|
|
|
|
| 285 |
|
| 286 |
+
fix_parench = np.empty((0, ) + C.IMG_SHAPE)
|
| 287 |
+
mov_parench = np.empty((0, ) + C.IMG_SHAPE)
|
| 288 |
|
| 289 |
fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
| 290 |
mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
| 291 |
+
|
| 292 |
fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
| 293 |
mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
| 294 |
+
|
| 295 |
+
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
|
| 296 |
+
|
| 297 |
for idx in idx_list:
|
| 298 |
data_file = h5py.File(self.__list_files[idx], 'r')
|
| 299 |
|
| 300 |
+
fix_img = self.__try_load(data_file, C.H5_FIX_IMG, fix_img)
|
| 301 |
+
mov_img = self.__try_load(data_file, C.H5_MOV_IMG, mov_img)
|
| 302 |
+
|
| 303 |
+
fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK, fix_parench)
|
| 304 |
+
mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK, mov_parench)
|
| 305 |
|
| 306 |
+
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK, fix_vessels)
|
| 307 |
+
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK, mov_vessels)
|
| 308 |
|
| 309 |
+
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK, mov_parench)
|
| 310 |
+
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK, mov_parench)
|
| 311 |
|
| 312 |
+
disp_map = self.__try_load(data_file, C.H5_GT_DISP, disp_map)
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
data_file.close()
|
| 315 |
+
batch_size = len(idx_list)
|
| 316 |
else:
|
| 317 |
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
| 318 |
|
| 319 |
+
fix_img = self.__try_load(data_file, C.H5_FIX_IMG)
|
| 320 |
+
mov_img = self.__try_load(data_file, C.H5_MOV_IMG)
|
| 321 |
|
| 322 |
+
fix_parench = self.__try_load(data_file, C.H5_FIX_PARENCHYMA_MASK)
|
| 323 |
+
mov_parench = self.__try_load(data_file, C.H5_MOV_PARENCHYMA_MASK)
|
| 324 |
|
| 325 |
+
fix_vessels = self.__try_load(data_file, C.H5_FIX_VESSELS_MASK)
|
| 326 |
+
mov_vessels = self.__try_load(data_file, C.H5_MOV_VESSELS_MASK)
|
|
|
|
|
|
|
| 327 |
|
| 328 |
+
fix_tumors = self.__try_load(data_file, C.H5_FIX_TUMORS_MASK)
|
| 329 |
+
mov_tumors = self.__try_load(data_file, C.H5_MOV_TUMORS_MASK)
|
| 330 |
|
| 331 |
+
disp_map = self.__try_load(data_file, C.H5_GT_DISP)
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
data_file.close()
|
| 334 |
+
batch_size = 1
|
| 335 |
+
|
| 336 |
+
data_dict = {C.H5_FIX_IMG: fix_img,
|
| 337 |
+
C.H5_FIX_TUMORS_MASK: fix_tumors,
|
| 338 |
+
C.H5_FIX_VESSELS_MASK: fix_vessels,
|
| 339 |
+
C.H5_FIX_PARENCHYMA_MASK: fix_parench,
|
| 340 |
+
C.H5_MOV_IMG: mov_img,
|
| 341 |
+
C.H5_MOV_TUMORS_MASK: mov_tumors,
|
| 342 |
+
C.H5_MOV_VESSELS_MASK: mov_vessels,
|
| 343 |
+
C.H5_MOV_PARENCHYMA_MASK: mov_parench,
|
| 344 |
+
C.H5_GT_DISP: disp_map,
|
| 345 |
+
'BATCH_SIZE': batch_size
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
return data_dict
|
| 349 |
+
|
| 350 |
+
def get_samples(self, num_samples, random=False):
|
| 351 |
+
if random:
|
| 352 |
+
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
| 353 |
+
else:
|
| 354 |
+
idxs = np.arange(0, num_samples)
|
| 355 |
+
data_dict = self.__load_data(idxs)
|
| 356 |
# return X, y
|
| 357 |
+
return self.__build_list(data_dict, self.__input_labels), self.__build_list(data_dict, self.__output_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
def get_input_shape(self):
|
| 360 |
input_batch, _ = self.__getitem__(0)
|
| 361 |
+
data_dict = self.__load_data(0)
|
| 362 |
+
|
| 363 |
+
ret_val = data_dict[self.__input_labels[0]].shape
|
| 364 |
+
ret_val = (None, ) + ret_val[1:]
|
| 365 |
+
return ret_val # const.BATCH_SHAPE_SEGM
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
def who_are_you(self):
|
| 368 |
return self.__dataset_type
|
|
|
|
| 374 |
class DataGeneratorManager2D:
|
| 375 |
FIX_IMG_H5 = 'input/1'
|
| 376 |
MOV_IMG_H5 = 'input/0'
|
| 377 |
+
|
| 378 |
def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
|
| 379 |
fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
|
| 380 |
self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
|
@@ -65,6 +65,17 @@ MAX_FLIPS = 2 # Axes to flip over
|
|
| 65 |
NUM_ROTATIONS = 5
|
| 66 |
MAX_WORKERS = 10
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Training constants
|
| 69 |
MODEL = 'unet'
|
| 70 |
BATCH_NORM = False
|
|
@@ -478,10 +489,9 @@ EPS_1_tf = tf.constant(EPS_1)
|
|
| 478 |
# LDDMM
|
| 479 |
GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
|
| 480 |
|
| 481 |
-
# Constants for
|
| 482 |
PRIOR_W = [1., 1 / 60, 1.]
|
| 483 |
MANUAL_W = [1.] * len(PRIOR_W)
|
| 484 |
|
| 485 |
REG_PRIOR_W = [1e-3]
|
| 486 |
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
| 487 |
-
|
|
|
|
| 65 |
NUM_ROTATIONS = 5
|
| 66 |
MAX_WORKERS = 10
|
| 67 |
|
| 68 |
+
# Labels to pass to the input_labels and output_labels parameter of DataGeneratorManager
|
| 69 |
+
DG_LBL_FIX_IMG = H5_FIX_IMG
|
| 70 |
+
DG_LBL_FIX_VESSELS = H5_FIX_VESSELS_MASK
|
| 71 |
+
DG_LBL_FIX_PARENCHYMA = H5_FIX_PARENCHYMA_MASK
|
| 72 |
+
DG_LBL_FIX_TUMOR = H5_FIX_TUMORS_MASK
|
| 73 |
+
DG_LBL_MOV_IMG = H5_MOV_IMG
|
| 74 |
+
DG_LBL_MOV_VESSELS = H5_MOV_VESSELS_MASK
|
| 75 |
+
DG_LBL_MOV_PARENCHYMA = H5_MOV_PARENCHYMA_MASK
|
| 76 |
+
DG_LBL_MOV_TUMOR = H5_MOV_TUMORS_MASK
|
| 77 |
+
DG_LBL_ZERO_GRADS = 'zero_gradients'
|
| 78 |
+
|
| 79 |
# Training constants
|
| 80 |
MODEL = 'unet'
|
| 81 |
BATCH_NORM = False
|
|
|
|
| 489 |
# LDDMM
|
| 490 |
GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
|
| 491 |
|
| 492 |
+
# Constants for Unsupervised Learning layer
|
| 493 |
PRIOR_W = [1., 1 / 60, 1.]
|
| 494 |
MANUAL_W = [1.] * len(PRIOR_W)
|
| 495 |
|
| 496 |
REG_PRIOR_W = [1e-3]
|
| 497 |
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
|
|