jpdefrutos commited on
Commit
ab9857f
·
0 Parent(s):

Initial commit

Browse files
DeepDeformationMapRegistration/__init__.py ADDED
File without changes
DeepDeformationMapRegistration/data_generator.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import numpy as np
9
+ from tensorflow import keras
10
+ import os
11
+ import h5py
12
+ import random
13
+ from PIL import Image
14
+
15
+ import DeepDeformationMapRegistration.utils.constants as C
16
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm
17
+
18
+
19
+ class DataGeneratorManager(keras.utils.Sequence):
20
+ def __init__(self, dataset_path, batch_size=32, shuffle=True, num_samples=None, validation_split=None, validation_samples=None,
21
+ clip_range=[0., 1.], voxelmorph=False, segmentations=False,
22
+ seg_labels: dict = {'bg': 0, 'vessels': 1, 'tumour': 2, 'parenchyma': 3}):
23
+ # Get the list of files
24
+ self.__list_files = self.__get_dataset_files(dataset_path)
25
+ self.__list_files.sort()
26
+ self.__dataset_path = dataset_path
27
+ self.__shuffle = shuffle
28
+ self.__total_samples = len(self.__list_files)
29
+ self.__validation_split = validation_split
30
+ self.__clip_range = clip_range
31
+ self.__batch_size = batch_size
32
+
33
+ self.__validation_samples = validation_samples
34
+
35
+ self.__voxelmorph = voxelmorph
36
+ self.__segmentations = segmentations
37
+ self.__seg_labels = seg_labels
38
+
39
+ if num_samples is not None:
40
+ self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
41
+ else:
42
+ self.__num_samples = self.__total_samples
43
+
44
+ self.__internal_idxs = np.arange(self.__num_samples)
45
+
46
+ # Split it accordingly
47
+ if validation_split is None:
48
+ self.__validation_num_samples = None
49
+ self.__validation_idxs = list()
50
+ if self.__shuffle:
51
+ random.shuffle(self.__internal_idxs)
52
+ self.__training_idxs = self.__internal_idxs
53
+
54
+ self.__validation_generator = None
55
+ else:
56
+ self.__validation_num_samples = int(np.ceil(self.__num_samples * validation_split))
57
+ if self.__shuffle:
58
+ self.__validation_idxs = np.random.choice(self.__internal_idxs, self.__validation_num_samples)
59
+ else:
60
+ self.__validation_idxs = self.__internal_idxs[0: self.__validation_num_samples]
61
+ self.__training_idxs = np.asarray([idx for idx in self.__internal_idxs if idx not in self.__validation_idxs])
62
+ # Build them DataGenerators
63
+ self.__validation_generator = DataGenerator(self, 'validation')
64
+
65
+ self.__train_generator = DataGenerator(self, 'train')
66
+ self.reshuffle_indices()
67
+
68
+ @property
69
+ def dataset_path(self):
70
+ return self.__dataset_path
71
+
72
+ @property
73
+ def dataset_list_files(self):
74
+ return self.__list_files
75
+
76
+ @property
77
+ def train_idxs(self):
78
+ return self.__training_idxs
79
+
80
+ @property
81
+ def validation_idxs(self):
82
+ return self.__validation_idxs
83
+
84
+ @property
85
+ def batch_size(self):
86
+ return self.__batch_size
87
+
88
+ @property
89
+ def clip_rage(self):
90
+ return self.__clip_range
91
+
92
+ @property
93
+ def shuffle(self):
94
+ return self.__shuffle
95
+
96
+ def get_generator_idxs(self, generator_type):
97
+ if generator_type == 'train':
98
+ return self.train_idxs
99
+ elif generator_type == 'validation':
100
+ return self.validation_idxs
101
+ else:
102
+ raise ValueError('Invalid generator type: ', generator_type)
103
+
104
+ @staticmethod
105
+ def __get_dataset_files(search_path):
106
+ """
107
+ Get the path to the dataset files
108
+ :param search_path: dir path to search for the hd5 files
109
+ :return:
110
+ """
111
+ file_list = list()
112
+ for root, dirs, files in os.walk(search_path):
113
+ file_list.sort()
114
+ for data_file in files:
115
+ file_name, extension = os.path.splitext(data_file)
116
+ if extension.lower() == '.hd5':
117
+ file_list.append(os.path.join(root, data_file))
118
+
119
+ if not file_list:
120
+ raise ValueError('No files found to train in ', search_path)
121
+
122
+ print('Found {} files in {}'.format(len(file_list), search_path))
123
+ return file_list
124
+
125
+ def reshuffle_indices(self):
126
+ if self.__validation_num_samples is None:
127
+ if self.__shuffle:
128
+ random.shuffle(self.__internal_idxs)
129
+ self.__training_idxs = self.__internal_idxs
130
+ else:
131
+ if self.__shuffle:
132
+ self.__validation_idxs = np.random.choice(self.__internal_idxs, self.__validation_num_samples)
133
+ else:
134
+ self.__validation_idxs = self.__internal_idxs[0: self.__validation_num_samples]
135
+ self.__training_idxs = np.asarray([idx for idx in self.__internal_idxs if idx not in self.__validation_idxs])
136
+
137
+ # Update the indices
138
+ self.__validation_generator.update_samples(self.__validation_idxs)
139
+
140
+ self.__train_generator.update_samples(self.__training_idxs)
141
+
142
+ def get_generator(self, type='train'):
143
+ if type.lower() == 'train':
144
+ return self.__train_generator
145
+ elif type.lower() == 'validation':
146
+ if self.__validation_generator is not None:
147
+ return self.__validation_generator
148
+ else:
149
+ raise Warning('No validation generator available. Set a non-zero validation_split to build one.')
150
+ else:
151
+ raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
152
+
153
+ @property
154
+ def is_voxelmorph(self):
155
+ return self.__voxelmorph
156
+
157
+ @property
158
+ def give_segmentations(self):
159
+ return self.__segmentations
160
+
161
+ @property
162
+ def seg_labels(self):
163
+ return self.__seg_labels
164
+
165
+
166
+ class DataGenerator(DataGeneratorManager):
167
+ def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
168
+ self.__complete_list_files = GeneratorManager.dataset_list_files
169
+ self.__list_files = [self.__complete_list_files[idx] for idx in GeneratorManager.get_generator_idxs(dataset_type)]
170
+ self.__batch_size = GeneratorManager.batch_size
171
+ self.__total_samples = len(self.__list_files)
172
+ self.__clip_range = GeneratorManager.clip_rage
173
+ self.__manager = GeneratorManager
174
+ self.__shuffle = GeneratorManager.shuffle
175
+
176
+ self.__seg_labels = GeneratorManager.seg_labels
177
+
178
+ self.__num_samples = len(self.__list_files)
179
+ self.__internal_idxs = np.arange(self.__num_samples)
180
+ # These indices are internal to the generator, they are not the same as the dataset_idxs!!
181
+
182
+ self.__dataset_type = dataset_type
183
+
184
+ self.__last_batch = 0
185
+ self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
186
+
187
+ self.__voxelmorph = GeneratorManager.is_voxelmorph
188
+ self.__segmentations = GeneratorManager.is_voxelmorph and GeneratorManager.give_segmentations
189
+
190
+ @staticmethod
191
+ def __get_dataset_files(search_path):
192
+ """
193
+ Get the path to the dataset files
194
+ :param search_path: dir path to search for the hd5 files
195
+ :return:
196
+ """
197
+ file_list = list()
198
+ for root, dirs, files in os.walk(search_path):
199
+ for data_file in files:
200
+ file_name, extension = os.path.splitext(data_file)
201
+ if extension.lower() == '.hd5':
202
+ file_list.append(os.path.join(root, data_file))
203
+
204
+ if not file_list:
205
+ raise ValueError('No files found to train in ', search_path)
206
+
207
+ print('Found {} files in {}'.format(len(file_list), search_path))
208
+ return file_list
209
+
210
+ def update_samples(self, new_sample_idxs):
211
+ self.__list_files = [self.__complete_list_files[idx] for idx in new_sample_idxs]
212
+ self.__num_samples = len(self.__list_files)
213
+ self.__internal_idxs = np.arange(self.__num_samples)
214
+
215
+ def on_epoch_end(self):
216
+ """
217
+ To be executed at the end of each epoch. Reshuffle the assigned samples
218
+ :return:
219
+ """
220
+ if self.__shuffle:
221
+ random.shuffle(self.__internal_idxs)
222
+ self.__last_batch = 0
223
+
224
+ def __len__(self):
225
+ """
226
+ Number of batches per epoch
227
+ :return:
228
+ """
229
+ return self.__batches_per_epoch
230
+
231
+ def __getitem__(self, index):
232
+ """
233
+ Generate one batch of data
234
+ :param index: epoch index
235
+ :return:
236
+ """
237
+ idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
238
+
239
+ fix_img, mov_img, fix_vessels, mov_vessels, fix_tumour, mov_tumour, disp_map = self.__load_data(idxs)
240
+
241
+ try:
242
+ fix_img = min_max_norm(fix_img).astype(np.float32)
243
+ mov_img = min_max_norm(mov_img).astype(np.float32)
244
+ except ValueError:
245
+ print(idxs, fix_img.shape, mov_img.shape)
246
+ er_str = 'ERROR:\t[file]:\t{}\t[idx]:\t{}\t[fix_img.shape]:\t{}\t[mov_img.shape]:\t{}\t'.format(self.__list_files[idxs], idxs, fix_img.shape, mov_img.shape)
247
+ raise ValueError(er_str)
248
+
249
+ fix_vessels[fix_vessels > 0.] = self.__seg_labels['vessels']
250
+ mov_vessels[mov_vessels > 0.] = self.__seg_labels['vessels']
251
+
252
+ # fix_tumour[fix_tumour > 0.] = self.__seg_labels['tumour']
253
+ # mov_tumour[mov_tumour > 0.] = self.__seg_labels['tumour']
254
+
255
+ # https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
256
+ # A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
257
+ # The second element must match the outputs of the model, in this case (image, displacement map)
258
+ if self.__voxelmorph:
259
+ zero_grad = np.zeros([fix_img.shape[0], *C.DISP_MAP_SHAPE])
260
+ if self.__segmentations:
261
+ inputs = [mov_vessels, fix_vessels, mov_img, fix_img, zero_grad]
262
+ outputs = [] #[fix_img, zero_grad]
263
+ else:
264
+ inputs = [mov_img, fix_img]
265
+ outputs = [fix_img, zero_grad]
266
+ return (inputs, outputs)
267
+ else:
268
+ return (fix_img, mov_img, fix_vessels, mov_vessels), # (None, fix_seg, fix_seg, fix_img)
269
+
270
+ def next_batch(self):
271
+ if self.__last_batch > self.__batches_per_epoch:
272
+ raise ValueError('No more batches for this epoch')
273
+ batch = self.__getitem__(self.__last_batch)
274
+ self.__last_batch += 1
275
+ return batch
276
+
277
+ def __load_data(self, idx_list):
278
+ """
279
+ Build the batch with the samples in idx_list
280
+ :param idx_list:
281
+ :return:
282
+ """
283
+ if isinstance(idx_list, (list, np.ndarray)):
284
+ fix_img = np.empty((0, ) + C.IMG_SHAPE)
285
+ mov_img = np.empty((0, ) + C.IMG_SHAPE)
286
+ disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
287
+
288
+ # fix_segm = np.empty((0, ) + const.IMG_SHAPE)
289
+ # mov_segm = np.empty((0, ) + const.IMG_SHAPE)
290
+
291
+ fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
292
+ mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
293
+ fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
294
+ mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
295
+ for idx in idx_list:
296
+ data_file = h5py.File(self.__list_files[idx], 'r')
297
+
298
+ fix_img = np.append(fix_img, [data_file[C.H5_FIX_IMG][:]], axis=0)
299
+ mov_img = np.append(mov_img, [data_file[C.H5_MOV_IMG][:]], axis=0)
300
+
301
+ # fix_segm = np.append(fix_segm, [data_file[const.H5_FIX_PARENCHYMA_MASK][:]], axis=0)
302
+ # mov_segm = np.append(mov_segm, [data_file[const.H5_MOV_PARENCHYMA_MASK][:]], axis=0)
303
+
304
+ disp_map = np.append(disp_map, [data_file[C.H5_GT_DISP][:]], axis=0)
305
+
306
+ fix_vessels = np.append(fix_vessels, [data_file[C.H5_FIX_VESSELS_MASK][:]], axis=0)
307
+ mov_vessels = np.append(mov_vessels, [data_file[C.H5_MOV_VESSELS_MASK][:]], axis=0)
308
+ fix_tumors = np.append(fix_tumors, [data_file[C.H5_FIX_TUMORS_MASK][:]], axis=0)
309
+ mov_tumors = np.append(mov_tumors, [data_file[C.H5_MOV_TUMORS_MASK][:]], axis=0)
310
+
311
+ data_file.close()
312
+ else:
313
+ data_file = h5py.File(self.__list_files[idx_list], 'r')
314
+
315
+ fix_img = np.expand_dims(data_file[C.H5_FIX_IMG][:], 0)
316
+ mov_img = np.expand_dims(data_file[C.H5_MOV_IMG][:], 0)
317
+
318
+ # fix_segm = np.expand_dims(data_file[const.H5_FIX_PARENCHYMA_MASK][:], 0)
319
+ # mov_segm = np.expand_dims(data_file[const.H5_MOV_PARENCHYMA_MASK][:], 0)
320
+
321
+ fix_vessels = np.expand_dims(data_file[C.H5_FIX_VESSELS_MASK][:], axis=0)
322
+ mov_vessels = np.expand_dims(data_file[C.H5_MOV_VESSELS_MASK][:], axis=0)
323
+ fix_tumors = np.expand_dims(data_file[C.H5_FIX_TUMORS_MASK][:], axis=0)
324
+ mov_tumors = np.expand_dims(data_file[C.H5_MOV_TUMORS_MASK][:], axis=0)
325
+
326
+ disp_map = np.expand_dims(data_file[C.H5_GT_DISP][:], 0)
327
+
328
+ data_file.close()
329
+
330
+ return fix_img, mov_img, fix_vessels, mov_vessels, fix_tumors, mov_tumors, disp_map
331
+
332
+ def get_single_sample(self):
333
+ fix_img, mov_img, fix_segm, mov_segm, _ = self.__load_data(0)
334
+ # return X, y
335
+ return np.expand_dims(np.concatenate([mov_img, fix_img, mov_segm, mov_segm], axis=-1), axis=0)
336
+
337
+ def get_random_sample(self, num_samples):
338
+ idxs = np.random.randint(0, self.__num_samples, num_samples)
339
+ fix_img, mov_img, fix_segm, mov_segm, disp_map = self.__load_data(idxs)
340
+
341
+ return (fix_img, mov_img, fix_segm, mov_segm, disp_map), [self.__list_files[f] for f in idxs]
342
+
343
+ def get_input_shape(self):
344
+ input_batch, _ = self.__getitem__(0)
345
+ if self.__voxelmorph:
346
+ ret_val = list(input_batch[0].shape)
347
+ ret_val[-1] = 2
348
+ ret_val = (None, ) + tuple(ret_val[1:])
349
+ else:
350
+ ret_val = input_batch.shape
351
+ ret_val = (None, ) + ret_val[1:]
352
+ return ret_val # const.BATCH_SHAPE_SEGM
353
+
354
+ def who_are_you(self):
355
+ return self.__dataset_type
356
+
357
+ def print_datafiles(self):
358
+ return self.__list_files
359
+
360
+
361
+ class DataGeneratorManager2D:
362
+ FIX_IMG_H5 = 'input/1'
363
+ MOV_IMG_H5 = 'input/0'
364
+ def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
365
+ fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
366
+ self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
367
+ self.__batch_size = batch_size
368
+ self.__data_split = data_split
369
+
370
+ self.__initialize()
371
+
372
+ self.__train_generator = DataGenerator2D(self.__train_file_list,
373
+ batch_size=self.__batch_size,
374
+ img_size=img_size,
375
+ fix_img_tag=fix_img_tag,
376
+ mov_img_tag=mov_img_tag,
377
+ multi_loss=multi_loss)
378
+ self.__val_generator = DataGenerator2D(self.__val_file_list,
379
+ batch_size=self.__batch_size,
380
+ img_size=img_size,
381
+ fix_img_tag=fix_img_tag,
382
+ mov_img_tag=mov_img_tag,
383
+ multi_loss=multi_loss)
384
+
385
+ def __initialize(self):
386
+ num_samples = len(self.__file_list)
387
+ random.shuffle(self.__file_list)
388
+
389
+ data_split = int(np.floor(num_samples * self.__data_split))
390
+ self.__val_file_list = self.__file_list[0:data_split]
391
+ self.__train_file_list = self.__file_list[data_split:]
392
+
393
+ @property
394
+ def train_generator(self):
395
+ return self.__train_generator
396
+
397
+ @property
398
+ def validation_generator(self):
399
+ return self.__val_generator
400
+
401
+
402
+ class DataGenerator2D(keras.utils.Sequence):
403
+ FIX_IMG_H5 = 'input/1'
404
+ MOV_IMG_H5 = 'input/0'
405
+
406
+ def __init__(self, file_list: list, batch_size=32, img_size=None, fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
407
+ self.__file_list = file_list # h5py.File(h5_file, 'r')
408
+ self.__file_list.sort()
409
+ self.__batch_size = batch_size
410
+ self.__idx_list = np.arange(0, len(self.__file_list))
411
+ self.__multi_loss = multi_loss
412
+
413
+ self.__tags = {'fix_img': fix_img_tag,
414
+ 'mov_img': mov_img_tag}
415
+
416
+ self.__batches_seen = 0
417
+ self.__batches_per_epoch = 0
418
+
419
+ self.__img_size = img_size
420
+
421
+ self.__initialize()
422
+
423
+ def __len__(self):
424
+ return self.__batches_per_epoch
425
+
426
+ def __initialize(self):
427
+ random.shuffle(self.__idx_list)
428
+
429
+ if self.__img_size is None:
430
+ f = h5py.File(self.__file_list[0], 'r')
431
+ self.input_shape = f[self.__tags['fix_img']].shape # Already defined in super()
432
+ f.close()
433
+ else:
434
+ self.input_shape = self.__img_size
435
+
436
+ if self.__multi_loss:
437
+ self.input_shape = (self.input_shape, (*self.input_shape[:-1], 2))
438
+
439
+ self.__batches_per_epoch = int(np.ceil(len(self.__file_list) / self.__batch_size))
440
+
441
+ def __load_and_preprocess(self, fh, tag):
442
+ img = fh[tag][:]
443
+
444
+ if (self.__img_size is not None) and (img[..., 0].shape != self.__img_size):
445
+ im = Image.fromarray(img[..., 0]) # Can't handle the 1 channel
446
+ img = np.array(im.resize(self.__img_size[:-1], Image.LANCZOS)).astype(np.float32)
447
+ img = img[..., np.newaxis]
448
+
449
+ if img.max() > 1. or img.min() < 0.:
450
+ try:
451
+ img = min_max_norm(img).astype(np.float32)
452
+ except ValueError:
453
+ print(fh, tag, img.shape)
454
+ er_str = 'ERROR:\t[file]:\t{}\t[tag]:\t{}\t[img.shape]:\t{}\t'.format(fh, tag, img.shape)
455
+ raise ValueError(er_str)
456
+ return img.astype(np.float32)
457
+
458
+ def __getitem__(self, idx):
459
+ idxs = self.__idx_list[idx * self.__batch_size:(idx + 1) * self.__batch_size]
460
+
461
+ fix_imgs, mov_imgs = self.__load_samples(idxs)
462
+
463
+ zero_grad = np.zeros((*fix_imgs.shape[:-1], 2))
464
+
465
+ inputs = [mov_imgs, fix_imgs]
466
+ outputs = [fix_imgs, zero_grad]
467
+
468
+ if self.__multi_loss:
469
+ return [mov_imgs, fix_imgs, zero_grad],
470
+ else:
471
+ return (inputs, outputs)
472
+
473
+ def __load_samples(self, idx_list):
474
+ if self.__multi_loss:
475
+ img_shape = (0, *self.input_shape[0])
476
+ else:
477
+ img_shape = (0, *self.input_shape)
478
+
479
+ fix_imgs = np.empty(img_shape)
480
+ mov_imgs = np.empty(img_shape)
481
+ for i in idx_list:
482
+ f = h5py.File(self.__file_list[i], 'r')
483
+ fix_imgs = np.append(fix_imgs, [self.__load_and_preprocess(f, self.__tags['fix_img'])], axis=0)
484
+ mov_imgs = np.append(mov_imgs, [self.__load_and_preprocess(f, self.__tags['mov_img'])], axis=0)
485
+ f.close()
486
+
487
+ return fix_imgs, mov_imgs
488
+
489
+ def on_epoch_end(self):
490
+ np.random.shuffle(self.__idx_list)
491
+
492
+ def get_single_sample(self):
493
+ idx = random.randint(0, len(self.__idx_list))
494
+ fix, mov = self.__load_samples([idx])
495
+ return mov, fix
496
+
497
+
DeepDeformationMapRegistration/layers.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
8
+
9
+ import tensorflow.keras.layers as kl
10
+ import tensorflow.keras.backend as K
11
+ import tensorflow as tf
12
+ import numpy as np
13
+
14
+ from DeepDeformationMapRegistration.utils.operators import soft_threshold
15
+
16
+
17
+ class UncertaintyWeighting(kl.Layer):
18
+ def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
19
+ reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
20
+ **kwargs):
21
+ assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
22
+ assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
23
+ self.num_loss = num_loss_fns
24
+ if len(loss_fns) == 1 and self.num_loss > 1:
25
+ self.loss_fns = loss_fns * self.num_loss
26
+ else:
27
+ self.loss_fns = loss_fns
28
+
29
+ if len(prior_loss_w) == 1:
30
+ self.prior_loss_w = prior_loss_w * num_loss_fns
31
+ else:
32
+ self.prior_loss_w = prior_loss_w
33
+ self.prior_loss_w = np.log(self.prior_loss_w)
34
+
35
+ if len(manual_loss_w) == 1:
36
+ self.manual_loss_w = manual_loss_w * num_loss_fns
37
+ else:
38
+ self.manual_loss_w = manual_loss_w
39
+
40
+ self.num_reg = num_reg_fns
41
+ if self.num_reg != 0:
42
+ if len(reg_fns) == 1 and self.num_reg > 1:
43
+ self.reg_fns = reg_fns * self.num_reg
44
+ else:
45
+ self.reg_fns = reg_fns
46
+
47
+ self.is_placeholder = True
48
+ if self.num_reg != 0:
49
+ if len(prior_reg_w) == 1:
50
+ self.prior_reg_w = prior_reg_w * num_reg_fns
51
+ else:
52
+ self.prior_reg_w = prior_reg_w
53
+ self.prior_reg_w = np.log(self.prior_reg_w)
54
+
55
+ if len(manual_reg_w) == 1:
56
+ self.manual_reg_w = manual_reg_w * num_reg_fns
57
+ else:
58
+ self.manual_reg_w = manual_reg_w
59
+
60
+ else:
61
+ self.prior_reg_w = list()
62
+ self.manual_reg_w = list()
63
+
64
+ super(UncertaintyWeighting, self).__init__(**kwargs)
65
+
66
+ def build(self, input_shape=None):
67
+ self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
68
+ initializer=tf.keras.initializers.Constant(self.prior_loss_w),
69
+ trainable=True)
70
+ self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
71
+
72
+ if self.num_reg != 0:
73
+ self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
74
+ initializer=tf.keras.initializers.Constant(self.prior_reg_w),
75
+ trainable=True)
76
+ if self.num_reg == 1:
77
+ self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
78
+ else:
79
+ self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
80
+
81
+ super(UncertaintyWeighting, self).build(input_shape)
82
+
83
+ def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
84
+ loss_values = list()
85
+ loss_names_loss = list()
86
+ loss_names_reg = list()
87
+
88
+ for y_true, y_pred, loss_fn, man_w in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w):
89
+ loss_values.append(tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
90
+ loss_names_loss.append(loss_fn.__name__)
91
+
92
+ loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
93
+ loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
94
+
95
+ if self.num_reg != 0:
96
+ loss_reg = list()
97
+ for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
98
+ loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
99
+ loss_names_reg.append(reg_fn.__name__)
100
+
101
+ reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
102
+ loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
103
+
104
+ for i, loss_name in enumerate(loss_names_loss):
105
+ self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
106
+ aggregation='mean')
107
+ self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
108
+ aggregation='mean')
109
+ if self.num_reg != 0:
110
+ for i, loss_name in enumerate(loss_names_reg):
111
+ self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
112
+ aggregation='mean')
113
+ self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
114
+ aggregation='mean')
115
+
116
+ return K.sum(loss)
117
+
118
+ def call(self, inputs):
119
+ ys_true = inputs[:self.num_loss]
120
+ ys_pred = inputs[self.num_loss:self.num_loss*2]
121
+ reg_true = inputs[-self.num_reg*2:-self.num_reg]
122
+ reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
123
+ loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
124
+ self.add_loss(loss, inputs=inputs)
125
+ # We won't actually use the output, but we need something for the TF graph
126
+ return K.concatenate(inputs, -1)
127
+
128
+ def get_config(self):
129
+ base_config = super(UncertaintyWeighting, self).get_config()
130
+ base_config['num_loss_fns'] = self.num_loss
131
+ base_config['num_reg_fns'] = self.num_reg
132
+
133
+ return base_config
134
+
135
+
136
+ def distance_map(coord1, coord2, dist, img_shape_w_channel=(64, 64, 1)):
137
+ max_dist = np.max(img_shape_w_channel)
138
+ dm_p = np.ones(img_shape_w_channel, np.float32)*max_dist
139
+ dm_n = np.ones(img_shape_w_channel, np.float32)*max_dist
140
+
141
+ for c1, c2, d in zip(coord1, coord2, dist):
142
+ dm_p[c1, c2, 0] = d if dm_p[c1, c2, 0] > d else dm_p[c1, c2]
143
+ d_n = 64. - max_dist
144
+ dm_n[c1, c2, 0] = d_n if dm_n[c1, c2, 0] > d_n else dm_n[c1, c2]
145
+
146
+ return dm_p/max_dist, dm_n/max_dist
147
+
148
+
149
+ def volume_to_ov_and_dm(in_volume: tf.Tensor):
150
+ # This one is run as a preprocessing step
151
+ def get_ov_projections_and_dm(volume):
152
+ # tf.sign returns -1, 0, 1 depending on the sign of the elements of the input (negative, zero, positive)
153
+ i, j, k, c = tf.where(volume > 0.0)
154
+ top = tf.sign(tf.reduce_sum(volume, axis=0), name='ov_top')
155
+ right = tf.sign(tf.reduce_sum(volume, axis=1), name='ov_right')
156
+ front = tf.sign(tf.reduce_sum(volume, axis=2), name='ov_front')
157
+
158
+ top_p, top_n = tf.py_func(distance_map, [j, k, i], tf.float32)
159
+ right_p, right_n = tf.py_func(distance_map, [i, k, j], tf.float32)
160
+ front_p, front_n = tf.py_func(distance_map, [i, j, k], tf.float32)
161
+
162
+ return [front, right, top], [front_p, front_n, top_p, top_n, right_p, right_n]
163
+
164
+ if len(in_volume.shape.as_list()) > 4:
165
+ return tf.map_fn(get_ov_projections_and_dm, in_volume, [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32])
166
+ else:
167
+ return get_ov_projections_and_dm(in_volume)
168
+
169
+
170
+ def ov_and_dm_to_volume(ov_projections):
171
+ front, right, top = ov_projections
172
+
173
+ def get_volume(front: tf.Tensor, right: tf.Tensor, top: tf.Tensor):
174
+ front_shape = front.shape.as_list() # Assume (H, W, C)
175
+ top_shape = top.shape.as_list()
176
+
177
+ front_vol = tf.tile(tf.expand_dims(front, 2), [1, 1, top_shape[0], 1])
178
+ right_vol = tf.tile(tf.expand_dims(right, 1), [1, front_shape[1], 1, 1])
179
+ top_vol = tf.tile(tf.expand_dims(top, 0), [front_shape[0], 1, 1, 1])
180
+ sum = tf.add(tf.add(front_vol, right_vol), top_vol)
181
+ return soft_threshold(sum, 2., 'get_volume')
182
+
183
+ if len(front.shape.as_list()) > 3:
184
+ return tf.map_fn(lambda x: get_volume(x[0], x[1], x[2]), ov_projections, tf.float32)
185
+ else:
186
+ return get_volume(front, right, top)
187
+
188
+ # TODO: Recovering the coordinates from the distance maps to prevent artifacts
189
+ # will the gradients be backpropagated??!?!!?!?!
DeepDeformationMapRegistration/losses.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import tensorflow as tf
9
+ from scipy.ndimage import generate_binary_structure
10
+
11
+ import DeepDeformationMapRegistration.utils.constants as C
12
+ from DeepDeformationMapRegistration.utils.operators import soft_threshold
13
+
14
+
15
+ class HausdorffDistance:
16
+ def __init__(self, ndim=3, nerosion=10):
17
+ self.ndims = ndim
18
+ self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
19
+ self.nerosions = nerosion
20
+
21
+ def _erode(self, in_tensor, kernel):
22
+ out = 1. - tf.squeeze(self.conv(tf.expand_dims(1. - in_tensor, 0), kernel, [1] * (self.ndims + 2), 'SAME'), axis=0)
23
+ return soft_threshold(out, 0.5, name='soft_thresholding')
24
+
25
+ def _erosion_distance_single(self, y_true, y_pred):
26
+ diff = tf.math.pow(y_pred - y_true, 2)
27
+ alpha = 2.
28
+
29
+ norm = 1 / self.ndims * 2 + 1
30
+ kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
31
+ kernel = tf.constant(kernel, tf.float32)
32
+ kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1)
33
+
34
+ ret = 0.
35
+ for i in range(self.nerosions):
36
+ for j in range(i + 1):
37
+ er = self._erode(diff, kernel)
38
+ ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)))
39
+
40
+ return tf.multiply(C.IMG_SIZE ** -self.ndims, ret) # Divide by the image size
41
+
42
+ def loss(self, y_true, y_pred):
43
+ batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
44
+ dtype=tf.float32)
45
+
46
+ return batched_dist # tf.reduce_mean(batched_dist)
47
+
DeepDeformationMapRegistration/networks.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import tensorflow as tf
9
+ import voxelmorph as vxm
10
+ from voxelmorph.tf.modelio import LoadableModel, store_config_args
11
+
12
+
13
+ class VxmWeaklySupervised(LoadableModel):
14
+
15
+ @store_config_args
16
+ def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
17
+ """
18
+ Parameters:
19
+ inshape: Input shape. e.g. (192, 192, 192)
20
+ all_labels: List of all labels included in training segmentations.
21
+ hot_labels: List of labels to output as one-hot maps.
22
+ nb_unet_features: Unet convolutional features. See VxmDense documentation for more information.
23
+ int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
24
+ kwargs: Forwarded to the internal VxmDense model.
25
+ """
26
+
27
+ fix_segm = tf.keras.Input((*inshape, len(all_labels)), name='fix_segmentations_input')
28
+ mov_segm = tf.keras.Input((*inshape, len(all_labels)), name='mov_segmentations_input')
29
+
30
+ mov_img = tf.keras.Input((*inshape, 1), name='mov_image_input')
31
+
32
+ unet_input_model = tf.keras.Model(inputs=[mov_segm, fix_segm], outputs=[mov_segm, fix_segm])
33
+
34
+ vxm_model = vxm.networks.VxmDense(inshape=inshape,
35
+ nb_unet_features=nb_unet_features,
36
+ input_model=unet_input_model,
37
+ int_steps=int_steps,
38
+ bidir=bidir, **kwargs)
39
+
40
+ pred_img = vxm.layers.SpatialTransformer(interp_method='linear', indexing='ij', name='pred_fix_img')(
41
+ [mov_img, vxm_model.references.pos_flow])
42
+
43
+ inputs = [mov_segm, fix_segm, mov_img] # mov_img, mov_segm, fix_segm
44
+ outputs = [pred_img] + vxm_model.outputs
45
+
46
+ self.references = LoadableModel.ReferenceContainer()
47
+ self.references.pred_segm = vxm_model.outputs[0]
48
+ self.references.pred_img = pred_img
49
+ self.references.pos_flow = vxm_model.references.pos_flow
50
+
51
+ super().__init__(inputs=inputs, outputs=outputs)
52
+
53
+ def get_registration_model(self):
54
+ return tf.keras.Model(self.inputs, self.references.pos_flow)
55
+
56
+ def register(self, mov_img, mov_segm, fix_segm):
57
+ return self.get_registration_model().predict([mov_segm, fix_segm, mov_img])
58
+
59
+ def apply_transform(self, mov_img, mov_segm, fix_segm, interp_method='linear'):
60
+ warp_model = self.get_registration_model()
61
+ img_input = tf.keras.Input(shape=mov_img.shape[1:], name='input_img')
62
+ pred_img = vxm.layers.SpatialTransformer(interp_method=interp_method)([img_input, warp_model.output])
63
+ return tf.keras.Model(warp_model.inputs, pred_img).predict([mov_segm, fix_segm, mov_img])
DeepDeformationMapRegistration/utils/acummulated_optimizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow.keras.optimizers import Optimizer
2
+ from tensorflow.keras import backend as K
3
+
4
+
5
+ class AccumOptimizer(Optimizer):
6
+ """Optimizer
7
+ Inheriting Optimizer class, wrapping the original optimizer
8
+ to achieve a new corresponding optimizer of gradient accumulation.
9
+ # Arguments
10
+ optimizer: an instance of keras optimizer (supporting
11
+ all keras optimizers currently available);
12
+ steps_per_update: the steps of gradient accumulation
13
+ # Returns
14
+ a new keras optimizer.
15
+ """
16
+ def __init__(self, optimizer, steps_per_update=1, **kwargs):
17
+ super(AccumOptimizer, self).__init__(**kwargs)
18
+ self.optimizer = optimizer
19
+ with K.name_scope(self.__class__.__name__):
20
+ self.steps_per_update = steps_per_update
21
+ self.iterations = K.variable(0, dtype='int64', name='iterations')
22
+ self.cond = K.equal(self.iterations % self.steps_per_update, 0)
23
+ self.lr = self.optimizer.lr
24
+ self.optimizer.lr = K.switch(self.cond, self.optimizer.lr, 0.)
25
+ for attr in ['momentum', 'rho', 'beta_1', 'beta_2']:
26
+ if hasattr(self.optimizer, attr):
27
+ value = getattr(self.optimizer, attr)
28
+ setattr(self, attr, value)
29
+ setattr(self.optimizer, attr, K.switch(self.cond, value, 1 - 1e-7))
30
+ for attr in self.optimizer.get_config():
31
+ if not hasattr(self, attr):
32
+ value = getattr(self.optimizer, attr)
33
+ setattr(self, attr, value)
34
+ # Cover the original get_gradients method with accumulative gradients.
35
+ def get_gradients(loss, params):
36
+ return [ag / self.steps_per_update for ag in self.accum_grads]
37
+ self.optimizer.get_gradients = get_gradients
38
+ def get_updates(self, loss, params):
39
+ self.updates = [
40
+ K.update_add(self.iterations, 1),
41
+ K.update_add(self.optimizer.iterations, K.cast(self.cond, 'int64')),
42
+ ]
43
+ # gradient accumulation
44
+ self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
45
+ grads = self.get_gradients(loss, params)
46
+ for g, ag in zip(grads, self.accum_grads):
47
+ self.updates.append(K.update(ag, K.switch(self.cond, g, ag + g)))
48
+ # inheriting updates of original optimizer
49
+ self.updates.extend(self.optimizer.get_updates(loss, params)[1:])
50
+ self.weights.extend(self.optimizer.weights)
51
+ return self.updates
52
+ def get_config(self):
53
+ iterations = K.eval(self.iterations)
54
+ K.set_value(self.iterations, 0)
55
+ config = self.optimizer.get_config()
56
+ K.set_value(self.iterations, iterations)
57
+ return config
DeepDeformationMapRegistration/utils/cmd_args_parser.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, getopt
2
+ import DeepDeformationMapRegistration.utils.constants as C
3
+ import os
4
+
5
+
6
+ def parse_arguments(argv):
7
+
8
+ try:
9
+ opts, args = getopt.getopt(argv, "hg:b:l:r:d:t:i:f:x:p:q:", ["gpu-num=",
10
+ "batch-size=",
11
+ "loss=",
12
+ "remote=",
13
+ "debug=",
14
+ "debug-training=",
15
+ "debug-input-data=",
16
+ "destination-folder=",
17
+ "destination-folder-fix=",
18
+ "training-dataset=",
19
+ "test-dataset=",
20
+ "help"])
21
+ except getopt.GetoptError:
22
+ print('\n\t\t--gpu-num:\t\tGPU number to use'
23
+ '\n\t\t--batch-size:\t\tsize of the training batch'
24
+ '\n\t\t--loss:\t\tLoss function: ncc, mse, dssim'
25
+ '\n\t\t--remote:\t\tExecuting the script in The Beast: "True"/"False". Def: False'
26
+ '\n\t\t--debug:\t\tEnable debugging logs: "True"/"False". Def: False'
27
+ '\n\t\t--debug-training:\t\tEnable debugging training logs: "True"/"False". Def: False'
28
+ '\n\t\t--debug-input-data:\t\tEnable debugging input data logs: "True"/"False". Def: False'
29
+ '\n\t\t--destination-folder:\t\tName of the folder where to save the generated training files'
30
+ '\n\t\t--destination-folder-fixed:\t\tSame as --destination-folder but do not add the timestamp'
31
+ '\n\t\t--training-dataset:\t\tPath to the training dataset file'
32
+ '\n\t\t--test-dataset:\t\tPath to the test dataset file'
33
+ '\n')
34
+ sys.exit(2)
35
+
36
+ for opt, arg in opts:
37
+ if opt in ('--help', '-h'):
38
+ print('\n\t\t--gpu-num:\t\tGPU number to use\n\t\t--batch-size:\t\tsize of the training batch'
39
+ '\n\t\t--loss:\t\tLoss function: ncc, mse, dssim\n')
40
+ continue
41
+ elif opt in ('--gpu_num', '-g'):
42
+ old = C.GPU_NUM
43
+ C.GPU_NUM = arg
44
+ os.environ['CUDA_VISIBLE_DEVICES'] = C.GPU_NUM
45
+ print('\t\tGPU_NUM: {} -> {}'.format(old, C.GPU_NUM))
46
+
47
+ elif opt in ('--batch-size', '-b'):
48
+ old = C.BATCH_SIZE
49
+ C.BATCH_SIZE = int(arg)
50
+ print('\t\tBATCH_SIZE: {} -> {}'.format(old, C.BATCH_SIZE))
51
+
52
+ elif opt in ('--destination-folder', '-f'):
53
+ old = C.DESTINATION_FOLDER
54
+ C.DESTINATION_FOLDER = arg + '_' + C.CUR_DATETIME
55
+ print('\t\tDESTINATION_FOLDER: {} -> {}'.format(old, C.DESTINATION_FOLDER))
56
+
57
+ elif opt in ('--destination-folder-fixed', '-x'):
58
+ old = C.DESTINATION_FOLDER
59
+ C.DESTINATION_FOLDER = arg
60
+ print('\t\tDESTINATION_FOLDER: {} -> {}'.format(old, C.DESTINATION_FOLDER))
61
+
62
+ elif opt in ('--training-dataset', '-p'):
63
+ old = C.TRAINING_DATASET
64
+ C.TRAINING_DATASET = arg
65
+ print('\t\tTRAINING_DATASET: {} -> {}'.format(old, C.TRAINING_DATASET))
66
+
67
+ elif opt in ('--test-dataset', '-q'):
68
+ old = C.TEST_DATASET
69
+ C.TEST_DATASET = arg
70
+ print('\t\tTEST_DATASET: {} -> {}'.format(old, C.TEST_DATASET))
71
+
72
+ elif opt in ('--remote', '-r'):
73
+ old = C.REMOTE
74
+ if arg.lower() in ('1', 'true', 't'):
75
+ C.REMOTE = True
76
+ else:
77
+ C.REMOTE = False
78
+ print('\t\tREMOTE: {} -> {}'.format(old, C.REMOTE))
79
+
80
+ elif opt in ('--debug', '-d'):
81
+ old = C.DEBUG
82
+ if arg.lower() in ('1', 'true', 't'):
83
+ C.DEBUG = True
84
+ else:
85
+ C.DEBUG = False
86
+ print('\t\tDEBUG: {} -> {}'.format(old, C.DEBUG))
87
+
88
+ elif opt in ('--debug-training', '-t'):
89
+ old = C.DEBUG_TRAINING
90
+ if arg.lower() in ('1', 'true', 't'):
91
+ C.DEBUG_TRAINING = True
92
+ else:
93
+ C.DEBUG_TRAINING = False
94
+ print('\t\tDEBUG_TRAINING: {} -> {}'.format(old, C.DEBUG_TRAINING))
95
+
96
+ elif opt in ('--debug-input-data', '-i'):
97
+ old = C.DEBUG_INPUT_DATA
98
+ if arg.lower() in ('1', 'true', 't'):
99
+ C.DEBUG_INPUT_DATA = True
100
+ else:
101
+ C.DEBUG_INPUT_DATA = False
102
+ print('\t\tDEBUG_INPUT_DATA: {} -> {}'.format(old, C.DEBUG_INPUT_DATA))
103
+
104
+ elif opt in ('--loss', '-l'):
105
+ old = C.LOSS_FNC
106
+ if arg in ('ncc', 'mse', 'dssim', 'dice'):
107
+ C.LOSS_FNC = arg
108
+ else:
109
+ print('Invalid option for --loss. Expected: "mse", "ncc" or "dssim", got {}'.format(arg))
110
+ sys.exit(2)
111
+ print('\t\tLOSS_FNC: {} -> {}'.format(old, C.LOSS_FNC))
DeepDeformationMapRegistration/utils/conf_file_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import DeepDeformationMapRegistration.utils.constants as C
2
+ import re
3
+ import os
4
+
5
+
6
+ class ConfigurationFile:
7
+ def __init__(self,
8
+ file_path: str):
9
+ self.__file = file_path
10
+ self.__load_configuration()
11
+
12
+ def __load_configuration(self):
13
+ fd = open(self.__file, 'r')
14
+ file_lines = fd.readlines()
15
+
16
+ for line in file_lines:
17
+ if '#' not in line and line != '\n':
18
+ match = re.match('(.*)=(.*)', line)
19
+ if match[1] in C.__dict__.keys():
20
+ # Careful with eval!!
21
+ try:
22
+ new_val = eval(match[2])
23
+ except NameError:
24
+ new_val = match[2]
25
+ old = C.__dict__[match[1]]
26
+ C.__dict__[match[1]] = new_val
27
+
28
+ # Special case
29
+ if match[1] == 'GPU_NUM':
30
+ C.__dict__[match[1]] = str(new_val)
31
+ os.environ['CUDA_VISIBLE_DEVICES'] = C.GPU_NUM
32
+
33
+ if match[1] == 'EPOCHS':
34
+ C.__dict__[match[1]] = new_val
35
+ C.__dict__['SAVE_EPOCH'] = new_val // 10
36
+ C.__dict__['VERBOSE_EPOCH'] = new_val // 10
37
+
38
+ if match[1] == 'SAVE_EPOCH' or match[1] == 'VERBOSE_EPOCH':
39
+ if new_val is not None:
40
+ C.__dict__[match[1]] = C.__dict__['EPOCHS'] // new_val
41
+ else:
42
+ C.__dict__[match[1]] = None
43
+
44
+ if match[1] == 'VALIDATION_ERR_LIMIT_COUNTER':
45
+ C.__dict__[match[1]] = new_val
46
+ C.__dict__['VALIDATION_ERR_LIMIT_COUNTER_BACKUP'] = new_val
47
+
48
+
49
+ print('INFO: Updating constant {}: {} -> {}'.format(match[1], old, C.__dict__[match[1]]))
50
+ else:
51
+ print('ERROR: Unknown constant {}'.format(match[1]))
52
+
DeepDeformationMapRegistration/utils/constants.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants
3
+ """
4
+ import tensorflow as tf
5
+ import os
6
+ import datetime
7
+ import numpy as np
8
+
9
+ # RUN CONFIG
10
+ REMOTE = False # os.popen('hostname').read().encode('utf-8') == 'medtech-beast' #os.environ.get('REMOTE') == 'True'
11
+
12
+ # Remote execution
13
+ DEV_ORDER = 'PCI_BUS_ID'
14
+ GPU_NUM = '0'
15
+
16
+ # Dataset generation constants
17
+ # See batchGenerator __next__ method: return [in_mov, in_fix], [disp_map, out_img]
18
+ MOVING_IMG = 0
19
+ FIXED_IMG = 1
20
+ MOVING_PARENCHYMA_MASK = 2
21
+ FIXED_PARENCHYMA_MASK = 3
22
+ MOVING_VESSELS_MASK = 4
23
+ FIXED_VESSELS_MASK = 5
24
+ MOVING_TUMORS_MASK = 6
25
+ FIXED_TUMORS_MASK = 7
26
+ MOVING_SEGMENTATIONS = 8 # Compination of vessels and tumors
27
+ FIXED_SEGMENTATIONS = 9 # Compination of vessels and tumors
28
+ DISP_MAP_GT = 0
29
+ PRED_IMG_GT = 1
30
+ DISP_VECT_GT = 2
31
+ DISP_VECT_LOC_GT = 3
32
+
33
+ IMG_SIZE = 64 # Assumed a square image
34
+ IMG_SHAPE = (IMG_SIZE, IMG_SIZE, IMG_SIZE, 1) # (IMG_SIZE, IMG_SIZE, 1)
35
+ DISP_MAP_SHAPE = (IMG_SIZE, IMG_SIZE, IMG_SIZE, 3)
36
+ BATCH_SHAPE = (None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 2) # Expected batch shape by the network
37
+ BATCH_SHAPE_SEGM = (None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 3) # Expected batch shape by the network
38
+ IMG_BATCH_SHAPE = (None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 1) # Batch shape for single images
39
+
40
+ RAW_DATA_BASE_DIR = './data'
41
+ DEFORMED_DATA_NAME = 'deformed'
42
+ GROUND_TRUTH_DATA_NAME = 'groundTruth'
43
+ GROUND_TRUTH_COORDS_FILE = 'centerlineCoords_GT.txt'
44
+ DEFORMED_COORDS_FILE = 'centerlineCoords_DF.txt'
45
+ H5_MOV_IMG = 'input/{}'.format(MOVING_IMG)
46
+ H5_FIX_IMG = 'input/{}'.format(FIXED_IMG)
47
+ H5_MOV_PARENCHYMA_MASK = 'input/{}'.format(MOVING_PARENCHYMA_MASK)
48
+ H5_FIX_PARENCHYMA_MASK = 'input/{}'.format(FIXED_PARENCHYMA_MASK)
49
+ H5_MOV_VESSELS_MASK = 'input/{}'.format(MOVING_VESSELS_MASK)
50
+ H5_FIX_VESSELS_MASK = 'input/{}'.format(FIXED_VESSELS_MASK)
51
+ H5_MOV_TUMORS_MASK = 'input/{}'.format(MOVING_TUMORS_MASK)
52
+ H5_FIX_TUMORS_MASK = 'input/{}'.format(FIXED_TUMORS_MASK)
53
+ H5_FIX_SEGMENTATIONS = 'input/{}'.format(FIXED_SEGMENTATIONS)
54
+ H5_MOV_SEGMENTATIONS = 'input/{}'.format(MOVING_SEGMENTATIONS)
55
+
56
+ H5_GT_DISP = 'output/{}'.format(DISP_MAP_GT)
57
+ H5_GT_IMG = 'output/{}'.format(PRED_IMG_GT)
58
+ H5_GT_DISP_VECT = 'output/{}'.format(DISP_VECT_GT)
59
+ H5_GT_DISP_VECT_LOC = 'output/{}'.format(DISP_VECT_LOC_GT)
60
+ H5_PARAMS_INTENSITY_RANGE = 'parameters/intensity'
61
+ TRAINING_PERC = 0.8
62
+ VALIDATION_PERC = 1 - TRAINING_PERC
63
+ MAX_ANGLE = 45.0 # degrees
64
+ MAX_FLIPS = 2 # Axes to flip over
65
+ NUM_ROTATIONS = 5
66
+ MAX_WORKERS = 10
67
+
68
+ # Training constants
69
+ MODEL = 'unet'
70
+ BATCH_NORM = False
71
+ TENSORBOARD = False
72
+ LIMIT_NUM_SAMPLES = None # If you don't want to use all the samples in the training set. None to use all
73
+ TRAINING_DATASET = 'data/training.hd5'
74
+ TEST_DATASET = 'data/test.hd5'
75
+ VALIDATION_DATASET = 'data/validation.hd5'
76
+ LOSS_FNC = 'mse'
77
+ LOSS_SCHEME = 'unidirectional'
78
+ NUM_EPOCHS = 10
79
+ DATA_FORMAT = 'channels_last' # or 'channels_fist'
80
+ DATA_DIR = './data'
81
+ MODEL_CHECKPOINT = './model_checkpoint'
82
+ BATCH_SIZE = 8
83
+ EPOCHS = 100
84
+ SAVE_EPOCH = EPOCHS // 10 # Epoch when to save the model
85
+ VERBOSE_EPOCH = EPOCHS // 10
86
+ VALIDATION_ERR_LIMIT = 0.2 # Stop training if reached this limit
87
+ VALIDATION_ERR_LIMIT_COUNTER = 10 # Number of successive times the validation error was smaller than the threshold
88
+ VALIDATION_ERR_LIMIT_COUNTER_BACKUP = 10
89
+ THRESHOLD = 0.5 # Threshold to select the centerline in the interpolated images
90
+ RESTORE_TRAINING = True # look for previously saved models to resume training
91
+ EARLY_STOP_PATIENCE = 10
92
+ LOG_FIELD_NAMES = ['time', 'epoch', 'step',
93
+ 'training_loss_mean', 'training_loss_std',
94
+ 'training_loss1_mean', 'training_loss1_std',
95
+ 'training_loss2_mean', 'training_loss2_std',
96
+ 'training_loss3_mean', 'training_loss3_std',
97
+ 'training_ncc1_mean', 'training_ncc1_std',
98
+ 'training_ncc2_mean', 'training_ncc2_std',
99
+ 'training_ncc3_mean', 'training_ncc3_std',
100
+ 'validation_loss_mean', 'validation_loss_std',
101
+ 'validation_loss1_mean', 'validation_loss1_std',
102
+ 'validation_loss2_mean', 'validation_loss2_std',
103
+ 'validation_loss3_mean', 'validation_loss3_std',
104
+ 'validation_ncc1_mean', 'validation_ncc1_std',
105
+ 'validation_ncc2_mean', 'validation_ncc2_std',
106
+ 'validation_ncc3_mean', 'validation_ncc3_std']
107
+ LOG_FIELD_NAMES_SHORT = ['time', 'epoch', 'step',
108
+ 'training_loss_mean', 'training_loss_std',
109
+ 'training_loss1_mean', 'training_loss1_std',
110
+ 'training_loss2_mean', 'training_loss2_std',
111
+ 'training_ncc1_mean', 'training_ncc1_std',
112
+ 'training_ncc2_mean', 'training_ncc2_std',
113
+ 'validation_loss_mean', 'validation_loss_std',
114
+ 'validation_loss1_mean', 'validation_loss1_std',
115
+ 'validation_loss2_mean', 'validation_loss2_std',
116
+ 'validation_ncc1_mean', 'validation_ncc1_std',
117
+ 'validation_ncc2_mean', 'validation_ncc2_std']
118
+ LOG_FIELD_NAMES_UNET = ['time', 'epoch', 'step', 'reg_smooth_coeff', 'reg_jacob_coeff',
119
+ 'training_loss_mean', 'training_loss_std',
120
+ 'training_loss_dissim_mean', 'training_loss_dissim_std',
121
+ 'training_reg_smooth_mean', 'training_reg_smooth_std',
122
+ 'training_reg_jacob_mean', 'training_reg_jacob_std',
123
+ 'training_ncc_mean', 'training_ncc_std',
124
+ 'training_dice_mean', 'training_dice_std',
125
+ 'training_owo_mean', 'training_owo_std',
126
+ 'validation_loss_mean', 'validation_loss_std',
127
+ 'validation_loss_dissim_mean', 'validation_loss_dissim_std',
128
+ 'validation_reg_smooth_mean', 'validation_reg_smooth_std',
129
+ 'validation_reg_jacob_mean', 'validation_reg_jacob_std',
130
+ 'validation_ncc_mean', 'validation_ncc_std',
131
+ 'validation_dice_mean', 'validation_dice_std',
132
+ 'validation_owo_mean', 'validation_owo_std']
133
+ CUR_DATETIME = datetime.datetime.now().strftime("%H%M_%d%m%Y")
134
+ DESTINATION_FOLDER = 'training_log_' + CUR_DATETIME
135
+ CSV_DELIMITER = ";"
136
+ CSV_QUOTE_CHAR = '"'
137
+ REG_SMOOTH = 0.0
138
+ REG_MAG = 1.0
139
+ REG_TYPE = 'l2'
140
+ MAX_DISP_DM = 10.
141
+ MAX_DISP_DM_TF = tf.constant((MAX_DISP_DM,), tf.float32, name='MAX_DISP_DM')
142
+ MAX_DISP_DM_PERC = 0.25
143
+
144
+ W_SIM = 0.7
145
+ W_REG = 0.3
146
+ W_INV = 0.1
147
+
148
+ # Loss function parameters
149
+ REG_SMOOTH1 = 1 / 100000
150
+ REG_SMOOTH2 = 1 / 5000
151
+ REG_SMOOTH3 = 1 / 5000
152
+ LOSS1 = 1.0
153
+ LOSS2 = 0.6
154
+ LOSS3 = 0.3
155
+ REG_JACOBIAN = 0.1
156
+
157
+ LOSS_COEFFICIENT = 1.0
158
+ REG_COEFFICIENT = 1.0
159
+
160
+ DICE_SMOOTH = 1.
161
+
162
+ CC_WINDOW = [9,9,9]
163
+
164
+ # Adam optimizer
165
+ LEARNING_RATE = 1e-3
166
+ B1 = 0.9
167
+ B2 = 0.999
168
+ LEARNING_RATE_DECAY = 0.01
169
+ LEARNING_RATE_DECAY_STEP = 10000 # Update the learning rate every LEARNING_RATE_DECAY_STEP steps
170
+ OPTIMIZER = 'adam'
171
+
172
+ # Network architecture constants
173
+ LAYER_MAXPOOL = 0
174
+ LAYER_UPSAMP = 1
175
+ LAYER_CONV = 2
176
+ AFFINE_TRANSF = False
177
+ OUTPUT_LAYER = 3
178
+ DROPOUT = True
179
+ DROPOUT_RATE = 0.2
180
+ MAX_DATA_SIZE = (1000, 1000, 1)
181
+ PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
182
+ ENCODER_FILTERS = [4, 8, 16, 32, 64]
183
+
184
+ # SSIM
185
+ SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
186
+ SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
187
+ SSIM_K1 = 0.01 # Def. 0.01
188
+ SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
189
+ MAX_VALUE = 1.0 # Maximum intensity values
190
+
191
+ # Mathematic constants
192
+ EPS = 1e-8
193
+ EPS_tf = tf.constant(EPS, dtype=tf.float32)
194
+ LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
195
+
196
+ # Debug constants
197
+ VERBOSE = False
198
+ DEBUG = False
199
+ DEBUG_TRAINING = False
200
+ DEBUG_INPUT_DATA = False
201
+
202
+ # Plotting
203
+ FONT_SIZE = 10
204
+ DPI = 200 # Dots Per Inch
205
+
206
+ # Coordinates
207
+ B = 0 # Batch dimension
208
+ H = 1 # Height dimension
209
+ W = 2 # Width dimension
210
+ D = 3 # Depth
211
+ C = -1 # Channel dimension
212
+
213
+ D_DISP = 2
214
+ W_DISP = 1
215
+ H_DISP = 0
216
+
217
+ DIMENSIONALITY = 3
218
+
219
+ # Interpolation type
220
+ BIL_INTERP = 0
221
+ TPS_INTERP = 1
222
+ CUADRATIC_C = 0.5
223
+
224
+ # Data augmentation
225
+ MAX_DISP = 5 # Test = 15
226
+ NUM_ROT = 5
227
+ NUM_FLIPS = 2
228
+ MAX_ANGLE = 10
229
+
230
+ # Thin Plate Splines implementation constants
231
+ TPS_NUM_CTRL_PTS_PER_AXIS = 4
232
+ TPS_NUM_CTRL_PTS = np.power(TPS_NUM_CTRL_PTS_PER_AXIS, DIMENSIONALITY)
233
+ TPS_REG = 0.01
234
+ DISP_SCALE = 2 # Scaling of the output of the CNN to increase the range of tanh
235
+
236
+
237
+ class CoordinatesGrid:
238
+ def __init__(self):
239
+ self.__grid = 0
240
+ self.__grid_fl = 0
241
+ self.__norm = False
242
+ self.__num_pts = 0
243
+ self.__batches = False
244
+ self.__shape = None
245
+ self.__shape_flat = None
246
+
247
+ def set_coords_grid(self, img_shape: tf.TensorShape, num_ppa: int = None, batches: bool = False,
248
+ img_type: tf.DType = tf.float32, norm: bool = False):
249
+ self.__batches = batches
250
+ not_batches = not batches # Just to not make a too complex code when indexing the values
251
+ if num_ppa is None:
252
+ num_ppa = img_shape
253
+ if norm:
254
+ x_coords = tf.linspace(-1., 1.,
255
+ num_ppa[W - int(not_batches)]) # np.linspace works fine, tf had some issues...
256
+ y_coords = tf.linspace(-1., 1., num_ppa[H - int(not_batches)]) # num_ppa: number of points per axis
257
+ z_coords = tf.linspace(-1., 1., num_ppa[D - int(not_batches)])
258
+ else:
259
+ x_coords = tf.linspace(0., img_shape[W - int(not_batches)] - 1.,
260
+ num_ppa[W - int(not_batches)]) # np.linspace works fine, tf had some issues...
261
+ y_coords = tf.linspace(0., img_shape[H - int(not_batches)] - 1.,
262
+ num_ppa[H - int(not_batches)]) # num_ppa: number of points per axis
263
+ z_coords = tf.linspace(0., img_shape[D - int(not_batches)] - 1., num_ppa[D - int(not_batches)])
264
+
265
+ coords = tf.meshgrid(x_coords, y_coords, z_coords, indexing='ij')
266
+ self.__num_pts = num_ppa[W - int(not_batches)] * num_ppa[H - int(not_batches)] * num_ppa[D - int(not_batches)]
267
+
268
+ grid = tf.stack([coords[0], coords[1], coords[2]], axis=-1)
269
+ grid = tf.cast(grid, img_type)
270
+
271
+ grid_fl = tf.stack([tf.reshape(coords[0], [-1]),
272
+ tf.reshape(coords[1], [-1]),
273
+ tf.reshape(coords[2], [-1])], axis=-1)
274
+ grid_fl = tf.cast(grid_fl, img_type)
275
+
276
+ grid_homogeneous = tf.stack([tf.reshape(coords[0], [-1]),
277
+ tf.reshape(coords[1], [-1]),
278
+ tf.reshape(coords[2], [-1]),
279
+ tf.ones_like(tf.reshape(coords[0], [-1]))], axis=-1)
280
+
281
+ self.__shape = np.asarray([num_ppa[W - int(not_batches)], num_ppa[H - int(not_batches)], num_ppa[D - int(not_batches)], 3])
282
+ total_num_pts = np.prod(self.__shape[:-1])
283
+ self.__shape_flat = np.asarray([total_num_pts, 3])
284
+ if batches:
285
+ grid = tf.expand_dims(grid, axis=0)
286
+ grid = tf.tile(grid, [img_shape[B], 1, 1, 1, 1])
287
+ grid_fl = tf.expand_dims(grid_fl, axis=0)
288
+ grid_fl = tf.tile(grid_fl, [img_shape[B], 1, 1])
289
+ grid_homogeneous = tf.expand_dims(grid_homogeneous, axis=0)
290
+ grid_homogeneous = tf.tile(grid_homogeneous, [img_shape[B], 1, 1])
291
+ self.__shape = np.concatenate([np.asarray((img_shape[B],)), self.__shape])
292
+ self.__shape_flat = np.concatenate([np.asarray((img_shape[B],)), self.__shape_flat])
293
+
294
+ self.__norm = norm
295
+ self.__grid_fl = grid_fl
296
+ self.__grid = grid
297
+ self.__grid_homogeneous = grid_homogeneous
298
+
299
+ @property
300
+ def grid(self):
301
+ return self.__grid
302
+
303
+ @property
304
+ def size(self):
305
+ return self.__len__()
306
+
307
+ def grid_flat(self, transpose=False):
308
+ if transpose:
309
+ if self.__batches:
310
+ ret = tf.transpose(self.__grid_fl, (0, 2, 1))
311
+ else:
312
+ ret = tf.transpose(self.__grid_fl)
313
+ else:
314
+ ret = self.__grid_fl
315
+ return ret
316
+
317
+ def grid_homogeneous(self, transpose=False):
318
+ if transpose:
319
+ if self.__batches:
320
+ ret = tf.transpose(self.__grid_homogeneous, (0, 2, 1))
321
+ else:
322
+ ret = tf.transpose(self.__grid_homogeneous)
323
+ else:
324
+ ret = self.__grid_homogeneous
325
+ return ret
326
+
327
+ @property
328
+ def is_normalized(self):
329
+ return self.__norm
330
+
331
+ def __len__(self):
332
+ return tf.size(self.__grid)
333
+
334
+ @property
335
+ def number_pts(self):
336
+ return self.__num_pts
337
+
338
+ @property
339
+ def shape_grid_flat(self):
340
+ return self.__shape_flat
341
+
342
+ @property
343
+ def shape(self):
344
+ return self.__shape
345
+
346
+
347
+
348
+ COORDS_GRID = CoordinatesGrid()
349
+
350
+
351
+ class VisualizationParameters:
352
+ def __init__(self):
353
+ self.__scale = None # See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.quiver.html
354
+ self.__spacing = 5
355
+
356
+ def set_spacing(self, img_shape: tf.TensorShape):
357
+ self.__spacing = int(5 * np.log(img_shape[W]))
358
+
359
+ @property
360
+ def spacing(self):
361
+ return self.__spacing
362
+
363
+ def set_arrow_scale(self, scale: int):
364
+ self.__scale = scale
365
+
366
+ @property
367
+ def arrow_scale(self):
368
+ return self.__scale
369
+
370
+
371
+ QUIVER_PARAMS = VisualizationParameters()
372
+
373
+ # Configuration file
374
+ CONF_FILE_NAME = 'configuration.txt'
375
+
376
+
377
+ def summary():
378
+ return '##### CONFIGURATION: REMOTE {} DEBUG {} DEBUG TRAINING {}' \
379
+ '\n\t\tLEARNING RATE: {}' \
380
+ '\n\t\tBATCH SIZE: {}' \
381
+ '\n\t\tLIMIT NUM SAMPLES: {}' \
382
+ '\n\t\tLOSS_FNC: {}' \
383
+ '\n\t\tTRAINING_DATASET: {} ({:.1f}%/{:.1f}%)' \
384
+ '\n\t\tTEST_DATASET: {}'.format(REMOTE, DEBUG, DEBUG_TRAINING, LEARNING_RATE, BATCH_SIZE, LIMIT_NUM_SAMPLES,
385
+ LOSS_FNC, TRAINING_DATASET, TRAINING_PERC * 100, (1 - TRAINING_PERC) * 100,
386
+ TEST_DATASET)
387
+
388
+
389
+ # LOG Severity levers
390
+ # https://docs.python.org/2/library/logging.html#logging-levels
391
+ INF = 20 # Information
392
+ WAR = 30 # Warning
393
+ ERR = 40 # Error
394
+ DEB = 10 # Debug
395
+ CRI = 50 # Critical
396
+
397
+ SEVERITY_STR = {INF: 'INFO',
398
+ WAR: 'WARNING',
399
+ ERR: 'ERROR',
400
+ DEB: 'DEBUG',
401
+ CRI: 'CRITICAL'}
402
+
403
+ HL_LOG_FIELD_NAMES = ['Time', 'Epoch', 'Step',
404
+ 'train_loss', 'train_loss_std',
405
+ 'train_loss1', 'train_loss1_std',
406
+ 'train_loss2', 'train_loss2_std',
407
+ 'train_loss3', 'train_loss3_std',
408
+ 'train_NCC', 'train_NCC_std',
409
+ 'val_loss', 'val_loss_std',
410
+ 'val_loss1', 'val_loss1_std',
411
+ 'val_loss2', 'val_loss2_std',
412
+ 'val_loss3', 'val_loss3_std',
413
+ 'val_NCC', 'val_NCC_std']
414
+
415
+ # Sobel filters
416
+ SOBEL_W_2D = tf.constant([[-1., 0., 1.],
417
+ [-2., 0., 2.],
418
+ [-1., 0., 1.]], dtype=tf.float32, name='sobel_w_2d')
419
+ SOBEL_W_3D = tf.tile(tf.expand_dims(SOBEL_W_2D, axis=-1), [1, 1, 3])
420
+ SOBEL_H_3D = tf.transpose(SOBEL_W_3D, [1, 0, 2])
421
+ SOBEL_D_3D = tf.transpose(SOBEL_W_3D, [2, 1, 0])
422
+
423
+ aux = tf.expand_dims(tf.expand_dims(SOBEL_W_3D, axis=-1), axis=-1)
424
+ SOBEL_FILTER_W_3D_IMAGE = aux
425
+ SOBEL_FILTER_W_3D = tf.tile(aux, [1, 1, 1, 3, 3])
426
+ # tf.nn.conv3d expects the filter in [D, H, W, C_in, C_out] order
427
+ SOBEL_FILTER_W_3D = tf.transpose(SOBEL_FILTER_W_3D, [2, 0, 1, 3, 4], name='sobel_filter_i_3d')
428
+
429
+ aux = tf.expand_dims(tf.expand_dims(SOBEL_H_3D, axis=-1), axis=-1)
430
+ SOBEL_FILTER_H_3D_IMAGE = aux
431
+ SOBEL_FILTER_H_3D = tf.tile(aux, [1, 1, 1, 3, 3])
432
+ SOBEL_FILTER_H_3D = tf.transpose(SOBEL_FILTER_H_3D, [2, 0, 1, 3, 4], name='sobel_filter_j_3d')
433
+
434
+ aux = tf.expand_dims(tf.expand_dims(SOBEL_D_3D, axis=-1), axis=-1)
435
+ SOBEL_FILTER_D_3D_IMAGE = aux
436
+ SOBEL_FILTER_D_3D = tf.tile(aux, [1, 1, 1, 3, 3])
437
+ SOBEL_FILTER_D_3D = tf.transpose(SOBEL_FILTER_D_3D, [2, 1, 0, 3, 4], name='sobel_filter_k_3d')
438
+
439
+ # Filters for spatial integration of the displacement map
440
+ INTEG_WIND_SIZE = IMG_SIZE
441
+ INTEG_STEPS = 4 # VoxelMorph default value for the integration of the stationary velocity field. >4 memory alloc issue
442
+ INTEG_FILTER_D = tf.ones([INTEG_WIND_SIZE, 1, 1, 1, 1], name='integrate_h_filter')
443
+ INTEG_FILTER_H = tf.ones([1, INTEG_WIND_SIZE, 1, 1, 1], name='integrate_w_filter')
444
+ INTEG_FILTER_W = tf.ones([1, 1, INTEG_WIND_SIZE, 1, 1], name='integrate_d_filter')
445
+
446
+ # Laplacian filter
447
+ LAPLACIAN_27_P = tf.constant(np.asarray([np.ones((3, 3)),
448
+ [[1, 1, 1],
449
+ [1, -26, 1],
450
+ [1, 1, 1]],
451
+ np.ones((3, 3))]), tf.float32)
452
+ LAPLACIAN_27_P = tf.expand_dims(tf.expand_dims(LAPLACIAN_27_P, axis=-1), axis=-1)
453
+ LAPLACIAN_27_P = tf.tile(LAPLACIAN_27_P, [1, 1, 1, 3, 3], name='laplacian_27_p')
454
+
455
+
456
+ LAPLACIAN_7_P = tf.constant(np.asarray([[[0, 0, 0],
457
+ [0, 1, 0],
458
+ [0, 0, 0]],
459
+ [[0, 1, 0],
460
+ [1, -6, 1],
461
+ [0, 1, 0]],
462
+ [[0, 0, 0],
463
+ [0, 1, 0],
464
+ [0, 0, 0]]]), tf.float32)
465
+ LAPLACIAN_7_P = tf.expand_dims(tf.expand_dims(LAPLACIAN_7_P, axis=-1), axis=-1)
466
+ LAPLACIAN_7_P = tf.tile(LAPLACIAN_7_P, [1, 1, 1, 3, 3], name='laplacian_7_p')
467
+
468
+ # Constants for bias loss
469
+ ZERO_WARP = tf.zeros((1,) + DISP_MAP_SHAPE, name='zero_warp')
470
+ BIAS_WARP_WEIGHT = 1e-02
471
+ BIAS_AFFINE_WEIGHT = 1e-02
472
+
473
+ # Overlapping score
474
+ OS_SCALE = 10
475
+ EPS_1 = 1.0
476
+ EPS_1_tf = tf.constant(EPS_1)
477
+
478
+ # LDDMM
479
+ GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
480
+
481
+ # Constants for MultiLoss layer
482
+ PRIOR_W = [1., 1 / 60, 1.]
483
+ MANUAL_W = [1.] * len(PRIOR_W)
484
+
485
+ REG_PRIOR_W = [1e-3]
486
+ REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
487
+
DeepDeformationMapRegistration/utils/misc.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import errno
3
+ import nibabel as nb
4
+ import numpy as np
5
+ import re
6
+
7
+ def try_mkdir(dir):
8
+ try:
9
+ os.makedirs(dir)
10
+ except OSError as err:
11
+ if err.errno == errno.EEXIST:
12
+ print("Directory " + dir + " already exists")
13
+ else:
14
+ raise ValueError("Can't create dir " + dir)
15
+ else:
16
+ print("Created directory " + dir)
17
+
18
+
19
+ def function_decorator(new_name):
20
+ """"
21
+ Change the __name__ property of a function using new_name.
22
+ :param new_name:
23
+ :return:
24
+ """
25
+ def decorator(func):
26
+ func.__name__ = new_name
27
+ return func
28
+ return decorator
DeepDeformationMapRegistration/utils/nifty_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import errno
3
+ import nibabel as nb
4
+ import numpy as np
5
+ import re
6
+ import zipfile
7
+ import tensorflow as tf
8
+
9
+
10
+ TEMP_UNZIP_PATH = '/mnt/EncryptedData1/Users/javier/ext_datasets/LITS17/temp'
11
+ NII_EXTENSION = '.nii'
12
+
13
+
14
+ def save_nifti(data, save_path):
15
+ data_nifti = nb.Nifti1Image(data, affine=np.eye(4))
16
+
17
+ data_nifti.header.get_xyzt_units()
18
+ try:
19
+ data_nifti.to_filename(save_path) # Save as NiBabel file
20
+ print('Saved {}'.format(save_path))
21
+ except ValueError:
22
+ print('Could not save {}'.format(save_path))
23
+
24
+
25
+ def unzip_nii_file(file_path):
26
+ file_dir, file_name = os.path.split(file_path)
27
+ file_name = file_name.split('.zip')[0]
28
+
29
+ dest_path = os.path.join(TEMP_UNZIP_PATH, file_name)
30
+ zipfile.ZipFile(file_path).extractall(TEMP_UNZIP_PATH)
31
+
32
+ if not os.path.exists(dest_path):
33
+ print('ERR: File {} not unzip-ed!'.format(file_path))
34
+ dest_path = None
35
+ return dest_path
36
+
37
+
38
+ def delete_temp(file_path, verbose=False):
39
+ assert NII_EXTENSION in file_path
40
+ os.remove(file_path)
41
+ if verbose:
42
+ print('Deleted file: ', file_path)
DeepDeformationMapRegistration/utils/operators ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+
4
+
5
+ def min_max_norm(img: np.ndarray, out_max_val=1.):
6
+ out_img = img
7
+ max_val = np.amax(img)
8
+ min_val = np.amin(img)
9
+ if (max_val - min_val) != 0:
10
+ out_img = (img - min_val) / (max_val - min_val)
11
+ return out_img * out_max_val
12
+
13
+
14
+ def soft_threshold(x, threshold, name=None):
15
+ # https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
16
+ with tf.name_scope(name or 'soft_threshold'):
17
+ x = tf.convert_to_tensor(x, name='x')
18
+ threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
19
+ return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
20
+
21
+
22
+ def binary_activation(x):
23
+ # https://stackoverflow.com/questions/37743574/hard-limiting-threshold-activation-function-in-tensorflow
24
+ cond = tf.less(x, tf.zeros(tf.shape(x)))
25
+ out = tf.where(cond, tf.zeros(tf.shape(x)), tf.ones(tf.shape(x)))
26
+
27
+ return out
28
+
DeepDeformationMapRegistration/utils/user_interface.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import re
3
+ import os
4
+
5
+
6
+ def show_and_select(file_list, msg='Select a file by the number: ', int_if_single=True):
7
+ # If the selection is a single number, then return that number instead of the list of length 1
8
+ invalid_selection = True
9
+ while invalid_selection:
10
+ for i, f in enumerate(file_list):
11
+ print('{:03d}) {}'. format(i+1, os.path.split(f)[-1]))
12
+
13
+ sel = np.asarray(re.split(',\s|,|\s',input(msg)), np.int) - 1
14
+
15
+ if (np.all(sel >= 0)) and (np.all(sel <= len(file_list))):
16
+ invalid_selection = False
17
+ sel = [file_list[s] for s in sel]
18
+ print('Selected: ', ', '.join([os.path.split(f)[-1] for f in sel]))
19
+
20
+ if int_if_single:
21
+ if len(sel) == 1:
22
+ sel = sel[0]
23
+ return sel
DeepDeformationMapRegistration/utils/visualization.py ADDED
@@ -0,0 +1,1151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import matplotlib
2
+ # matplotlib.use('TkAgg')
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.mplot3d import Axes3D
5
+ import matplotlib.colors as mcolors
6
+ from matplotlib.lines import Line2D
7
+ from matplotlib import cm
8
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
9
+ import tensorflow as tf
10
+ import numpy as np
11
+ import pyVesselRegistration_constants as const
12
+ from skimage.exposure import rescale_intensity
13
+ import scipy.misc as scpmisc
14
+ import os
15
+
16
+ THRES = 0.9
17
+
18
+ # COLOR MAPS
19
+ chunks = np.linspace(0, 1, 10)
20
+ cmap1 = plt.get_cmap('hsv', 4)
21
+ # cmaplist = [cmap1(i) for i in range(cmap1.N)]
22
+ cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
23
+ cmaplist[0] = (1, 1, 1, 1.0)
24
+ cmap1 = mcolors.LinearSegmentedColormap.from_list('custom', cmaplist, cmap1.N)
25
+
26
+ colors = [(0, 0, 1, i) for i in chunks]
27
+ cmap2 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
28
+
29
+ colors = [(230 / 255, 97 / 255, 1 / 255, i) for i in chunks]
30
+ cmap3 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
31
+
32
+ colors = [(128 / 255, 0 / 255, 32 / 255, i) for i in chunks]
33
+ cmap4 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
34
+
35
+ cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
36
+
37
+
38
+ def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
39
+ if dimensionality == 2:
40
+ _plot_2d(sample, ax, c, name=name)
41
+ elif dimensionality == 3:
42
+ _plot_3d(sample, ax, c, name=name)
43
+ else:
44
+ raise ValueError('Invalid valud for dimensionality. Expected int 2 or 3')
45
+
46
+
47
+ def matrix_to_orthographicProjection(matrix: np.ndarray, ret_list=False):
48
+ """ Given a 3D matrix, returns the three orthographic projections: top, front, right.
49
+ Top corresponds to dimensions 1 and 2
50
+ Front corresponds to dimensions 0 and 1
51
+ Right corresponds to dimensions 0 and 2
52
+
53
+ :param matrix: 3D matrix
54
+ :param ret_list: return a list instead of an array (optional)
55
+ :return: list or array with the three views [top, front, right]
56
+ """
57
+ top = _getProjection(matrix, dim=0) # YZ
58
+ front = _getProjection(matrix, dim=2) # XY
59
+ right = _getProjection(matrix, dim=1) # XZ
60
+
61
+ if ret_list:
62
+ return top, front, right
63
+ else:
64
+ return np.asarray([top, front, right])
65
+
66
+
67
+ def _getProjection(matrix: np.ndarray, dim: int):
68
+ orth_view = matrix.sum(axis=dim, dtype=float)
69
+ orth_view = orth_view > 0.0
70
+ orth_view.astype(np.float)
71
+
72
+ return orth_view
73
+
74
+
75
+ def orthographicProjection_to_matrix(top: np.ndarray, front: np.ndarray, right: np.ndarray):
76
+ """ Given the three orthographic projections, it returns a 3D-view of the object based on back projection
77
+
78
+ :param top: 2D view top view
79
+ :param front: 2D front view
80
+ :param right: 2D right view
81
+ :return: matrix with the 3D-view
82
+ """
83
+ top_mat = np.tile(top, (front.shape[0], 1, 1))
84
+ front_mat = np.tile(top, (right.shape[1], 1, 1))
85
+ right_mat = np.tile(top, (top.shape[0], 1, 1))
86
+
87
+ reconstruction = np.zeros((front.shape[0], right.shape[1], top.shape[0]))
88
+ iter = np.nditer([top_mat, front_mat, right_mat, reconstruction], flags=['multi_index'], op_flags=['readwrite'])
89
+ while not iter.finished:
90
+ if iter[0] and iter[1] and iter[2]:
91
+ iter[3] = 1
92
+ iter.iternext()
93
+
94
+ return reconstruction
95
+
96
+
97
+ def _plot_2d(sample: np.ndarray, ax=None, c=None, name=None):
98
+ if isinstance(sample, tf.Tensor):
99
+ sample = sample.eval(session=tf.Session())
100
+
101
+ x_range = list()
102
+ y_range = list()
103
+ marker_size = list()
104
+ for idx, val in np.ndenumerate(sample):
105
+ if val >= THRES:
106
+ x_range.append(idx[0])
107
+ y_range.append(idx[1])
108
+ marker_size.append(val ** 2)
109
+
110
+ if not ax:
111
+ fig = plt.figure()
112
+ ax = fig.add_subplot(111)
113
+
114
+ if c:
115
+ ax.scatter(x_range, y_range, c=c, s=marker_size)
116
+ else:
117
+ ax.scatter(x_range, y_range, s=marker_size)
118
+
119
+ ax.set_xlabel('X')
120
+ ax.set_ylabel('Y')
121
+ if name:
122
+ ax.set_title(name)
123
+
124
+ return ax
125
+
126
+
127
+ def _plot_3d(sample: np.ndarray, ax=None, c=None, name=None):
128
+ from mpl_toolkits.mplot3d import Axes3D
129
+ if isinstance(sample, tf.Tensor):
130
+ sample = sample.eval(session=tf.Session())
131
+
132
+ x_range = list()
133
+ y_range = list()
134
+ z_range = list()
135
+ marker_size = list()
136
+ for idx, val in np.ndenumerate(sample):
137
+ if val >= THRES:
138
+ x_range.append(idx[0])
139
+ y_range.append(idx[1])
140
+ z_range.append(idx[2])
141
+ marker_size.append(val ** 2)
142
+
143
+ print('Found ', len(x_range), ' points')
144
+ # x_range = np.linspace(start=0, stop=sample.shape[0], num=sample.shape[0])
145
+ # y_range = np.linspace(start=0, stop=sample.shape[1], num=sample.shape[1])
146
+ # z_range = np.linspace(start=0, stop=sample.shape[2], num=sample.shape[2])
147
+ #
148
+ # sample_flat = sample.flatten(order='C')
149
+
150
+ if len(x_range):
151
+ if not ax:
152
+ fig = plt.figure()
153
+ ax = fig.add_subplot(111, projection='3d')
154
+
155
+ if c:
156
+ ax.scatter(x_range, y_range, z_range, c=c, s=marker_size)
157
+ else:
158
+ ax.scatter(x_range, y_range, z_range, s=marker_size)
159
+ # ax.scatter(x_range, y_range, z_range, s=marker_size)#, c=sample_flat)
160
+
161
+ ax.set_xlabel('X')
162
+ ax.set_ylabel('Y')
163
+ ax.set_zlabel('Z')
164
+ if name:
165
+ ax.set_title(name)
166
+
167
+ return ax
168
+ else:
169
+ print('Nothing to plot')
170
+ return None
171
+
172
+
173
+ def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None):
174
+ if fig is not None:
175
+ fig.clear()
176
+ plt.figure(fig.number)
177
+ else:
178
+ fig = plt.figure(dpi=const.DPI)
179
+
180
+ ax_fix = fig.add_subplot(231)
181
+ im_fix = ax_fix.imshow(list_imgs[0][:, :, 0])
182
+ ax_fix.set_title('Fix image', fontsize=const.FONT_SIZE)
183
+ ax_fix.tick_params(axis='both',
184
+ which='both',
185
+ bottom=False,
186
+ left=False,
187
+ labelleft=False,
188
+ labelbottom=False)
189
+ ax_mov = fig.add_subplot(232)
190
+ im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
191
+ ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
192
+ ax_mov.tick_params(axis='both',
193
+ which='both',
194
+ bottom=False,
195
+ left=False,
196
+ labelleft=False,
197
+ labelbottom=False)
198
+
199
+ ax_pred_im = fig.add_subplot(233)
200
+ im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
201
+ ax_pred_im.set_title('Prediction', fontsize=const.FONT_SIZE)
202
+ ax_pred_im.tick_params(axis='both',
203
+ which='both',
204
+ bottom=False,
205
+ left=False,
206
+ labelleft=False,
207
+ labelbottom=False)
208
+
209
+ ax_pred_disp = fig.add_subplot(234)
210
+ if affine_transf:
211
+ fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
212
+ [0.0, 0.0, 0.0, 0.0],
213
+ [0.0, 0.0, 0.0, 0.0],
214
+ [0.0, 0.0, 0.0, 0.0]])
215
+
216
+ bottom = np.asarray([0, 0, 0, 1])
217
+
218
+ transf_mat = np.reshape(list_imgs[3], (2, 3))
219
+ transf_mat = np.stack([transf_mat, bottom], axis=0)
220
+
221
+ im_pred_disp = ax_pred_disp.imshow(fake_bg)
222
+ for i in range(4):
223
+ for j in range(4):
224
+ ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
225
+
226
+ ax_pred_disp.set_title('Affine transformation matrix')
227
+
228
+ else:
229
+ cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3])
230
+ im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
231
+ ax_pred_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
232
+ ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
233
+ ax_pred_disp.tick_params(axis='both',
234
+ which='both',
235
+ bottom=False,
236
+ left=False,
237
+ labelleft=False,
238
+ labelbottom=False)
239
+
240
+ ax_gt_disp = fig.add_subplot(235)
241
+ if affine_transf:
242
+ fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
243
+ [0.0, 0.0, 0.0, 0.0],
244
+ [0.0, 0.0, 0.0, 0.0],
245
+ [0.0, 0.0, 0.0, 0.0]])
246
+
247
+ bottom = np.asarray([0, 0, 0, 1])
248
+
249
+ transf_mat = np.reshape(list_imgs[4], (2, 3))
250
+ transf_mat = np.stack([transf_mat, bottom], axis=0)
251
+
252
+ im_gt_disp = ax_pred_disp.imshow(fake_bg)
253
+ for i in range(4):
254
+ for j in range(4):
255
+ ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
256
+
257
+ ax_pred_disp.set_title('Affine transformation matrix')
258
+
259
+ else:
260
+ cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4])
261
+ im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal')
262
+ ax_gt_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
263
+ ax_gt_disp.set_title('GT disp map', fontsize=const.FONT_SIZE)
264
+ ax_gt_disp.tick_params(axis='both',
265
+ which='both',
266
+ bottom=False,
267
+ left=False,
268
+ labelleft=False,
269
+ labelbottom=False)
270
+
271
+ cb_fix = _set_colorbar(fig, ax_fix, im_fix, False)
272
+ cb_mov = _set_colorbar(fig, ax_mov, im_mov, False)
273
+ cb_pred = _set_colorbar(fig, ax_pred_im, im_pred_im, False)
274
+ cb_pred_disp = _set_colorbar(fig, ax_pred_disp, im_pred_disp, False)
275
+ cd_gt_disp = _set_colorbar(fig, ax_gt_disp, im_gt_disp, False)
276
+
277
+ if filename is not None:
278
+ plt.savefig(filename, format='png') # Call before show
279
+ if not const.REMOTE:
280
+ plt.show()
281
+ else:
282
+ plt.close()
283
+ return fig
284
+
285
+
286
+ def save_centreline_img(img, title, filename, fig=None):
287
+ if fig is not None:
288
+ fig.clear()
289
+ plt.figure(fig.number)
290
+ else:
291
+ fig = plt.figure(dpi=const.DPI)
292
+
293
+ dim = len(img.shape[:-1])
294
+
295
+ if dim == 2:
296
+ ax = fig.add_subplot(111)
297
+ fig.suptitle(title)
298
+ im = ax.imshow(img[..., 0], cmap=cmap_bin)
299
+ ax.tick_params(axis='both',
300
+ which='both',
301
+ bottom=False,
302
+ left=False,
303
+ labelleft=False,
304
+ labelbottom=False)
305
+
306
+ #cb = _set_colorbar(fig, ax, im, False)
307
+ else:
308
+ ax = fig.add_subplot(111, projection='3d')
309
+ fig.suptitle(title)
310
+ im = ax.voxels(img[0, ..., 0] > 0.0)
311
+ _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
312
+
313
+ ax.tick_params(axis='both',
314
+ which='both',
315
+ bottom=False,
316
+ left=False,
317
+ labelleft=False,
318
+ labelbottom=False)
319
+
320
+ plt.savefig(filename, format='png')
321
+ plt.close()
322
+
323
+
324
+ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
325
+ if fig is not None:
326
+ fig.clear()
327
+ plt.figure(fig.number)
328
+ else:
329
+ fig = plt.figure(dpi=const.DPI)
330
+
331
+ dim = disp_map.shape[-1]
332
+
333
+ if dim == 2:
334
+ ax_x = fig.add_subplot(131)
335
+ ax_x.set_title('H displacement')
336
+ im_x = ax_x.imshow(disp_map[..., const.H_DISP])
337
+ ax_x.tick_params(axis='both',
338
+ which='both',
339
+ bottom=False,
340
+ left=False,
341
+ labelleft=False,
342
+ labelbottom=False)
343
+ cb_x = _set_colorbar(fig, ax_x, im_x, False)
344
+
345
+ ax_y = fig.add_subplot(132)
346
+ ax_y.set_title('W displacement')
347
+ im_y = ax_y.imshow(disp_map[..., const.W_DISP])
348
+ ax_y.tick_params(axis='both',
349
+ which='both',
350
+ bottom=False,
351
+ left=False,
352
+ labelleft=False,
353
+ labelbottom=False)
354
+ cb_y = _set_colorbar(fig, ax_y, im_y, False)
355
+
356
+ ax = fig.add_subplot(133)
357
+ if affine_transf:
358
+ fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
359
+ [0.0, 0.0, 0.0, 0.0],
360
+ [0.0, 0.0, 0.0, 0.0],
361
+ [0.0, 0.0, 0.0, 0.0]])
362
+
363
+ bottom = np.asarray([0, 0, 0, 1])
364
+
365
+ transf_mat = np.reshape(disp_map, (2, 3))
366
+ transf_mat = np.stack([transf_mat, bottom], axis=0)
367
+
368
+ im = ax.imshow(fake_bg)
369
+ for i in range(4):
370
+ for j in range(4):
371
+ ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
372
+
373
+ else:
374
+ c, d, s = _prepare_quiver_map(disp_map, dim=dim)
375
+ im = ax.imshow(s, interpolation='none', aspect='equal')
376
+ ax.quiver(c[const.H_DISP], c[const.W_DISP], d[const.H_DISP], d[const.W_DISP],
377
+ scale=const.QUIVER_PARAMS.arrow_scale)
378
+ cb = _set_colorbar(fig, ax, im, False)
379
+ ax.set_title('Displacement map')
380
+ ax.tick_params(axis='both',
381
+ which='both',
382
+ bottom=False,
383
+ left=False,
384
+ labelleft=False,
385
+ labelbottom=False)
386
+ fig.suptitle(title)
387
+ else:
388
+ ax = fig.add_subplot(111, projection='3d')
389
+ c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
390
+ ax.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP], d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
391
+ norm=True)
392
+ _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
393
+ fig.suptitle('Displacement map')
394
+ ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html
395
+ which='both',
396
+ bottom=False,
397
+ left=False,
398
+ labelleft=False,
399
+ labelbottom=False)
400
+ fig.suptitle(title)
401
+
402
+ plt.savefig(filename, format='png')
403
+ plt.close()
404
+
405
+
406
+ def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None,
407
+ title_first_row='TRAINING', title_second_row='VALIDATION'):
408
+ if fig is not None:
409
+ fig.clear()
410
+ plt.figure(fig.number)
411
+ else:
412
+ fig = plt.figure(dpi=const.DPI)
413
+
414
+ dim = len(list_imgs[0].shape[:-1])
415
+
416
+ if dim == 2:
417
+ # TRAINING
418
+ ax_input = fig.add_subplot(241)
419
+ ax_input.set_ylabel(title_first_row, fontsize=const.FONT_SIZE)
420
+ im_fix = ax_input.imshow(list_imgs[0][:, :, 0])
421
+ ax_input.set_title('Fix image', fontsize=const.FONT_SIZE)
422
+ ax_input.tick_params(axis='both',
423
+ which='both',
424
+ bottom=False,
425
+ left=False,
426
+ labelleft=False,
427
+ labelbottom=False)
428
+ ax_mov = fig.add_subplot(242)
429
+ im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
430
+ ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
431
+ ax_mov.tick_params(axis='both',
432
+ which='both',
433
+ bottom=False,
434
+ left=False,
435
+ labelleft=False,
436
+ labelbottom=False)
437
+
438
+ ax_pred_im = fig.add_subplot(244)
439
+ im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
440
+ ax_pred_im.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
441
+ ax_pred_im.tick_params(axis='both',
442
+ which='both',
443
+ bottom=False,
444
+ left=False,
445
+ labelleft=False,
446
+ labelbottom=False)
447
+
448
+ ax_pred_disp = fig.add_subplot(243)
449
+ if affine_transf:
450
+ fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
451
+ [0.0, 0.0, 0.0, 0.0],
452
+ [0.0, 0.0, 0.0, 0.0],
453
+ [0.0, 0.0, 0.0, 0.0]])
454
+
455
+ bottom = np.asarray([0, 0, 0, 1])
456
+
457
+ transf_mat = np.reshape(list_imgs[3], (2, 3))
458
+ transf_mat = np.stack([transf_mat, bottom], axis=0)
459
+
460
+ im_pred_disp = ax_pred_disp.imshow(fake_bg)
461
+ for i in range(4):
462
+ for j in range(4):
463
+ ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
464
+
465
+ ax_pred_disp.set_title('Affine transformation matrix')
466
+
467
+ else:
468
+ cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim)
469
+ im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
470
+ ax_pred_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
471
+ ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
472
+ ax_pred_disp.tick_params(axis='both',
473
+ which='both',
474
+ bottom=False,
475
+ left=False,
476
+ labelleft=False,
477
+ labelbottom=False)
478
+
479
+ # VALIDATION
480
+ axinput_val = fig.add_subplot(245)
481
+ axinput_val.set_ylabel(title_second_row, fontsize=const.FONT_SIZE)
482
+ im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0])
483
+ axinput_val.set_title('Fix image', fontsize=const.FONT_SIZE)
484
+ axinput_val.tick_params(axis='both',
485
+ which='both',
486
+ bottom=False,
487
+ left=False,
488
+ labelleft=False,
489
+ labelbottom=False)
490
+ ax_mov_val = fig.add_subplot(246)
491
+ im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0])
492
+ ax_mov_val.set_title('Moving image', fontsize=const.FONT_SIZE)
493
+ ax_mov_val.tick_params(axis='both',
494
+ which='both',
495
+ bottom=False,
496
+ left=False,
497
+ labelleft=False,
498
+ labelbottom=False)
499
+
500
+ ax_pred_im_val = fig.add_subplot(248)
501
+ im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0])
502
+ ax_pred_im_val.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
503
+ ax_pred_im_val.tick_params(axis='both',
504
+ which='both',
505
+ bottom=False,
506
+ left=False,
507
+ labelleft=False,
508
+ labelbottom=False)
509
+
510
+ ax_pred_disp_val = fig.add_subplot(247)
511
+ if affine_transf:
512
+ fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
513
+ [0.0, 0.0, 0.0, 0.0],
514
+ [0.0, 0.0, 0.0, 0.0],
515
+ [0.0, 0.0, 0.0, 0.0]])
516
+
517
+ bottom = np.asarray([0, 0, 0, 1])
518
+
519
+ transf_mat = np.reshape(list_imgs[7], (2, 3))
520
+ transf_mat = np.stack([transf_mat, bottom], axis=0)
521
+
522
+ im_pred_disp_val = ax_pred_disp_val.imshow(fake_bg)
523
+ for i in range(4):
524
+ for j in range(4):
525
+ ax_pred_disp_val.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
526
+
527
+ ax_pred_disp_val.set_title('Affine transformation matrix')
528
+
529
+ else:
530
+ c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
531
+ im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
532
+ ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=const.QUIVER_PARAMS.arrow_scale)
533
+ ax_pred_disp_val.set_title('Pred disp map', fontsize=const.FONT_SIZE)
534
+ ax_pred_disp_val.tick_params(axis='both',
535
+ which='both',
536
+ bottom=False,
537
+ left=False,
538
+ labelleft=False,
539
+ labelbottom=False)
540
+
541
+ cb_fix = _set_colorbar(fig, ax_input, im_fix, False)
542
+ cb_mov = _set_colorbar(fig, ax_mov, im_mov, False)
543
+ cb_pred = _set_colorbar(fig, ax_pred_im, im_pred_im, False)
544
+ cb_pred_disp = _set_colorbar(fig, ax_pred_disp, im_pred_disp, False)
545
+
546
+ cd_fix_val = _set_colorbar(fig, axinput_val, im_fix_val, False)
547
+ cb_mov_val = _set_colorbar(fig, ax_mov_val, im_mov_val, False)
548
+ cb_pred_val = _set_colorbar(fig, ax_pred_im_val, im_pred_im_val, False)
549
+ cb_pred_disp_val = _set_colorbar(fig, ax_pred_disp_val, im_pred_disp_val, False)
550
+
551
+ else:
552
+ # 3D
553
+ # TRAINING
554
+ ax_input = fig.add_subplot(231, projection='3d')
555
+ ax_input.set_ylabel(title_first_row, fontsize=const.FONT_SIZE)
556
+ im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
557
+ im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
558
+ ax_input.set_title('Fix image', fontsize=const.FONT_SIZE)
559
+ ax_input.tick_params(axis='both',
560
+ which='both',
561
+ bottom=False,
562
+ left=False,
563
+ labelleft=False,
564
+ labelbottom=False)
565
+
566
+ ax_pred_im = fig.add_subplot(232, projection='3d')
567
+ im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction')
568
+ im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
569
+ ax_pred_im.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
570
+ ax_pred_im.tick_params(axis='both',
571
+ which='both',
572
+ bottom=False,
573
+ left=False,
574
+ labelleft=False,
575
+ labelbottom=False)
576
+
577
+ ax_pred_disp = fig.add_subplot(233, projection='3d')
578
+
579
+ c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim)
580
+ im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
581
+ ax_pred_disp.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
582
+ d[const.H_DISP], d[const.W_DISP], d[const.D_DISP], scale=const.QUIVER_PARAMS.arrow_scale)
583
+ ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
584
+ ax_pred_disp.tick_params(axis='both',
585
+ which='both',
586
+ bottom=False,
587
+ left=False,
588
+ labelleft=False,
589
+ labelbottom=False)
590
+
591
+ # VALIDATION
592
+ axinput_val = fig.add_subplot(234, projection='3d')
593
+ axinput_val.set_ylabel(title_second_row, fontsize=const.FONT_SIZE)
594
+ im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
595
+ im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)')
596
+ axinput_val.set_title('Fix image', fontsize=const.FONT_SIZE)
597
+ axinput_val.tick_params(axis='both',
598
+ which='both',
599
+ bottom=False,
600
+ left=False,
601
+ labelleft=False,
602
+ labelbottom=False)
603
+
604
+ ax_pred_im_val = fig.add_subplot(235, projection='3d')
605
+ im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)')
606
+ im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
607
+ ax_pred_im_val.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
608
+ ax_pred_im_val.tick_params(axis='both',
609
+ which='both',
610
+ bottom=False,
611
+ left=False,
612
+ labelleft=False,
613
+ labelbottom=False)
614
+
615
+ ax_pred_disp_val = fig.add_subplot(236, projection='3d')
616
+ c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
617
+ im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
618
+ ax_pred_disp_val.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
619
+ d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
620
+ scale=const.QUIVER_PARAMS.arrow_scale)
621
+ ax_pred_disp_val.set_title('Pred disp map', fontsize=const.FONT_SIZE)
622
+ ax_pred_disp_val.tick_params(axis='both',
623
+ which='both',
624
+ bottom=False,
625
+ left=False,
626
+ labelleft=False,
627
+ labelbottom=False)
628
+
629
+ if filename is not None:
630
+ plt.savefig(filename, format='png') # Call before show
631
+ if not const.REMOTE:
632
+ plt.show()
633
+ else:
634
+ plt.close()
635
+ return fig
636
+
637
+
638
+ def _set_colorbar(fig, ax, im, drawedges=True):
639
+ div = make_axes_locatable(ax)
640
+ im_cax = div.append_axes('right', size='5%', pad=0.05)
641
+ im_cb = fig.colorbar(im, cax=im_cax, drawedges=drawedges, shrink=0.5, orientation='vertical')
642
+ im_cb.ax.tick_params(labelsize=5)
643
+
644
+ return im_cb
645
+
646
+
647
+ def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=const.QUIVER_PARAMS.spacing):
648
+ if isinstance(disp_map, tf.Tensor):
649
+ if tf.executing_eagerly():
650
+ disp_map = disp_map.numpy()
651
+ else:
652
+ disp_map = disp_map.eval()
653
+ dx = disp_map[..., const.H_DISP]
654
+ dy = disp_map[..., const.W_DISP]
655
+ if dim > 2:
656
+ dz = disp_map[..., const.D_DISP]
657
+
658
+ img_size_x = disp_map.shape[const.H_DISP]
659
+ img_size_y = disp_map.shape[const.W_DISP]
660
+ if dim > 2:
661
+ img_size_z = disp_map.shape[const.D_DISP]
662
+
663
+ if dim > 2:
664
+ s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz))
665
+ s = np.reshape(s, [img_size_x, img_size_y, img_size_z])
666
+
667
+ cx, cy, cz = np.meshgrid(list(range(0, img_size_x)), list(range(0, img_size_y)), list(range(0, img_size_z)),
668
+ indexing='ij')
669
+ c = [cx[::spc, ::spc, ::spc], cy[::spc, ::spc, ::spc], cz[::spc, ::spc, ::spc]]
670
+ d = [dx[::spc, ::spc, ::spc], dy[::spc, ::spc, ::spc], dz[::spc, ::spc, ::spc]]
671
+ else:
672
+ s = np.sqrt(np.square(dx) + np.square(dy))
673
+ s = np.reshape(s, [img_size_x, img_size_y])
674
+
675
+ cx, cy = np.meshgrid(list(range(0, img_size_x)), list(range(0, img_size_y)))
676
+ c = [cx[::spc, ::spc], cy[::spc, ::spc]]
677
+ d = [dx[::spc, ::spc], dy[::spc, ::spc]]
678
+
679
+ return c, d, s
680
+
681
+
682
+ def _prepare_colormap(disp_map: np.ndarray):
683
+ if isinstance(disp_map, tf.Tensor):
684
+ disp_map = disp_map.eval()
685
+ dx = disp_map[:, :, 0]
686
+ dy = disp_map[:, :, 1]
687
+
688
+ mod_img = np.zeros_like(dx)
689
+
690
+ for i in range(dx.shape[0]):
691
+ for j in range(dx.shape[1]):
692
+ vec = np.asarray([dx[i, j], dy[i, j]])
693
+ mod_img[i, j] = np.linalg.norm(vec, ord=2)
694
+
695
+ p_l, p_h = np.percentile(mod_img, (0, 100))
696
+ mod_img = rescale_intensity(mod_img, in_range=(p_l, p_h), out_range=(0, 255))
697
+
698
+ return mod_img
699
+
700
+
701
+ def plot_input_data(fix_img, mov_img, img_size=(64, 64), title=None, filename=None):
702
+ num_samples = fix_img.shape[0]
703
+
704
+ if num_samples != 16 and num_samples != 32:
705
+ raise ValueError('Only batches of 16 or 32 samples!')
706
+
707
+ fig, ax = plt.subplots(nrows=4 if num_samples == 16 else 8, ncols=4)
708
+ ncol = 0
709
+ nrow = 0
710
+ black_col = np.ones([img_size[0], 0])
711
+ for sample in range(num_samples):
712
+ combined_img = np.hstack([fix_img[sample, :, :, 0], black_col, mov_img[sample, :, :, 0]])
713
+ ax[nrow, ncol].imshow(combined_img, cmap='Greys')
714
+ ax[nrow, ncol].set_ylabel('#{}'.format(sample))
715
+ ax[nrow, ncol].tick_params(axis='both',
716
+ which='both',
717
+ bottom=False,
718
+ left=False,
719
+ labelleft=False,
720
+ labelbottom=False)
721
+ ncol += 1
722
+ if ncol >= 4:
723
+ ncol = 0
724
+ nrow += 1
725
+
726
+ if title is not None:
727
+ fig.suptitle(title)
728
+
729
+ if filename is not None:
730
+ plt.savefig(filename, format='png') # Call before show
731
+ if not const.REMOTE:
732
+ plt.show()
733
+ else:
734
+ plt.close()
735
+ return fig
736
+
737
+
738
+ def plot_dataset_orthographic_views(view_sets: [[np.ndarray]]):
739
+ """
740
+
741
+ :param views_fix: Expected order: top, front, left
742
+ :param views_mov: Expected order: top, front, left
743
+ :return:
744
+ """
745
+ nrows = len(view_sets)
746
+ fig, ax = plt.subplots(nrows=nrows, ncols=3)
747
+ labels = ['top', 'front', 'left']
748
+ for nrow in range(nrows):
749
+ for ncol in range(3):
750
+ if nrows == 1:
751
+ ax[ncol].imshow(view_sets[nrow][ncol][:, :, 0])
752
+ ax[ncol].set_title('Fix ' + labels[ncol])
753
+ ax[ncol].tick_params(axis='both',
754
+ which='both',
755
+ bottom=False,
756
+ left=False,
757
+ labelleft=False,
758
+ labelbottom=False)
759
+
760
+ else:
761
+ ax[nrow, ncol].imshow(view_sets[nrow][ncol][:, :, 0])
762
+ ax[nrow, ncol].set_title('Fix ' + labels[ncol])
763
+ ax[nrow, ncol].tick_params(axis='both',
764
+ which='both',
765
+ bottom=False,
766
+ left=False,
767
+ labelleft=False,
768
+ labelbottom=False)
769
+
770
+ plt.show()
771
+ return fig
772
+
773
+
774
+ def plot_compare_2d_images(img1, img2, img1_name='img1', img2_name='img2'):
775
+ fig, ax = plt.subplots(nrows=1, ncols=2)
776
+ ax[0].imshow(img1[:, :, 0])
777
+ ax[0].set_title(img1_name)
778
+ ax[0].tick_params(axis='both',
779
+ which='both',
780
+ bottom=False,
781
+ left=False,
782
+ labelleft=False,
783
+ labelbottom=False)
784
+
785
+ ax[1].imshow(img2[:, :, 0])
786
+ ax[1].set_title(img2_name)
787
+ ax[1].tick_params(axis='both',
788
+ which='both',
789
+ bottom=False,
790
+ left=False,
791
+ labelleft=False,
792
+ labelbottom=False)
793
+
794
+ plt.show()
795
+ return fig
796
+
797
+
798
+ def plot_dataset_3d(img_sets):
799
+ from mpl_toolkits.mplot3d import Axes3D
800
+ fig = plt.figure()
801
+ ax = fig.add_subplot(111, projection='3d')
802
+
803
+ for idx in range(len(img_sets)):
804
+ ax = _plot_3d(img_sets[idx], ax=ax, name='Set {}'.format(idx))
805
+
806
+ plt.show()
807
+ return fig
808
+
809
+
810
+ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None):
811
+ num_rows = fix_img_batch.shape[0]
812
+ img_size = fix_img_batch.shape[1:3]
813
+ if fig is not None:
814
+ fig.clear()
815
+ plt.figure(fig.number)
816
+ ax = fig.add_subplot(nrows=num_rows, ncols=4, dpi=const.DPI)
817
+ else:
818
+ fig, ax = plt.subplots(nrows=num_rows, ncols=4, dpi=const.DPI)
819
+
820
+ for row in range(num_rows):
821
+ fix_img = fix_img_batch[row, :, :, 0]
822
+ mov_img = mov_img_batch[row, :, :, 0]
823
+ disp_map = disp_map_batch[row, :, :, :]
824
+ pred_img = pred_img_batch[row, :, :, 0]
825
+ ax[row, 0].imshow(fix_img)
826
+ ax[row, 0].tick_params(axis='both',
827
+ which='both',
828
+ bottom=False,
829
+ left=False,
830
+ labelleft=False,
831
+ labelbottom=False)
832
+ ax[row, 1].imshow(mov_img)
833
+ ax[row, 1].tick_params(axis='both',
834
+ which='both',
835
+ bottom=False,
836
+ left=False,
837
+ labelleft=False,
838
+ labelbottom=False)
839
+
840
+ cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
841
+ disp_map_color = _prepare_colormap(disp_map)
842
+ ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal')
843
+ ax[row, 2].quiver(cx.eval(), cy.eval(), dx.eval(), dy.eval(), units='xy', scale=const.QUIVER_PARAMS.arrow_scale)
844
+ ax[row, 2].figure.set_size_inches(img_size)
845
+ ax[row, 2].tick_params(axis='both',
846
+ which='both',
847
+ bottom=False,
848
+ left=False,
849
+ labelleft=False,
850
+ labelbottom=False)
851
+
852
+ ax[row, 3].tick_params(axis='both',
853
+ which='both',
854
+ bottom=False,
855
+ left=False,
856
+ labelleft=False,
857
+ labelbottom=False)
858
+
859
+ plt.axis('off')
860
+ ax[0, 0].set_title('Fixed img ($I_f$)', fontsize=const.FONT_SIZE)
861
+ ax[0, 1].set_title('Moving img ($I_m$)', fontsize=const.FONT_SIZE)
862
+ ax[0, 2].set_title('Displacement map ($\delta$)', fontsize=const.FONT_SIZE)
863
+ ax[0, 3].set_title('Updated $I_m$', fontsize=const.FONT_SIZE)
864
+
865
+ if filename is not None:
866
+ plt.savefig(filename, format='png') # Call before show
867
+ if not const.REMOTE:
868
+ plt.show()
869
+ else:
870
+ plt.close()
871
+ return fig
872
+
873
+
874
+ def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=None):
875
+ if fig is not None:
876
+ fig.clear()
877
+ plt.figure(fig.number)
878
+ else:
879
+ fig = plt.figure(dpi=const.DPI)
880
+
881
+ ax0 = fig.add_subplot(221)
882
+ im0 = ax0.imshow(fix_img[..., 0])
883
+ ax0.tick_params(axis='both',
884
+ which='both',
885
+ bottom=False,
886
+ left=False,
887
+ labelleft=False,
888
+ labelbottom=False)
889
+ ax1 = fig.add_subplot(222)
890
+ im1 = ax1.imshow(mov_img[..., 0])
891
+ ax1.tick_params(axis='both',
892
+ which='both',
893
+ bottom=False,
894
+ left=False,
895
+ labelleft=False,
896
+ labelbottom=False)
897
+
898
+ cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
899
+ disp_map_color = _prepare_colormap(disp_map)
900
+ ax2 = fig.add_subplot(223)
901
+ im2 = ax2.imshow(s, interpolation='none', aspect='equal')
902
+
903
+ ax2.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
904
+ # ax2.figure.set_size_inches(img_size)
905
+ ax2.tick_params(axis='both',
906
+ which='both',
907
+ bottom=False,
908
+ left=False,
909
+ labelleft=False,
910
+ labelbottom=False)
911
+
912
+ ax3 = fig.add_subplot(224)
913
+ dif = fix_img[..., 0] - mov_img[..., 0]
914
+ im3 = ax3.imshow(dif)
915
+ ax3.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
916
+ ax3.tick_params(axis='both',
917
+ which='both',
918
+ bottom=False,
919
+ left=False,
920
+ labelleft=False,
921
+ labelbottom=False)
922
+
923
+ plt.axis('off')
924
+ ax0.set_title('Fixed img ($I_f$)', fontsize=const.FONT_SIZE)
925
+ ax1.set_title('Moving img ($I_m$)', fontsize=const.FONT_SIZE)
926
+ ax2.set_title('Displacement map', fontsize=const.FONT_SIZE)
927
+ ax3.set_title('Fix - Mov', fontsize=const.FONT_SIZE)
928
+
929
+ im0_cb = _set_colorbar(fig, ax0, im0, False)
930
+ im1_cb = _set_colorbar(fig, ax1, im1, False)
931
+ disp_cb = _set_colorbar(fig, ax2, im2, False)
932
+ im3_cb = _set_colorbar(fig, ax3, im3, False)
933
+
934
+ if filename is not None:
935
+ plt.savefig(filename, format='png') # Call before show
936
+ if not const.REMOTE:
937
+ plt.show()
938
+ else:
939
+ plt.close()
940
+
941
+ return fig
942
+
943
+
944
+ def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coords, disp_map, mask, fix_img, mov_img,
945
+ filename=None, fig=None):
946
+ if fig is not None:
947
+ fig.clear()
948
+ plt.figure(fig.number)
949
+ else:
950
+ fig = plt.figure()
951
+
952
+ ax_grid = fig.add_subplot(231)
953
+ ax_grid.set_title('Grids', fontsize=const.FONT_SIZE)
954
+ ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20)
955
+ ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1)
956
+ ax_grid.tick_params(axis='both',
957
+ which='both',
958
+ bottom=False,
959
+ left=False,
960
+ labelleft=False,
961
+ labelbottom=False)
962
+
963
+ ax_grid.scatter(target_coords[:, 0], target_coords[:, 1], marker='+', c='b', s=20)
964
+ ax_grid.scatter(disp_coords[:, 0], disp_coords[:, 1], marker='.', c='b', s=1)
965
+
966
+ ax_grid.set_aspect('equal')
967
+
968
+ ax_disp = fig.add_subplot(232)
969
+ ax_disp.set_title('Displacement map', fontsize=const.FONT_SIZE)
970
+ cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
971
+ ax_disp.imshow(s, interpolation='none', aspect='equal')
972
+ ax_disp.tick_params(axis='both',
973
+ which='both',
974
+ bottom=False,
975
+ left=False,
976
+ labelleft=False,
977
+ labelbottom=False)
978
+
979
+ ax_mask = fig.add_subplot(233)
980
+ ax_mask.set_title('Mask', fontsize=const.FONT_SIZE)
981
+ ax_mask.imshow(mask)
982
+ ax_mask.tick_params(axis='both',
983
+ which='both',
984
+ bottom=False,
985
+ left=False,
986
+ labelleft=False,
987
+ labelbottom=False)
988
+
989
+ ax_fix = fig.add_subplot(234)
990
+ ax_fix.set_title('Fix image', fontsize=const.FONT_SIZE)
991
+ ax_fix.imshow(fix_img[..., 0])
992
+ ax_fix.tick_params(axis='both',
993
+ which='both',
994
+ bottom=False,
995
+ left=False,
996
+ labelleft=False,
997
+ labelbottom=False)
998
+
999
+ ax_mov = fig.add_subplot(235)
1000
+ ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
1001
+ ax_mov.imshow(mov_img[..., 0])
1002
+ ax_mov.tick_params(axis='both',
1003
+ which='both',
1004
+ bottom=False,
1005
+ left=False,
1006
+ labelleft=False,
1007
+ labelbottom=False)
1008
+
1009
+ ax_dif = fig.add_subplot(236)
1010
+ ax_dif.set_title('Fix - Moving image', fontsize=const.FONT_SIZE)
1011
+ ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
1012
+ ax_dif.tick_params(axis='both',
1013
+ which='both',
1014
+ bottom=False,
1015
+ left=False,
1016
+ labelleft=False,
1017
+ labelbottom=False)
1018
+ legend_elems = [Line2D([0], [0], color=cmap_bin(0), lw=2),
1019
+ Line2D([0], [0], color=cmap_bin(2), lw=2)]
1020
+
1021
+ ax_dif.legend(legend_elems, ['Mov', 'Fix'], loc='upper left', bbox_to_anchor=(0., 0., 1., 0.),
1022
+ ncol=2, mode='expand')
1023
+
1024
+ if filename is not None:
1025
+ plt.savefig(filename, format='png') # Call before show
1026
+ if not const.REMOTE:
1027
+ plt.show()
1028
+
1029
+ return fig
1030
+
1031
+
1032
+ def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=None):
1033
+ if fig is not None:
1034
+ fig.clear()
1035
+ plt.figure(fig.number)
1036
+ else:
1037
+ fig = plt.figure()
1038
+
1039
+ ax_d_m_f = fig.add_subplot(131)
1040
+ ax_d_m_f.set_title('Disp M->F', fontsize=const.FONT_SIZE)
1041
+ cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f)
1042
+ ax_d_m_f.imshow(s, interpolation='none', aspect='equal')
1043
+ ax_d_m_f.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
1044
+ ax_d_m_f.tick_params(axis='both',
1045
+ which='both',
1046
+ bottom=False,
1047
+ left=False,
1048
+ labelleft=False,
1049
+ labelbottom=False)
1050
+
1051
+ ax_d_f_m = fig.add_subplot(132)
1052
+ ax_d_f_m.set_title('Disp F->M', fontsize=const.FONT_SIZE)
1053
+ cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m)
1054
+ ax_d_f_m.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
1055
+ ax_d_f_m.imshow(s, interpolation='none', aspect='equal')
1056
+ ax_d_f_m.tick_params(axis='both',
1057
+ which='both',
1058
+ bottom=False,
1059
+ left=False,
1060
+ labelleft=False,
1061
+ labelbottom=False)
1062
+
1063
+ ax_dif = fig.add_subplot(133)
1064
+ ax_dif.set_title('Fix - Moving image', fontsize=const.FONT_SIZE)
1065
+ ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
1066
+ ax_dif.tick_params(axis='both',
1067
+ which='both',
1068
+ bottom=False,
1069
+ left=False,
1070
+ labelleft=False,
1071
+ labelbottom=False)
1072
+
1073
+ legend_elems = [Line2D([0], [0], color=cmap_bin(0), lw=2),
1074
+ Line2D([0], [0], color=cmap_bin(2), lw=2)]
1075
+
1076
+ ax_dif.legend(legend_elems, ['Mov', 'Fix'], loc='upper left', bbox_to_anchor=(0., 0., 1., 0.),
1077
+ ncol=2, mode='expand')
1078
+
1079
+ if filename is not None:
1080
+ plt.savefig(filename, format='png') # Call before show
1081
+ if not const.REMOTE:
1082
+ plt.show()
1083
+ else:
1084
+ plt.close()
1085
+
1086
+ return fig
1087
+
1088
+
1089
+ def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='.', save_file=True):
1090
+ # list_imgs[0]: fix image
1091
+ # list_imgs[1]: moving image
1092
+ # list_imgs[2]: prediction scale 1
1093
+ # list_imgs[3]: prediction scale 2
1094
+ # list_imgs[4]: prediction scale 3
1095
+ # list_imgs[5]: disp map scale 1
1096
+ # list_imgs[6]: disp map scale 2
1097
+ # list_imgs[7]: disp map scale 3
1098
+ num_imgs = len(list_imgs)
1099
+ num_preds = (num_imgs - 2) // 2
1100
+ num_cols = num_preds + 1
1101
+ # 3D
1102
+ # TRAINING
1103
+ fig = plt.figure(figsize=(12.8, 10.24))
1104
+ fig.tight_layout(pad=5.0)
1105
+ ax = fig.add_subplot(2, num_cols, 1, projection='3d')
1106
+ ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
1107
+ ax.set_title('Fix image', fontsize=const.FONT_SIZE)
1108
+ _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1109
+
1110
+ for i in range(2, num_preds+2):
1111
+ ax = fig.add_subplot(2, num_cols, i, projection='3d')
1112
+ ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
1113
+ ax.voxels(list_imgs[i][0, ..., 0] > 0.0, facecolors='green', edgecolors='green', label='Pred_{}'.format(i - 1))
1114
+ ax.set_title('Pred. #{}'.format(i - 1))
1115
+ _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1116
+
1117
+ ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d')
1118
+ ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
1119
+ ax.set_title('Fix image', fontsize=const.FONT_SIZE)
1120
+ _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1121
+
1122
+ for i in range(num_preds+2, 2 * num_preds + 2):
1123
+ ax = fig.add_subplot(2, num_cols, i + 1, projection='3d')
1124
+ c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3)
1125
+ ax.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
1126
+ d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
1127
+ norm=True)
1128
+ ax.set_title('Disp. #{}'.format(i - 5))
1129
+ _square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
1130
+
1131
+ fig.suptitle(fig_title, fontsize=const.FONT_SIZE)
1132
+
1133
+ if save_file:
1134
+ plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show
1135
+ if not const.REMOTE:
1136
+ plt.show()
1137
+ else:
1138
+ plt.close()
1139
+ return fig
1140
+
1141
+
1142
+ def _square_3d_plot(X, Y, Z, ax):
1143
+ max_range = np.array([X.max() - X.min(), Y.max() - Y.min(), Z.max() - Z.min()]).max() / 2.0
1144
+
1145
+ mid_x = (X.max() + X.min()) * 0.5
1146
+ mid_y = (Y.max() + Y.min()) * 0.5
1147
+ mid_z = (Z.max() + Z.min()) * 0.5
1148
+ ax.set_xlim(mid_x - max_range, mid_x + max_range)
1149
+ ax.set_ylim(mid_y - max_range, mid_y + max_range)
1150
+ ax.set_zlim(mid_z - max_range, mid_z + max_range)
1151
+
EvaluationScripts/evaluation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
11
+ import voxelmorph as vxm
12
+ import neurite as ne
13
+ import h5py
14
+ from datetime import datetime
15
+
16
+ if PYCHARM_EXEC:
17
+ import scripts.tf.myScript_constants as const
18
+ from scripts.tf.myScript_data_generator import DataGeneratorManager
19
+ from scripts.tf.myScript_utils import save_nifti, try_mkdir
20
+ else:
21
+ import myScript_constants as const
22
+ from myScript_data_generator import DataGeneratorManager
23
+ from myScript_utils import save_nifti, try_mkdir
24
+
25
+ os.environ['CUDA_DEVICE_ORDER'] = const.DEV_ORDER
26
+ os.environ['CUDA_VISIBLE_DEVICES'] = const.GPU_NUM # Check availability before running using 'nvidia-smi'
27
+
28
+ const.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_LITS'
29
+ const.BATCH_SIZE = 8
30
+ const.LIMIT_NUM_SAMPLES = None
31
+ const.EPOCHS = 1000
32
+
33
+ if PYCHARM_EXEC:
34
+ path_prefix = os.path.join('scripts', 'tf')
35
+ else:
36
+ path_prefix = ''
37
+
38
+ # Load data
39
+ # Build data generator
40
+ data_generator = DataGeneratorManager(const.TRAINING_DATASET, const.BATCH_SIZE, True, const.LIMIT_NUM_SAMPLES,
41
+ 1 - const.TRAINING_PERC, voxelmorph=True)
42
+
43
+ test_generator = data_generator.get_generator('validation')
44
+ test_fix_img, test_mov_img, *_ = test_generator.get_random_sample(1)
45
+
46
+ # Build model
47
+ in_shape = test_generator.get_input_shape()[1:-1]
48
+ enc_features = [16, 32, 32, 32]# const.ENCODER_FILTERS
49
+ dec_features = [32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
50
+ nb_features = [enc_features, dec_features]
51
+ vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features, int_steps=0)
52
+
53
+ weight_files = [os.path.join(path_prefix, 'checkpoints', f) for f in os.listdir(os.path.join(path_prefix, 'checkpoints')) if 'weights' in f]
54
+ weight_files.sort()
55
+ pred_folder = os.path.join(path_prefix, 'predictions')
56
+ try_mkdir(pred_folder)
57
+
58
+ # Prepare the images
59
+ fix_img = test_fix_img.squeeze()
60
+ mid_slice_fix = [np.take(fix_img, fix_img.shape[d]//2, axis=d) for d in range(3)]
61
+ mid_slice_fix[1] = np.rot90(mid_slice_fix[1], 1)
62
+ mid_slice_fix[2] = np.rot90(mid_slice_fix[2], -1)
63
+
64
+ mid_mov_slice = list()
65
+ mid_disp_slice = list()
66
+ # Due to slicing, it can happen that the last file is not tested. So include it always
67
+ slice = 5
68
+ for f in weight_files[:-1:slice] + [weight_files[-1]]:
69
+ name = os.path.split(f)[-1].split('.h5')[0]
70
+ vxm_model.load_weights(f)
71
+ pred_img, pred_disp = vxm_model.predict([test_mov_img, test_fix_img])
72
+ pred_img = pred_img.squeeze()
73
+
74
+ mov_slices = [np.take(pred_img, pred_img.shape[d]//2, axis=d) for d in range(3)]
75
+ mov_slices[1] = np.rot90(mov_slices[1], 1)
76
+ mov_slices[2] = np.rot90(mov_slices[2], -1)
77
+ mid_mov_slice.append(mov_slices)
78
+
79
+
80
+
81
+
82
+
83
+ # Get sample for testing
84
+ test_sample = test_generator.get_single_sample()
TrainingScripts/Train_2d.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
8
+
9
+ import tensorflow as tf
10
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
11
+ import voxelmorph as vxm
12
+ from datetime import datetime
13
+
14
+ import DeepDeformationMapRegistration.utils.constants as C
15
+ from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
16
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir
17
+ from DeepDeformationMapRegistration.losses import HausdorffDistance
18
+
19
+
20
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
21
+ os.environ['CUDA_VISIBLE_DEVICES'] = C.GPU_NUM # Check availability before running using 'nvidia-smi'
22
+
23
+ C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/ov_dataset/training'
24
+ C.BATCH_SIZE = 256
25
+ C.LIMIT_NUM_SAMPLES = None
26
+ C.EPOCHS = 10000
27
+
28
+ if PYCHARM_EXEC:
29
+ path_prefix = os.path.join('scripts', 'tf')
30
+ else:
31
+ path_prefix = ''
32
+
33
+ # Load data
34
+ # Build data generator
35
+ sample_list = [os.path.join(C.TRAINING_DATASET, f) for f in os.listdir(C.TRAINING_DATASET) if
36
+ f.startswith('sample')]
37
+ sample_list.sort()
38
+
39
+ data_generator = DataGeneratorManager2D(sample_list[:C.LIMIT_NUM_SAMPLES],
40
+ C.BATCH_SIZE, C.TRAINING_PERC,
41
+ (64, 64, 1),
42
+ fix_img_tag='dilated/input/fix',
43
+ mov_img_tag='dilated/input/mov'
44
+ )
45
+
46
+ # Build model
47
+ in_shape = data_generator.train_generator.input_shape[:-1]
48
+ enc_features = [32, 32, 32, 32, 32, 32] # const.ENCODER_FILTERS
49
+ dec_features = [32, 32, 32, 32, 32, 32, 32, 16] # const.ENCODER_FILTERS[::-1]
50
+ nb_features = [enc_features, dec_features]
51
+ vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features, int_steps=0)
52
+
53
+ # Losses and loss weights
54
+ def comb_loss(y_true, y_pred):
55
+ return 1e-3 * HausdorffDistance(ndim=2, nerosion=2).loss(y_true, y_pred) + vxm.losses.Dice().loss(y_true, y_pred)
56
+
57
+
58
+ losses = [comb_loss, vxm.losses.Grad('l2').loss]
59
+ loss_weights = [1, 0.01]
60
+
61
+ # Compile the model
62
+ vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)
63
+
64
+ # Train
65
+ output_folder = os.path.join('train_2d_dice_hausdorff_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
66
+ try_mkdir(output_folder)
67
+ try_mkdir(os.path.join(output_folder, 'checkpoints'))
68
+ try_mkdir(os.path.join(output_folder, 'tensorboard'))
69
+ my_callbacks = [
70
+ # EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
71
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
72
+ save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
73
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
74
+ save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
75
+ # CSVLogger(train_log_name, ';'),
76
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
77
+ TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
78
+ batch_size=C.BATCH_SIZE, write_images=True, histogram_freq=10, update_freq='epoch',
79
+ write_grads=True),
80
+ EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
81
+ ]
82
+ hist = vxm_model.fit_generator(data_generator.train_generator,
83
+ epochs=C.EPOCHS,
84
+ validation_data=data_generator.validation_generator,
85
+ verbose=2,
86
+ callbacks=my_callbacks)
TrainingScripts/Train_2d_uncertaintyWeighting.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ currentdir = os.path.dirname(os.path.realpath(__file__))
4
+ parentdir = os.path.dirname(currentdir)
5
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
+
7
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
12
+ import voxelmorph as vxm
13
+ import neurite as ne
14
+ import h5py
15
+ from datetime import datetime
16
+
17
+ import DeepDeformationMapRegistration.utils.constants as C
18
+ from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
19
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir
20
+ from DeepDeformationMapRegistration.losses import HausdorffDistance
21
+ from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
+
23
+
24
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
25
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1' # const.GPU_NUM # Check availability before running using 'nvidia-smi'
26
+
27
+ C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/ov_dataset/training'
28
+ C.BATCH_SIZE = 256
29
+ C.LIMIT_NUM_SAMPLES = None
30
+ C.EPOCHS = 10000
31
+
32
+ if PYCHARM_EXEC:
33
+ path_prefix = os.path.join('scripts', 'tf')
34
+ else:
35
+ path_prefix = ''
36
+
37
+ # Load data
38
+ # Build data generator
39
+ sample_list = [os.path.join(C.TRAINING_DATASET, f) for f in os.listdir(C.TRAINING_DATASET) if
40
+ f.startswith('sample')]
41
+ sample_list.sort()
42
+
43
+ data_generator = DataGeneratorManager2D(sample_list[:C.LIMIT_NUM_SAMPLES],
44
+ C.BATCH_SIZE, C.TRAINING_PERC,
45
+ (64, 64, 1),
46
+ fix_img_tag='dilated/input/fix',
47
+ mov_img_tag='dilated/input/mov',
48
+ multi_loss=True,
49
+ )
50
+
51
+ # Build model
52
+ in_shape_img, in_shape_grad = data_generator.train_generator.input_shape
53
+ enc_features = [32, 32, 32, 32, 32, 32] # const.ENCODER_FILTERS
54
+ dec_features = [32, 32, 32, 32, 32, 32, 32, 16] # const.ENCODER_FILTERS[::-1]
55
+ nb_features = [enc_features, dec_features]
56
+ vxm_model = vxm.networks.VxmDense(inshape=in_shape_img[:-1], nb_unet_features=nb_features, int_steps=0)
57
+
58
+ #moving = tf.keras.Input(shape=in_shape_img, name='multiLoss_moving_input', dtype=tf.float32)
59
+ #fixed = tf.keras.Input(shape=in_shape_img, name='multiLoss_fixed_input', dtype=tf.float32)
60
+ grad = tf.keras.Input(shape=(*in_shape_img[:-1], 2), name='multiLoss_grad_input', dtype=tf.float32)
61
+
62
+ def dice_loss(y_true, y_pred):
63
+ # Dice().loss returns -Dice score
64
+ return 1 + vxm.losses.Dice().loss(y_true, y_pred)
65
+
66
+ #fixed_pred, dm_pred = vxm_model([moving, fixed])
67
+ multiLoss = UncertaintyWeighting(num_loss_fns=2,
68
+ num_reg_fns=1,
69
+ loss_fns=[HausdorffDistance(2, 2).loss, dice_loss],
70
+ reg_fns=[vxm.losses.Grad('l2').loss],
71
+ prior_loss_w=[1., 1.],
72
+ prior_reg_w=[0.01],
73
+ name='MultiLossLayer')
74
+ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], vxm_model.references.y_source, vxm_model.references.y_source, grad, vxm_model.references.pos_flow])
75
+
76
+ full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
77
+
78
+ # Compile the model
79
+ full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
80
+
81
+ # Train
82
+ output_folder = os.path.join('train_2d_multiloss_haussdorf_dice_grad' + datetime.now().strftime("%H%M%S-%d%m%Y"))
83
+ try_mkdir(output_folder)
84
+ try_mkdir(os.path.join(output_folder, 'checkpoints'))
85
+ try_mkdir(os.path.join(output_folder, 'tensorboard'))
86
+ my_callbacks = [
87
+ # EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
88
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
89
+ save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
90
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
91
+ save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
92
+ # CSVLogger(train_log_name, ';'),
93
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
94
+ TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
95
+ batch_size=C.BATCH_SIZE, write_images=True, histogram_freq=10, update_freq='epoch',
96
+ write_grads=True),
97
+ EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
98
+ ]
99
+ hist = full_model.fit_generator(data_generator.train_generator,
100
+ epochs=C.EPOCHS,
101
+ validation_data=data_generator.validation_generator,
102
+ verbose=2,
103
+ callbacks=my_callbacks)
TrainingScripts/Train_3d.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
11
+ import voxelmorph as vxm
12
+ import neurite as ne
13
+ import h5py
14
+ from datetime import datetime
15
+
16
+ import DeepDeformationMapRegistration.utils.constants as C
17
+ from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
+
20
+
21
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
22
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Check availability before running using 'nvidia-smi'
23
+
24
+ C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_LITS'
25
+ C.BATCH_SIZE = 2
26
+ C.LIMIT_NUM_SAMPLES = None
27
+ C.EPOCHS = 10000
28
+
29
+ # Load data
30
+ # Build data generator
31
+ data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
32
+ 1 - C.TRAINING_PERC, voxelmorph=True)
33
+
34
+ train_generator = data_generator.get_generator('train')
35
+ validation_generator = data_generator.get_generator('validation')
36
+
37
+
38
+ # Build model
39
+ in_shape = train_generator.get_input_shape()[1:-1]
40
+ enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
41
+ dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
42
+ nb_features = [enc_features, dec_features]
43
+ vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features, int_steps=7)
44
+
45
+
46
+ # Losses and loss weights
47
+
48
+ def comb_loss(y_true, y_pred):
49
+ return vxm.losses.MSE().loss(y_true, y_pred) + vxm.losses.NCC().loss(y_true, y_pred)
50
+
51
+
52
+ losses = [comb_loss, vxm.losses.Grad('l2').loss]
53
+ loss_weights = [1., 0.01]
54
+
55
+ # Compile the model
56
+ vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)
57
+
58
+ # Train
59
+ output_folder = os.path.join('train_3d_mse_ncc_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
60
+ try_mkdir(output_folder)
61
+ try_mkdir(os.path.join(output_folder, 'checkpoints'))
62
+ try_mkdir(os.path.join(output_folder, 'tensorboard'))
63
+ my_callbacks = [
64
+ #EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
65
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
66
+ save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
67
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
68
+ save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
69
+ # CSVLogger(train_log_name, ';'),
70
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
71
+ TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
72
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=10, update_freq='epoch',
73
+ write_grads=True),
74
+ EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
75
+ ]
76
+ hist = vxm_model.fit(train_generator, epochs=C.EPOCHS, validation_data=validation_generator, verbose=2, callbacks=my_callbacks)
TrainingScripts/Train_3d_weaklySupervised.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
11
+ import voxelmorph as vxm
12
+ import neurite as ne
13
+ import h5py
14
+ from datetime import datetime
15
+
16
+ import DeepDeformationMapRegistration.utils.constants as C
17
+ from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
18
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir
19
+ from DeepDeformationMapRegistration.networks import VxmWeaklySupervised
20
+ from DeepDeformationMapRegistration.losses import HausdorffDistance
21
+ from DeepDeformationMapRegistration.layers import UncertaintyWeighting
22
+
23
+
24
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
25
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
26
+
27
+ C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_vessels'
28
+ C.BATCH_SIZE = 2
29
+ C.LIMIT_NUM_SAMPLES = None
30
+ C.EPOCHS = 10000
31
+
32
+ # Load data
33
+ # Build data generator
34
+
35
+ data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
36
+ 1 - C.TRAINING_PERC, voxelmorph=True, segmentations=True)
37
+
38
+ train_generator = data_generator.get_generator('train')
39
+ validation_generator = data_generator.get_generator('validation')
40
+
41
+
42
+ # Build model
43
+ in_shape = train_generator.get_input_shape()[1:-1]
44
+ enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
45
+ dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
46
+ nb_features = [enc_features, dec_features]
47
+ vxm_model = VxmWeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
48
+
49
+ # Losses and loss weights
50
+
51
+ grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
52
+ fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
53
+ def dice_loss(y_true, y_pred):
54
+ # Dice().loss returns -Dice score
55
+ return 1 + vxm.losses.Dice().loss(y_true, y_pred)
56
+
57
+ multiLoss = UncertaintyWeighting(num_loss_fns=3,
58
+ num_reg_fns=1,
59
+ loss_fns=[HausdorffDistance(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
60
+ reg_fns=[vxm.losses.Grad('l2').loss],
61
+ prior_loss_w=[1., 1., 1.],
62
+ prior_reg_w=[0.01],
63
+ name='MultiLossLayer')
64
+ loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
65
+ vxm_model.references.pred_segm, vxm_model.references.pred_segm, vxm_model.references.pred_img,
66
+ grad,
67
+ vxm_model.references.pos_flow])
68
+
69
+ full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
70
+
71
+ # Compile the model
72
+ full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
73
+
74
+ # Train
75
+ output_folder = os.path.join('train_3d_multiloss_segm_haus_dice_ncc_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
76
+ try_mkdir(output_folder)
77
+ try_mkdir(os.path.join(output_folder, 'checkpoints'))
78
+ try_mkdir(os.path.join(output_folder, 'tensorboard'))
79
+ my_callbacks = [
80
+ #EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
81
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
82
+ save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
83
+ ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
84
+ save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
85
+ # CSVLogger(train_log_name, ';'),
86
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
87
+ TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
88
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=10, update_freq='epoch',
89
+ write_grads=True),
90
+ EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
91
+ ]
92
+ hist = full_model.fit(train_generator, epochs=C.EPOCHS, validation_data=validation_generator, verbose=2, callbacks=my_callbacks)