Commit
·
ab9857f
0
Parent(s):
Initial commit
Browse files- DeepDeformationMapRegistration/__init__.py +0 -0
- DeepDeformationMapRegistration/data_generator.py +497 -0
- DeepDeformationMapRegistration/layers.py +189 -0
- DeepDeformationMapRegistration/losses.py +47 -0
- DeepDeformationMapRegistration/networks.py +63 -0
- DeepDeformationMapRegistration/utils/acummulated_optimizer.py +57 -0
- DeepDeformationMapRegistration/utils/cmd_args_parser.py +111 -0
- DeepDeformationMapRegistration/utils/conf_file_utils.py +52 -0
- DeepDeformationMapRegistration/utils/constants.py +487 -0
- DeepDeformationMapRegistration/utils/misc.py +28 -0
- DeepDeformationMapRegistration/utils/nifty_utils.py +42 -0
- DeepDeformationMapRegistration/utils/operators +28 -0
- DeepDeformationMapRegistration/utils/user_interface.py +23 -0
- DeepDeformationMapRegistration/utils/visualization.py +1151 -0
- EvaluationScripts/evaluation.py +84 -0
- TrainingScripts/Train_2d.py +86 -0
- TrainingScripts/Train_2d_uncertaintyWeighting.py +103 -0
- TrainingScripts/Train_3d.py +76 -0
- TrainingScripts/Train_3d_weaklySupervised.py +92 -0
DeepDeformationMapRegistration/__init__.py
ADDED
File without changes
|
DeepDeformationMapRegistration/data_generator.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from tensorflow import keras
|
10 |
+
import os
|
11 |
+
import h5py
|
12 |
+
import random
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
16 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
17 |
+
|
18 |
+
|
19 |
+
class DataGeneratorManager(keras.utils.Sequence):
|
20 |
+
def __init__(self, dataset_path, batch_size=32, shuffle=True, num_samples=None, validation_split=None, validation_samples=None,
|
21 |
+
clip_range=[0., 1.], voxelmorph=False, segmentations=False,
|
22 |
+
seg_labels: dict = {'bg': 0, 'vessels': 1, 'tumour': 2, 'parenchyma': 3}):
|
23 |
+
# Get the list of files
|
24 |
+
self.__list_files = self.__get_dataset_files(dataset_path)
|
25 |
+
self.__list_files.sort()
|
26 |
+
self.__dataset_path = dataset_path
|
27 |
+
self.__shuffle = shuffle
|
28 |
+
self.__total_samples = len(self.__list_files)
|
29 |
+
self.__validation_split = validation_split
|
30 |
+
self.__clip_range = clip_range
|
31 |
+
self.__batch_size = batch_size
|
32 |
+
|
33 |
+
self.__validation_samples = validation_samples
|
34 |
+
|
35 |
+
self.__voxelmorph = voxelmorph
|
36 |
+
self.__segmentations = segmentations
|
37 |
+
self.__seg_labels = seg_labels
|
38 |
+
|
39 |
+
if num_samples is not None:
|
40 |
+
self.__num_samples = self.__total_samples if num_samples > self.__total_samples else num_samples
|
41 |
+
else:
|
42 |
+
self.__num_samples = self.__total_samples
|
43 |
+
|
44 |
+
self.__internal_idxs = np.arange(self.__num_samples)
|
45 |
+
|
46 |
+
# Split it accordingly
|
47 |
+
if validation_split is None:
|
48 |
+
self.__validation_num_samples = None
|
49 |
+
self.__validation_idxs = list()
|
50 |
+
if self.__shuffle:
|
51 |
+
random.shuffle(self.__internal_idxs)
|
52 |
+
self.__training_idxs = self.__internal_idxs
|
53 |
+
|
54 |
+
self.__validation_generator = None
|
55 |
+
else:
|
56 |
+
self.__validation_num_samples = int(np.ceil(self.__num_samples * validation_split))
|
57 |
+
if self.__shuffle:
|
58 |
+
self.__validation_idxs = np.random.choice(self.__internal_idxs, self.__validation_num_samples)
|
59 |
+
else:
|
60 |
+
self.__validation_idxs = self.__internal_idxs[0: self.__validation_num_samples]
|
61 |
+
self.__training_idxs = np.asarray([idx for idx in self.__internal_idxs if idx not in self.__validation_idxs])
|
62 |
+
# Build them DataGenerators
|
63 |
+
self.__validation_generator = DataGenerator(self, 'validation')
|
64 |
+
|
65 |
+
self.__train_generator = DataGenerator(self, 'train')
|
66 |
+
self.reshuffle_indices()
|
67 |
+
|
68 |
+
@property
|
69 |
+
def dataset_path(self):
|
70 |
+
return self.__dataset_path
|
71 |
+
|
72 |
+
@property
|
73 |
+
def dataset_list_files(self):
|
74 |
+
return self.__list_files
|
75 |
+
|
76 |
+
@property
|
77 |
+
def train_idxs(self):
|
78 |
+
return self.__training_idxs
|
79 |
+
|
80 |
+
@property
|
81 |
+
def validation_idxs(self):
|
82 |
+
return self.__validation_idxs
|
83 |
+
|
84 |
+
@property
|
85 |
+
def batch_size(self):
|
86 |
+
return self.__batch_size
|
87 |
+
|
88 |
+
@property
|
89 |
+
def clip_rage(self):
|
90 |
+
return self.__clip_range
|
91 |
+
|
92 |
+
@property
|
93 |
+
def shuffle(self):
|
94 |
+
return self.__shuffle
|
95 |
+
|
96 |
+
def get_generator_idxs(self, generator_type):
|
97 |
+
if generator_type == 'train':
|
98 |
+
return self.train_idxs
|
99 |
+
elif generator_type == 'validation':
|
100 |
+
return self.validation_idxs
|
101 |
+
else:
|
102 |
+
raise ValueError('Invalid generator type: ', generator_type)
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def __get_dataset_files(search_path):
|
106 |
+
"""
|
107 |
+
Get the path to the dataset files
|
108 |
+
:param search_path: dir path to search for the hd5 files
|
109 |
+
:return:
|
110 |
+
"""
|
111 |
+
file_list = list()
|
112 |
+
for root, dirs, files in os.walk(search_path):
|
113 |
+
file_list.sort()
|
114 |
+
for data_file in files:
|
115 |
+
file_name, extension = os.path.splitext(data_file)
|
116 |
+
if extension.lower() == '.hd5':
|
117 |
+
file_list.append(os.path.join(root, data_file))
|
118 |
+
|
119 |
+
if not file_list:
|
120 |
+
raise ValueError('No files found to train in ', search_path)
|
121 |
+
|
122 |
+
print('Found {} files in {}'.format(len(file_list), search_path))
|
123 |
+
return file_list
|
124 |
+
|
125 |
+
def reshuffle_indices(self):
|
126 |
+
if self.__validation_num_samples is None:
|
127 |
+
if self.__shuffle:
|
128 |
+
random.shuffle(self.__internal_idxs)
|
129 |
+
self.__training_idxs = self.__internal_idxs
|
130 |
+
else:
|
131 |
+
if self.__shuffle:
|
132 |
+
self.__validation_idxs = np.random.choice(self.__internal_idxs, self.__validation_num_samples)
|
133 |
+
else:
|
134 |
+
self.__validation_idxs = self.__internal_idxs[0: self.__validation_num_samples]
|
135 |
+
self.__training_idxs = np.asarray([idx for idx in self.__internal_idxs if idx not in self.__validation_idxs])
|
136 |
+
|
137 |
+
# Update the indices
|
138 |
+
self.__validation_generator.update_samples(self.__validation_idxs)
|
139 |
+
|
140 |
+
self.__train_generator.update_samples(self.__training_idxs)
|
141 |
+
|
142 |
+
def get_generator(self, type='train'):
|
143 |
+
if type.lower() == 'train':
|
144 |
+
return self.__train_generator
|
145 |
+
elif type.lower() == 'validation':
|
146 |
+
if self.__validation_generator is not None:
|
147 |
+
return self.__validation_generator
|
148 |
+
else:
|
149 |
+
raise Warning('No validation generator available. Set a non-zero validation_split to build one.')
|
150 |
+
else:
|
151 |
+
raise ValueError('Unknown dataset type "{}". Expected "train" or "validation"'.format(type))
|
152 |
+
|
153 |
+
@property
|
154 |
+
def is_voxelmorph(self):
|
155 |
+
return self.__voxelmorph
|
156 |
+
|
157 |
+
@property
|
158 |
+
def give_segmentations(self):
|
159 |
+
return self.__segmentations
|
160 |
+
|
161 |
+
@property
|
162 |
+
def seg_labels(self):
|
163 |
+
return self.__seg_labels
|
164 |
+
|
165 |
+
|
166 |
+
class DataGenerator(DataGeneratorManager):
|
167 |
+
def __init__(self, GeneratorManager: DataGeneratorManager, dataset_type='train'):
|
168 |
+
self.__complete_list_files = GeneratorManager.dataset_list_files
|
169 |
+
self.__list_files = [self.__complete_list_files[idx] for idx in GeneratorManager.get_generator_idxs(dataset_type)]
|
170 |
+
self.__batch_size = GeneratorManager.batch_size
|
171 |
+
self.__total_samples = len(self.__list_files)
|
172 |
+
self.__clip_range = GeneratorManager.clip_rage
|
173 |
+
self.__manager = GeneratorManager
|
174 |
+
self.__shuffle = GeneratorManager.shuffle
|
175 |
+
|
176 |
+
self.__seg_labels = GeneratorManager.seg_labels
|
177 |
+
|
178 |
+
self.__num_samples = len(self.__list_files)
|
179 |
+
self.__internal_idxs = np.arange(self.__num_samples)
|
180 |
+
# These indices are internal to the generator, they are not the same as the dataset_idxs!!
|
181 |
+
|
182 |
+
self.__dataset_type = dataset_type
|
183 |
+
|
184 |
+
self.__last_batch = 0
|
185 |
+
self.__batches_per_epoch = int(np.floor(len(self.__internal_idxs) / self.__batch_size))
|
186 |
+
|
187 |
+
self.__voxelmorph = GeneratorManager.is_voxelmorph
|
188 |
+
self.__segmentations = GeneratorManager.is_voxelmorph and GeneratorManager.give_segmentations
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def __get_dataset_files(search_path):
|
192 |
+
"""
|
193 |
+
Get the path to the dataset files
|
194 |
+
:param search_path: dir path to search for the hd5 files
|
195 |
+
:return:
|
196 |
+
"""
|
197 |
+
file_list = list()
|
198 |
+
for root, dirs, files in os.walk(search_path):
|
199 |
+
for data_file in files:
|
200 |
+
file_name, extension = os.path.splitext(data_file)
|
201 |
+
if extension.lower() == '.hd5':
|
202 |
+
file_list.append(os.path.join(root, data_file))
|
203 |
+
|
204 |
+
if not file_list:
|
205 |
+
raise ValueError('No files found to train in ', search_path)
|
206 |
+
|
207 |
+
print('Found {} files in {}'.format(len(file_list), search_path))
|
208 |
+
return file_list
|
209 |
+
|
210 |
+
def update_samples(self, new_sample_idxs):
|
211 |
+
self.__list_files = [self.__complete_list_files[idx] for idx in new_sample_idxs]
|
212 |
+
self.__num_samples = len(self.__list_files)
|
213 |
+
self.__internal_idxs = np.arange(self.__num_samples)
|
214 |
+
|
215 |
+
def on_epoch_end(self):
|
216 |
+
"""
|
217 |
+
To be executed at the end of each epoch. Reshuffle the assigned samples
|
218 |
+
:return:
|
219 |
+
"""
|
220 |
+
if self.__shuffle:
|
221 |
+
random.shuffle(self.__internal_idxs)
|
222 |
+
self.__last_batch = 0
|
223 |
+
|
224 |
+
def __len__(self):
|
225 |
+
"""
|
226 |
+
Number of batches per epoch
|
227 |
+
:return:
|
228 |
+
"""
|
229 |
+
return self.__batches_per_epoch
|
230 |
+
|
231 |
+
def __getitem__(self, index):
|
232 |
+
"""
|
233 |
+
Generate one batch of data
|
234 |
+
:param index: epoch index
|
235 |
+
:return:
|
236 |
+
"""
|
237 |
+
idxs = self.__internal_idxs[index * self.__batch_size:(index + 1) * self.__batch_size]
|
238 |
+
|
239 |
+
fix_img, mov_img, fix_vessels, mov_vessels, fix_tumour, mov_tumour, disp_map = self.__load_data(idxs)
|
240 |
+
|
241 |
+
try:
|
242 |
+
fix_img = min_max_norm(fix_img).astype(np.float32)
|
243 |
+
mov_img = min_max_norm(mov_img).astype(np.float32)
|
244 |
+
except ValueError:
|
245 |
+
print(idxs, fix_img.shape, mov_img.shape)
|
246 |
+
er_str = 'ERROR:\t[file]:\t{}\t[idx]:\t{}\t[fix_img.shape]:\t{}\t[mov_img.shape]:\t{}\t'.format(self.__list_files[idxs], idxs, fix_img.shape, mov_img.shape)
|
247 |
+
raise ValueError(er_str)
|
248 |
+
|
249 |
+
fix_vessels[fix_vessels > 0.] = self.__seg_labels['vessels']
|
250 |
+
mov_vessels[mov_vessels > 0.] = self.__seg_labels['vessels']
|
251 |
+
|
252 |
+
# fix_tumour[fix_tumour > 0.] = self.__seg_labels['tumour']
|
253 |
+
# mov_tumour[mov_tumour > 0.] = self.__seg_labels['tumour']
|
254 |
+
|
255 |
+
# https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit
|
256 |
+
# A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample_weights)
|
257 |
+
# The second element must match the outputs of the model, in this case (image, displacement map)
|
258 |
+
if self.__voxelmorph:
|
259 |
+
zero_grad = np.zeros([fix_img.shape[0], *C.DISP_MAP_SHAPE])
|
260 |
+
if self.__segmentations:
|
261 |
+
inputs = [mov_vessels, fix_vessels, mov_img, fix_img, zero_grad]
|
262 |
+
outputs = [] #[fix_img, zero_grad]
|
263 |
+
else:
|
264 |
+
inputs = [mov_img, fix_img]
|
265 |
+
outputs = [fix_img, zero_grad]
|
266 |
+
return (inputs, outputs)
|
267 |
+
else:
|
268 |
+
return (fix_img, mov_img, fix_vessels, mov_vessels), # (None, fix_seg, fix_seg, fix_img)
|
269 |
+
|
270 |
+
def next_batch(self):
|
271 |
+
if self.__last_batch > self.__batches_per_epoch:
|
272 |
+
raise ValueError('No more batches for this epoch')
|
273 |
+
batch = self.__getitem__(self.__last_batch)
|
274 |
+
self.__last_batch += 1
|
275 |
+
return batch
|
276 |
+
|
277 |
+
def __load_data(self, idx_list):
|
278 |
+
"""
|
279 |
+
Build the batch with the samples in idx_list
|
280 |
+
:param idx_list:
|
281 |
+
:return:
|
282 |
+
"""
|
283 |
+
if isinstance(idx_list, (list, np.ndarray)):
|
284 |
+
fix_img = np.empty((0, ) + C.IMG_SHAPE)
|
285 |
+
mov_img = np.empty((0, ) + C.IMG_SHAPE)
|
286 |
+
disp_map = np.empty((0, ) + C.DISP_MAP_SHAPE)
|
287 |
+
|
288 |
+
# fix_segm = np.empty((0, ) + const.IMG_SHAPE)
|
289 |
+
# mov_segm = np.empty((0, ) + const.IMG_SHAPE)
|
290 |
+
|
291 |
+
fix_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
292 |
+
mov_vessels = np.empty((0, ) + C.IMG_SHAPE)
|
293 |
+
fix_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
294 |
+
mov_tumors = np.empty((0, ) + C.IMG_SHAPE)
|
295 |
+
for idx in idx_list:
|
296 |
+
data_file = h5py.File(self.__list_files[idx], 'r')
|
297 |
+
|
298 |
+
fix_img = np.append(fix_img, [data_file[C.H5_FIX_IMG][:]], axis=0)
|
299 |
+
mov_img = np.append(mov_img, [data_file[C.H5_MOV_IMG][:]], axis=0)
|
300 |
+
|
301 |
+
# fix_segm = np.append(fix_segm, [data_file[const.H5_FIX_PARENCHYMA_MASK][:]], axis=0)
|
302 |
+
# mov_segm = np.append(mov_segm, [data_file[const.H5_MOV_PARENCHYMA_MASK][:]], axis=0)
|
303 |
+
|
304 |
+
disp_map = np.append(disp_map, [data_file[C.H5_GT_DISP][:]], axis=0)
|
305 |
+
|
306 |
+
fix_vessels = np.append(fix_vessels, [data_file[C.H5_FIX_VESSELS_MASK][:]], axis=0)
|
307 |
+
mov_vessels = np.append(mov_vessels, [data_file[C.H5_MOV_VESSELS_MASK][:]], axis=0)
|
308 |
+
fix_tumors = np.append(fix_tumors, [data_file[C.H5_FIX_TUMORS_MASK][:]], axis=0)
|
309 |
+
mov_tumors = np.append(mov_tumors, [data_file[C.H5_MOV_TUMORS_MASK][:]], axis=0)
|
310 |
+
|
311 |
+
data_file.close()
|
312 |
+
else:
|
313 |
+
data_file = h5py.File(self.__list_files[idx_list], 'r')
|
314 |
+
|
315 |
+
fix_img = np.expand_dims(data_file[C.H5_FIX_IMG][:], 0)
|
316 |
+
mov_img = np.expand_dims(data_file[C.H5_MOV_IMG][:], 0)
|
317 |
+
|
318 |
+
# fix_segm = np.expand_dims(data_file[const.H5_FIX_PARENCHYMA_MASK][:], 0)
|
319 |
+
# mov_segm = np.expand_dims(data_file[const.H5_MOV_PARENCHYMA_MASK][:], 0)
|
320 |
+
|
321 |
+
fix_vessels = np.expand_dims(data_file[C.H5_FIX_VESSELS_MASK][:], axis=0)
|
322 |
+
mov_vessels = np.expand_dims(data_file[C.H5_MOV_VESSELS_MASK][:], axis=0)
|
323 |
+
fix_tumors = np.expand_dims(data_file[C.H5_FIX_TUMORS_MASK][:], axis=0)
|
324 |
+
mov_tumors = np.expand_dims(data_file[C.H5_MOV_TUMORS_MASK][:], axis=0)
|
325 |
+
|
326 |
+
disp_map = np.expand_dims(data_file[C.H5_GT_DISP][:], 0)
|
327 |
+
|
328 |
+
data_file.close()
|
329 |
+
|
330 |
+
return fix_img, mov_img, fix_vessels, mov_vessels, fix_tumors, mov_tumors, disp_map
|
331 |
+
|
332 |
+
def get_single_sample(self):
|
333 |
+
fix_img, mov_img, fix_segm, mov_segm, _ = self.__load_data(0)
|
334 |
+
# return X, y
|
335 |
+
return np.expand_dims(np.concatenate([mov_img, fix_img, mov_segm, mov_segm], axis=-1), axis=0)
|
336 |
+
|
337 |
+
def get_random_sample(self, num_samples):
|
338 |
+
idxs = np.random.randint(0, self.__num_samples, num_samples)
|
339 |
+
fix_img, mov_img, fix_segm, mov_segm, disp_map = self.__load_data(idxs)
|
340 |
+
|
341 |
+
return (fix_img, mov_img, fix_segm, mov_segm, disp_map), [self.__list_files[f] for f in idxs]
|
342 |
+
|
343 |
+
def get_input_shape(self):
|
344 |
+
input_batch, _ = self.__getitem__(0)
|
345 |
+
if self.__voxelmorph:
|
346 |
+
ret_val = list(input_batch[0].shape)
|
347 |
+
ret_val[-1] = 2
|
348 |
+
ret_val = (None, ) + tuple(ret_val[1:])
|
349 |
+
else:
|
350 |
+
ret_val = input_batch.shape
|
351 |
+
ret_val = (None, ) + ret_val[1:]
|
352 |
+
return ret_val # const.BATCH_SHAPE_SEGM
|
353 |
+
|
354 |
+
def who_are_you(self):
|
355 |
+
return self.__dataset_type
|
356 |
+
|
357 |
+
def print_datafiles(self):
|
358 |
+
return self.__list_files
|
359 |
+
|
360 |
+
|
361 |
+
class DataGeneratorManager2D:
|
362 |
+
FIX_IMG_H5 = 'input/1'
|
363 |
+
MOV_IMG_H5 = 'input/0'
|
364 |
+
def __init__(self, h5_file_list, batch_size=32, data_split=0.7, img_size=None,
|
365 |
+
fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
|
366 |
+
self.__file_list = h5_file_list #h5py.File(h5_file, 'r')
|
367 |
+
self.__batch_size = batch_size
|
368 |
+
self.__data_split = data_split
|
369 |
+
|
370 |
+
self.__initialize()
|
371 |
+
|
372 |
+
self.__train_generator = DataGenerator2D(self.__train_file_list,
|
373 |
+
batch_size=self.__batch_size,
|
374 |
+
img_size=img_size,
|
375 |
+
fix_img_tag=fix_img_tag,
|
376 |
+
mov_img_tag=mov_img_tag,
|
377 |
+
multi_loss=multi_loss)
|
378 |
+
self.__val_generator = DataGenerator2D(self.__val_file_list,
|
379 |
+
batch_size=self.__batch_size,
|
380 |
+
img_size=img_size,
|
381 |
+
fix_img_tag=fix_img_tag,
|
382 |
+
mov_img_tag=mov_img_tag,
|
383 |
+
multi_loss=multi_loss)
|
384 |
+
|
385 |
+
def __initialize(self):
|
386 |
+
num_samples = len(self.__file_list)
|
387 |
+
random.shuffle(self.__file_list)
|
388 |
+
|
389 |
+
data_split = int(np.floor(num_samples * self.__data_split))
|
390 |
+
self.__val_file_list = self.__file_list[0:data_split]
|
391 |
+
self.__train_file_list = self.__file_list[data_split:]
|
392 |
+
|
393 |
+
@property
|
394 |
+
def train_generator(self):
|
395 |
+
return self.__train_generator
|
396 |
+
|
397 |
+
@property
|
398 |
+
def validation_generator(self):
|
399 |
+
return self.__val_generator
|
400 |
+
|
401 |
+
|
402 |
+
class DataGenerator2D(keras.utils.Sequence):
|
403 |
+
FIX_IMG_H5 = 'input/1'
|
404 |
+
MOV_IMG_H5 = 'input/0'
|
405 |
+
|
406 |
+
def __init__(self, file_list: list, batch_size=32, img_size=None, fix_img_tag=FIX_IMG_H5, mov_img_tag=MOV_IMG_H5, multi_loss=False):
|
407 |
+
self.__file_list = file_list # h5py.File(h5_file, 'r')
|
408 |
+
self.__file_list.sort()
|
409 |
+
self.__batch_size = batch_size
|
410 |
+
self.__idx_list = np.arange(0, len(self.__file_list))
|
411 |
+
self.__multi_loss = multi_loss
|
412 |
+
|
413 |
+
self.__tags = {'fix_img': fix_img_tag,
|
414 |
+
'mov_img': mov_img_tag}
|
415 |
+
|
416 |
+
self.__batches_seen = 0
|
417 |
+
self.__batches_per_epoch = 0
|
418 |
+
|
419 |
+
self.__img_size = img_size
|
420 |
+
|
421 |
+
self.__initialize()
|
422 |
+
|
423 |
+
def __len__(self):
|
424 |
+
return self.__batches_per_epoch
|
425 |
+
|
426 |
+
def __initialize(self):
|
427 |
+
random.shuffle(self.__idx_list)
|
428 |
+
|
429 |
+
if self.__img_size is None:
|
430 |
+
f = h5py.File(self.__file_list[0], 'r')
|
431 |
+
self.input_shape = f[self.__tags['fix_img']].shape # Already defined in super()
|
432 |
+
f.close()
|
433 |
+
else:
|
434 |
+
self.input_shape = self.__img_size
|
435 |
+
|
436 |
+
if self.__multi_loss:
|
437 |
+
self.input_shape = (self.input_shape, (*self.input_shape[:-1], 2))
|
438 |
+
|
439 |
+
self.__batches_per_epoch = int(np.ceil(len(self.__file_list) / self.__batch_size))
|
440 |
+
|
441 |
+
def __load_and_preprocess(self, fh, tag):
|
442 |
+
img = fh[tag][:]
|
443 |
+
|
444 |
+
if (self.__img_size is not None) and (img[..., 0].shape != self.__img_size):
|
445 |
+
im = Image.fromarray(img[..., 0]) # Can't handle the 1 channel
|
446 |
+
img = np.array(im.resize(self.__img_size[:-1], Image.LANCZOS)).astype(np.float32)
|
447 |
+
img = img[..., np.newaxis]
|
448 |
+
|
449 |
+
if img.max() > 1. or img.min() < 0.:
|
450 |
+
try:
|
451 |
+
img = min_max_norm(img).astype(np.float32)
|
452 |
+
except ValueError:
|
453 |
+
print(fh, tag, img.shape)
|
454 |
+
er_str = 'ERROR:\t[file]:\t{}\t[tag]:\t{}\t[img.shape]:\t{}\t'.format(fh, tag, img.shape)
|
455 |
+
raise ValueError(er_str)
|
456 |
+
return img.astype(np.float32)
|
457 |
+
|
458 |
+
def __getitem__(self, idx):
|
459 |
+
idxs = self.__idx_list[idx * self.__batch_size:(idx + 1) * self.__batch_size]
|
460 |
+
|
461 |
+
fix_imgs, mov_imgs = self.__load_samples(idxs)
|
462 |
+
|
463 |
+
zero_grad = np.zeros((*fix_imgs.shape[:-1], 2))
|
464 |
+
|
465 |
+
inputs = [mov_imgs, fix_imgs]
|
466 |
+
outputs = [fix_imgs, zero_grad]
|
467 |
+
|
468 |
+
if self.__multi_loss:
|
469 |
+
return [mov_imgs, fix_imgs, zero_grad],
|
470 |
+
else:
|
471 |
+
return (inputs, outputs)
|
472 |
+
|
473 |
+
def __load_samples(self, idx_list):
|
474 |
+
if self.__multi_loss:
|
475 |
+
img_shape = (0, *self.input_shape[0])
|
476 |
+
else:
|
477 |
+
img_shape = (0, *self.input_shape)
|
478 |
+
|
479 |
+
fix_imgs = np.empty(img_shape)
|
480 |
+
mov_imgs = np.empty(img_shape)
|
481 |
+
for i in idx_list:
|
482 |
+
f = h5py.File(self.__file_list[i], 'r')
|
483 |
+
fix_imgs = np.append(fix_imgs, [self.__load_and_preprocess(f, self.__tags['fix_img'])], axis=0)
|
484 |
+
mov_imgs = np.append(mov_imgs, [self.__load_and_preprocess(f, self.__tags['mov_img'])], axis=0)
|
485 |
+
f.close()
|
486 |
+
|
487 |
+
return fix_imgs, mov_imgs
|
488 |
+
|
489 |
+
def on_epoch_end(self):
|
490 |
+
np.random.shuffle(self.__idx_list)
|
491 |
+
|
492 |
+
def get_single_sample(self):
|
493 |
+
idx = random.randint(0, len(self.__idx_list))
|
494 |
+
fix, mov = self.__load_samples([idx])
|
495 |
+
return mov, fix
|
496 |
+
|
497 |
+
|
DeepDeformationMapRegistration/layers.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
8 |
+
|
9 |
+
import tensorflow.keras.layers as kl
|
10 |
+
import tensorflow.keras.backend as K
|
11 |
+
import tensorflow as tf
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from DeepDeformationMapRegistration.utils.operators import soft_threshold
|
15 |
+
|
16 |
+
|
17 |
+
class UncertaintyWeighting(kl.Layer):
|
18 |
+
def __init__(self, num_loss_fns=1, num_reg_fns=0, loss_fns: list = [tf.keras.losses.mean_squared_error],
|
19 |
+
reg_fns: list = list(), prior_loss_w=[1.], manual_loss_w=[1.], prior_reg_w=[1.], manual_reg_w=[1.],
|
20 |
+
**kwargs):
|
21 |
+
assert isinstance(loss_fns, list) and (num_loss_fns == len(loss_fns) or len(loss_fns) == 1)
|
22 |
+
assert isinstance(reg_fns, list) and (num_reg_fns == len(reg_fns))
|
23 |
+
self.num_loss = num_loss_fns
|
24 |
+
if len(loss_fns) == 1 and self.num_loss > 1:
|
25 |
+
self.loss_fns = loss_fns * self.num_loss
|
26 |
+
else:
|
27 |
+
self.loss_fns = loss_fns
|
28 |
+
|
29 |
+
if len(prior_loss_w) == 1:
|
30 |
+
self.prior_loss_w = prior_loss_w * num_loss_fns
|
31 |
+
else:
|
32 |
+
self.prior_loss_w = prior_loss_w
|
33 |
+
self.prior_loss_w = np.log(self.prior_loss_w)
|
34 |
+
|
35 |
+
if len(manual_loss_w) == 1:
|
36 |
+
self.manual_loss_w = manual_loss_w * num_loss_fns
|
37 |
+
else:
|
38 |
+
self.manual_loss_w = manual_loss_w
|
39 |
+
|
40 |
+
self.num_reg = num_reg_fns
|
41 |
+
if self.num_reg != 0:
|
42 |
+
if len(reg_fns) == 1 and self.num_reg > 1:
|
43 |
+
self.reg_fns = reg_fns * self.num_reg
|
44 |
+
else:
|
45 |
+
self.reg_fns = reg_fns
|
46 |
+
|
47 |
+
self.is_placeholder = True
|
48 |
+
if self.num_reg != 0:
|
49 |
+
if len(prior_reg_w) == 1:
|
50 |
+
self.prior_reg_w = prior_reg_w * num_reg_fns
|
51 |
+
else:
|
52 |
+
self.prior_reg_w = prior_reg_w
|
53 |
+
self.prior_reg_w = np.log(self.prior_reg_w)
|
54 |
+
|
55 |
+
if len(manual_reg_w) == 1:
|
56 |
+
self.manual_reg_w = manual_reg_w * num_reg_fns
|
57 |
+
else:
|
58 |
+
self.manual_reg_w = manual_reg_w
|
59 |
+
|
60 |
+
else:
|
61 |
+
self.prior_reg_w = list()
|
62 |
+
self.manual_reg_w = list()
|
63 |
+
|
64 |
+
super(UncertaintyWeighting, self).__init__(**kwargs)
|
65 |
+
|
66 |
+
def build(self, input_shape=None):
|
67 |
+
self.log_loss_vars = self.add_weight(name='loss_log_vars', shape=(self.num_loss,),
|
68 |
+
initializer=tf.keras.initializers.Constant(self.prior_loss_w),
|
69 |
+
trainable=True)
|
70 |
+
self.loss_weights = tf.math.softmax(self.log_loss_vars, name='SM_loss_weights')
|
71 |
+
|
72 |
+
if self.num_reg != 0:
|
73 |
+
self.log_reg_vars = self.add_weight(name='loss_reg_vars', shape=(self.num_reg,),
|
74 |
+
initializer=tf.keras.initializers.Constant(self.prior_reg_w),
|
75 |
+
trainable=True)
|
76 |
+
if self.num_reg == 1:
|
77 |
+
self.reg_weights = tf.math.exp(self.log_reg_vars, name='EXP_reg_weights')
|
78 |
+
else:
|
79 |
+
self.reg_weights = tf.math.softmax(self.log_reg_vars, name='SM_reg_weights')
|
80 |
+
|
81 |
+
super(UncertaintyWeighting, self).build(input_shape)
|
82 |
+
|
83 |
+
def multi_loss(self, ys_true, ys_pred, regs_true, regs_pred):
|
84 |
+
loss_values = list()
|
85 |
+
loss_names_loss = list()
|
86 |
+
loss_names_reg = list()
|
87 |
+
|
88 |
+
for y_true, y_pred, loss_fn, man_w in zip(ys_true, ys_pred, self.loss_fns, self.manual_loss_w):
|
89 |
+
loss_values.append(tf.keras.backend.mean(man_w * loss_fn(y_true, y_pred)))
|
90 |
+
loss_names_loss.append(loss_fn.__name__)
|
91 |
+
|
92 |
+
loss_values = tf.convert_to_tensor(loss_values, dtype=tf.float32, name="step_loss_values")
|
93 |
+
loss = tf.math.multiply(self.loss_weights, loss_values, name='step_weighted_loss')
|
94 |
+
|
95 |
+
if self.num_reg != 0:
|
96 |
+
loss_reg = list()
|
97 |
+
for reg_true, reg_pred, reg_fn, man_w in zip(regs_true, regs_pred, self.reg_fns, self.manual_reg_w):
|
98 |
+
loss_reg.append(K.mean(man_w * reg_fn(reg_true, reg_pred)))
|
99 |
+
loss_names_reg.append(reg_fn.__name__)
|
100 |
+
|
101 |
+
reg_values = tf.convert_to_tensor(loss_reg, dtype=tf.float32, name="step_reg_values")
|
102 |
+
loss = loss + tf.math.multiply(self.reg_weights, reg_values, name='step_weighted_reg')
|
103 |
+
|
104 |
+
for i, loss_name in enumerate(loss_names_loss):
|
105 |
+
self.add_metric(tf.slice(self.loss_weights, [i], [1]), name='LOSS_WEIGHT_{}_{}'.format(i, loss_name),
|
106 |
+
aggregation='mean')
|
107 |
+
self.add_metric(tf.slice(loss_values, [i], [1]), name='LOSS_VALUE_{}_{}'.format(i, loss_name),
|
108 |
+
aggregation='mean')
|
109 |
+
if self.num_reg != 0:
|
110 |
+
for i, loss_name in enumerate(loss_names_reg):
|
111 |
+
self.add_metric(tf.slice(self.reg_weights, [i], [1]), name='REG_WEIGHT_{}_{}'.format(i, loss_name),
|
112 |
+
aggregation='mean')
|
113 |
+
self.add_metric(tf.slice(reg_values, [i], [1]), name='REG_VALUE_{}_{}'.format(i, loss_name),
|
114 |
+
aggregation='mean')
|
115 |
+
|
116 |
+
return K.sum(loss)
|
117 |
+
|
118 |
+
def call(self, inputs):
|
119 |
+
ys_true = inputs[:self.num_loss]
|
120 |
+
ys_pred = inputs[self.num_loss:self.num_loss*2]
|
121 |
+
reg_true = inputs[-self.num_reg*2:-self.num_reg]
|
122 |
+
reg_pred = inputs[-self.num_reg:] # The last terms are the regularization ones which have no GT
|
123 |
+
loss = self.multi_loss(ys_true, ys_pred, reg_true, reg_pred)
|
124 |
+
self.add_loss(loss, inputs=inputs)
|
125 |
+
# We won't actually use the output, but we need something for the TF graph
|
126 |
+
return K.concatenate(inputs, -1)
|
127 |
+
|
128 |
+
def get_config(self):
|
129 |
+
base_config = super(UncertaintyWeighting, self).get_config()
|
130 |
+
base_config['num_loss_fns'] = self.num_loss
|
131 |
+
base_config['num_reg_fns'] = self.num_reg
|
132 |
+
|
133 |
+
return base_config
|
134 |
+
|
135 |
+
|
136 |
+
def distance_map(coord1, coord2, dist, img_shape_w_channel=(64, 64, 1)):
|
137 |
+
max_dist = np.max(img_shape_w_channel)
|
138 |
+
dm_p = np.ones(img_shape_w_channel, np.float32)*max_dist
|
139 |
+
dm_n = np.ones(img_shape_w_channel, np.float32)*max_dist
|
140 |
+
|
141 |
+
for c1, c2, d in zip(coord1, coord2, dist):
|
142 |
+
dm_p[c1, c2, 0] = d if dm_p[c1, c2, 0] > d else dm_p[c1, c2]
|
143 |
+
d_n = 64. - max_dist
|
144 |
+
dm_n[c1, c2, 0] = d_n if dm_n[c1, c2, 0] > d_n else dm_n[c1, c2]
|
145 |
+
|
146 |
+
return dm_p/max_dist, dm_n/max_dist
|
147 |
+
|
148 |
+
|
149 |
+
def volume_to_ov_and_dm(in_volume: tf.Tensor):
|
150 |
+
# This one is run as a preprocessing step
|
151 |
+
def get_ov_projections_and_dm(volume):
|
152 |
+
# tf.sign returns -1, 0, 1 depending on the sign of the elements of the input (negative, zero, positive)
|
153 |
+
i, j, k, c = tf.where(volume > 0.0)
|
154 |
+
top = tf.sign(tf.reduce_sum(volume, axis=0), name='ov_top')
|
155 |
+
right = tf.sign(tf.reduce_sum(volume, axis=1), name='ov_right')
|
156 |
+
front = tf.sign(tf.reduce_sum(volume, axis=2), name='ov_front')
|
157 |
+
|
158 |
+
top_p, top_n = tf.py_func(distance_map, [j, k, i], tf.float32)
|
159 |
+
right_p, right_n = tf.py_func(distance_map, [i, k, j], tf.float32)
|
160 |
+
front_p, front_n = tf.py_func(distance_map, [i, j, k], tf.float32)
|
161 |
+
|
162 |
+
return [front, right, top], [front_p, front_n, top_p, top_n, right_p, right_n]
|
163 |
+
|
164 |
+
if len(in_volume.shape.as_list()) > 4:
|
165 |
+
return tf.map_fn(get_ov_projections_and_dm, in_volume, [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32])
|
166 |
+
else:
|
167 |
+
return get_ov_projections_and_dm(in_volume)
|
168 |
+
|
169 |
+
|
170 |
+
def ov_and_dm_to_volume(ov_projections):
|
171 |
+
front, right, top = ov_projections
|
172 |
+
|
173 |
+
def get_volume(front: tf.Tensor, right: tf.Tensor, top: tf.Tensor):
|
174 |
+
front_shape = front.shape.as_list() # Assume (H, W, C)
|
175 |
+
top_shape = top.shape.as_list()
|
176 |
+
|
177 |
+
front_vol = tf.tile(tf.expand_dims(front, 2), [1, 1, top_shape[0], 1])
|
178 |
+
right_vol = tf.tile(tf.expand_dims(right, 1), [1, front_shape[1], 1, 1])
|
179 |
+
top_vol = tf.tile(tf.expand_dims(top, 0), [front_shape[0], 1, 1, 1])
|
180 |
+
sum = tf.add(tf.add(front_vol, right_vol), top_vol)
|
181 |
+
return soft_threshold(sum, 2., 'get_volume')
|
182 |
+
|
183 |
+
if len(front.shape.as_list()) > 3:
|
184 |
+
return tf.map_fn(lambda x: get_volume(x[0], x[1], x[2]), ov_projections, tf.float32)
|
185 |
+
else:
|
186 |
+
return get_volume(front, right, top)
|
187 |
+
|
188 |
+
# TODO: Recovering the coordinates from the distance maps to prevent artifacts
|
189 |
+
# will the gradients be backpropagated??!?!!?!?!
|
DeepDeformationMapRegistration/losses.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
+
|
8 |
+
import tensorflow as tf
|
9 |
+
from scipy.ndimage import generate_binary_structure
|
10 |
+
|
11 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
12 |
+
from DeepDeformationMapRegistration.utils.operators import soft_threshold
|
13 |
+
|
14 |
+
|
15 |
+
class HausdorffDistance:
|
16 |
+
def __init__(self, ndim=3, nerosion=10):
|
17 |
+
self.ndims = ndim
|
18 |
+
self.conv = getattr(tf.nn, 'conv%dd' % self.ndims)
|
19 |
+
self.nerosions = nerosion
|
20 |
+
|
21 |
+
def _erode(self, in_tensor, kernel):
|
22 |
+
out = 1. - tf.squeeze(self.conv(tf.expand_dims(1. - in_tensor, 0), kernel, [1] * (self.ndims + 2), 'SAME'), axis=0)
|
23 |
+
return soft_threshold(out, 0.5, name='soft_thresholding')
|
24 |
+
|
25 |
+
def _erosion_distance_single(self, y_true, y_pred):
|
26 |
+
diff = tf.math.pow(y_pred - y_true, 2)
|
27 |
+
alpha = 2.
|
28 |
+
|
29 |
+
norm = 1 / self.ndims * 2 + 1
|
30 |
+
kernel = generate_binary_structure(self.ndims, 1).astype(int) * norm
|
31 |
+
kernel = tf.constant(kernel, tf.float32)
|
32 |
+
kernel = tf.expand_dims(tf.expand_dims(kernel, -1), -1)
|
33 |
+
|
34 |
+
ret = 0.
|
35 |
+
for i in range(self.nerosions):
|
36 |
+
for j in range(i + 1):
|
37 |
+
er = self._erode(diff, kernel)
|
38 |
+
ret += tf.reduce_sum(tf.multiply(er, tf.pow(i + 1., alpha)))
|
39 |
+
|
40 |
+
return tf.multiply(C.IMG_SIZE ** -self.ndims, ret) # Divide by the image size
|
41 |
+
|
42 |
+
def loss(self, y_true, y_pred):
|
43 |
+
batched_dist = tf.map_fn(lambda x: self._erosion_distance_single(x[0], x[1]), (y_true, y_pred),
|
44 |
+
dtype=tf.float32)
|
45 |
+
|
46 |
+
return batched_dist # tf.reduce_mean(batched_dist)
|
47 |
+
|
DeepDeformationMapRegistration/networks.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
+
|
8 |
+
import tensorflow as tf
|
9 |
+
import voxelmorph as vxm
|
10 |
+
from voxelmorph.tf.modelio import LoadableModel, store_config_args
|
11 |
+
|
12 |
+
|
13 |
+
class VxmWeaklySupervised(LoadableModel):
|
14 |
+
|
15 |
+
@store_config_args
|
16 |
+
def __init__(self, inshape, all_labels: [list, tuple], nb_unet_features=None, int_steps=5, bidir=False, **kwargs):
|
17 |
+
"""
|
18 |
+
Parameters:
|
19 |
+
inshape: Input shape. e.g. (192, 192, 192)
|
20 |
+
all_labels: List of all labels included in training segmentations.
|
21 |
+
hot_labels: List of labels to output as one-hot maps.
|
22 |
+
nb_unet_features: Unet convolutional features. See VxmDense documentation for more information.
|
23 |
+
int_steps: Number of flow integration steps. The warp is non-diffeomorphic when this value is 0.
|
24 |
+
kwargs: Forwarded to the internal VxmDense model.
|
25 |
+
"""
|
26 |
+
|
27 |
+
fix_segm = tf.keras.Input((*inshape, len(all_labels)), name='fix_segmentations_input')
|
28 |
+
mov_segm = tf.keras.Input((*inshape, len(all_labels)), name='mov_segmentations_input')
|
29 |
+
|
30 |
+
mov_img = tf.keras.Input((*inshape, 1), name='mov_image_input')
|
31 |
+
|
32 |
+
unet_input_model = tf.keras.Model(inputs=[mov_segm, fix_segm], outputs=[mov_segm, fix_segm])
|
33 |
+
|
34 |
+
vxm_model = vxm.networks.VxmDense(inshape=inshape,
|
35 |
+
nb_unet_features=nb_unet_features,
|
36 |
+
input_model=unet_input_model,
|
37 |
+
int_steps=int_steps,
|
38 |
+
bidir=bidir, **kwargs)
|
39 |
+
|
40 |
+
pred_img = vxm.layers.SpatialTransformer(interp_method='linear', indexing='ij', name='pred_fix_img')(
|
41 |
+
[mov_img, vxm_model.references.pos_flow])
|
42 |
+
|
43 |
+
inputs = [mov_segm, fix_segm, mov_img] # mov_img, mov_segm, fix_segm
|
44 |
+
outputs = [pred_img] + vxm_model.outputs
|
45 |
+
|
46 |
+
self.references = LoadableModel.ReferenceContainer()
|
47 |
+
self.references.pred_segm = vxm_model.outputs[0]
|
48 |
+
self.references.pred_img = pred_img
|
49 |
+
self.references.pos_flow = vxm_model.references.pos_flow
|
50 |
+
|
51 |
+
super().__init__(inputs=inputs, outputs=outputs)
|
52 |
+
|
53 |
+
def get_registration_model(self):
|
54 |
+
return tf.keras.Model(self.inputs, self.references.pos_flow)
|
55 |
+
|
56 |
+
def register(self, mov_img, mov_segm, fix_segm):
|
57 |
+
return self.get_registration_model().predict([mov_segm, fix_segm, mov_img])
|
58 |
+
|
59 |
+
def apply_transform(self, mov_img, mov_segm, fix_segm, interp_method='linear'):
|
60 |
+
warp_model = self.get_registration_model()
|
61 |
+
img_input = tf.keras.Input(shape=mov_img.shape[1:], name='input_img')
|
62 |
+
pred_img = vxm.layers.SpatialTransformer(interp_method=interp_method)([img_input, warp_model.output])
|
63 |
+
return tf.keras.Model(warp_model.inputs, pred_img).predict([mov_segm, fix_segm, mov_img])
|
DeepDeformationMapRegistration/utils/acummulated_optimizer.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow.keras.optimizers import Optimizer
|
2 |
+
from tensorflow.keras import backend as K
|
3 |
+
|
4 |
+
|
5 |
+
class AccumOptimizer(Optimizer):
|
6 |
+
"""Optimizer
|
7 |
+
Inheriting Optimizer class, wrapping the original optimizer
|
8 |
+
to achieve a new corresponding optimizer of gradient accumulation.
|
9 |
+
# Arguments
|
10 |
+
optimizer: an instance of keras optimizer (supporting
|
11 |
+
all keras optimizers currently available);
|
12 |
+
steps_per_update: the steps of gradient accumulation
|
13 |
+
# Returns
|
14 |
+
a new keras optimizer.
|
15 |
+
"""
|
16 |
+
def __init__(self, optimizer, steps_per_update=1, **kwargs):
|
17 |
+
super(AccumOptimizer, self).__init__(**kwargs)
|
18 |
+
self.optimizer = optimizer
|
19 |
+
with K.name_scope(self.__class__.__name__):
|
20 |
+
self.steps_per_update = steps_per_update
|
21 |
+
self.iterations = K.variable(0, dtype='int64', name='iterations')
|
22 |
+
self.cond = K.equal(self.iterations % self.steps_per_update, 0)
|
23 |
+
self.lr = self.optimizer.lr
|
24 |
+
self.optimizer.lr = K.switch(self.cond, self.optimizer.lr, 0.)
|
25 |
+
for attr in ['momentum', 'rho', 'beta_1', 'beta_2']:
|
26 |
+
if hasattr(self.optimizer, attr):
|
27 |
+
value = getattr(self.optimizer, attr)
|
28 |
+
setattr(self, attr, value)
|
29 |
+
setattr(self.optimizer, attr, K.switch(self.cond, value, 1 - 1e-7))
|
30 |
+
for attr in self.optimizer.get_config():
|
31 |
+
if not hasattr(self, attr):
|
32 |
+
value = getattr(self.optimizer, attr)
|
33 |
+
setattr(self, attr, value)
|
34 |
+
# Cover the original get_gradients method with accumulative gradients.
|
35 |
+
def get_gradients(loss, params):
|
36 |
+
return [ag / self.steps_per_update for ag in self.accum_grads]
|
37 |
+
self.optimizer.get_gradients = get_gradients
|
38 |
+
def get_updates(self, loss, params):
|
39 |
+
self.updates = [
|
40 |
+
K.update_add(self.iterations, 1),
|
41 |
+
K.update_add(self.optimizer.iterations, K.cast(self.cond, 'int64')),
|
42 |
+
]
|
43 |
+
# gradient accumulation
|
44 |
+
self.accum_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
|
45 |
+
grads = self.get_gradients(loss, params)
|
46 |
+
for g, ag in zip(grads, self.accum_grads):
|
47 |
+
self.updates.append(K.update(ag, K.switch(self.cond, g, ag + g)))
|
48 |
+
# inheriting updates of original optimizer
|
49 |
+
self.updates.extend(self.optimizer.get_updates(loss, params)[1:])
|
50 |
+
self.weights.extend(self.optimizer.weights)
|
51 |
+
return self.updates
|
52 |
+
def get_config(self):
|
53 |
+
iterations = K.eval(self.iterations)
|
54 |
+
K.set_value(self.iterations, 0)
|
55 |
+
config = self.optimizer.get_config()
|
56 |
+
K.set_value(self.iterations, iterations)
|
57 |
+
return config
|
DeepDeformationMapRegistration/utils/cmd_args_parser.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, getopt
|
2 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def parse_arguments(argv):
|
7 |
+
|
8 |
+
try:
|
9 |
+
opts, args = getopt.getopt(argv, "hg:b:l:r:d:t:i:f:x:p:q:", ["gpu-num=",
|
10 |
+
"batch-size=",
|
11 |
+
"loss=",
|
12 |
+
"remote=",
|
13 |
+
"debug=",
|
14 |
+
"debug-training=",
|
15 |
+
"debug-input-data=",
|
16 |
+
"destination-folder=",
|
17 |
+
"destination-folder-fix=",
|
18 |
+
"training-dataset=",
|
19 |
+
"test-dataset=",
|
20 |
+
"help"])
|
21 |
+
except getopt.GetoptError:
|
22 |
+
print('\n\t\t--gpu-num:\t\tGPU number to use'
|
23 |
+
'\n\t\t--batch-size:\t\tsize of the training batch'
|
24 |
+
'\n\t\t--loss:\t\tLoss function: ncc, mse, dssim'
|
25 |
+
'\n\t\t--remote:\t\tExecuting the script in The Beast: "True"/"False". Def: False'
|
26 |
+
'\n\t\t--debug:\t\tEnable debugging logs: "True"/"False". Def: False'
|
27 |
+
'\n\t\t--debug-training:\t\tEnable debugging training logs: "True"/"False". Def: False'
|
28 |
+
'\n\t\t--debug-input-data:\t\tEnable debugging input data logs: "True"/"False". Def: False'
|
29 |
+
'\n\t\t--destination-folder:\t\tName of the folder where to save the generated training files'
|
30 |
+
'\n\t\t--destination-folder-fixed:\t\tSame as --destination-folder but do not add the timestamp'
|
31 |
+
'\n\t\t--training-dataset:\t\tPath to the training dataset file'
|
32 |
+
'\n\t\t--test-dataset:\t\tPath to the test dataset file'
|
33 |
+
'\n')
|
34 |
+
sys.exit(2)
|
35 |
+
|
36 |
+
for opt, arg in opts:
|
37 |
+
if opt in ('--help', '-h'):
|
38 |
+
print('\n\t\t--gpu-num:\t\tGPU number to use\n\t\t--batch-size:\t\tsize of the training batch'
|
39 |
+
'\n\t\t--loss:\t\tLoss function: ncc, mse, dssim\n')
|
40 |
+
continue
|
41 |
+
elif opt in ('--gpu_num', '-g'):
|
42 |
+
old = C.GPU_NUM
|
43 |
+
C.GPU_NUM = arg
|
44 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = C.GPU_NUM
|
45 |
+
print('\t\tGPU_NUM: {} -> {}'.format(old, C.GPU_NUM))
|
46 |
+
|
47 |
+
elif opt in ('--batch-size', '-b'):
|
48 |
+
old = C.BATCH_SIZE
|
49 |
+
C.BATCH_SIZE = int(arg)
|
50 |
+
print('\t\tBATCH_SIZE: {} -> {}'.format(old, C.BATCH_SIZE))
|
51 |
+
|
52 |
+
elif opt in ('--destination-folder', '-f'):
|
53 |
+
old = C.DESTINATION_FOLDER
|
54 |
+
C.DESTINATION_FOLDER = arg + '_' + C.CUR_DATETIME
|
55 |
+
print('\t\tDESTINATION_FOLDER: {} -> {}'.format(old, C.DESTINATION_FOLDER))
|
56 |
+
|
57 |
+
elif opt in ('--destination-folder-fixed', '-x'):
|
58 |
+
old = C.DESTINATION_FOLDER
|
59 |
+
C.DESTINATION_FOLDER = arg
|
60 |
+
print('\t\tDESTINATION_FOLDER: {} -> {}'.format(old, C.DESTINATION_FOLDER))
|
61 |
+
|
62 |
+
elif opt in ('--training-dataset', '-p'):
|
63 |
+
old = C.TRAINING_DATASET
|
64 |
+
C.TRAINING_DATASET = arg
|
65 |
+
print('\t\tTRAINING_DATASET: {} -> {}'.format(old, C.TRAINING_DATASET))
|
66 |
+
|
67 |
+
elif opt in ('--test-dataset', '-q'):
|
68 |
+
old = C.TEST_DATASET
|
69 |
+
C.TEST_DATASET = arg
|
70 |
+
print('\t\tTEST_DATASET: {} -> {}'.format(old, C.TEST_DATASET))
|
71 |
+
|
72 |
+
elif opt in ('--remote', '-r'):
|
73 |
+
old = C.REMOTE
|
74 |
+
if arg.lower() in ('1', 'true', 't'):
|
75 |
+
C.REMOTE = True
|
76 |
+
else:
|
77 |
+
C.REMOTE = False
|
78 |
+
print('\t\tREMOTE: {} -> {}'.format(old, C.REMOTE))
|
79 |
+
|
80 |
+
elif opt in ('--debug', '-d'):
|
81 |
+
old = C.DEBUG
|
82 |
+
if arg.lower() in ('1', 'true', 't'):
|
83 |
+
C.DEBUG = True
|
84 |
+
else:
|
85 |
+
C.DEBUG = False
|
86 |
+
print('\t\tDEBUG: {} -> {}'.format(old, C.DEBUG))
|
87 |
+
|
88 |
+
elif opt in ('--debug-training', '-t'):
|
89 |
+
old = C.DEBUG_TRAINING
|
90 |
+
if arg.lower() in ('1', 'true', 't'):
|
91 |
+
C.DEBUG_TRAINING = True
|
92 |
+
else:
|
93 |
+
C.DEBUG_TRAINING = False
|
94 |
+
print('\t\tDEBUG_TRAINING: {} -> {}'.format(old, C.DEBUG_TRAINING))
|
95 |
+
|
96 |
+
elif opt in ('--debug-input-data', '-i'):
|
97 |
+
old = C.DEBUG_INPUT_DATA
|
98 |
+
if arg.lower() in ('1', 'true', 't'):
|
99 |
+
C.DEBUG_INPUT_DATA = True
|
100 |
+
else:
|
101 |
+
C.DEBUG_INPUT_DATA = False
|
102 |
+
print('\t\tDEBUG_INPUT_DATA: {} -> {}'.format(old, C.DEBUG_INPUT_DATA))
|
103 |
+
|
104 |
+
elif opt in ('--loss', '-l'):
|
105 |
+
old = C.LOSS_FNC
|
106 |
+
if arg in ('ncc', 'mse', 'dssim', 'dice'):
|
107 |
+
C.LOSS_FNC = arg
|
108 |
+
else:
|
109 |
+
print('Invalid option for --loss. Expected: "mse", "ncc" or "dssim", got {}'.format(arg))
|
110 |
+
sys.exit(2)
|
111 |
+
print('\t\tLOSS_FNC: {} -> {}'.format(old, C.LOSS_FNC))
|
DeepDeformationMapRegistration/utils/conf_file_utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
2 |
+
import re
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class ConfigurationFile:
|
7 |
+
def __init__(self,
|
8 |
+
file_path: str):
|
9 |
+
self.__file = file_path
|
10 |
+
self.__load_configuration()
|
11 |
+
|
12 |
+
def __load_configuration(self):
|
13 |
+
fd = open(self.__file, 'r')
|
14 |
+
file_lines = fd.readlines()
|
15 |
+
|
16 |
+
for line in file_lines:
|
17 |
+
if '#' not in line and line != '\n':
|
18 |
+
match = re.match('(.*)=(.*)', line)
|
19 |
+
if match[1] in C.__dict__.keys():
|
20 |
+
# Careful with eval!!
|
21 |
+
try:
|
22 |
+
new_val = eval(match[2])
|
23 |
+
except NameError:
|
24 |
+
new_val = match[2]
|
25 |
+
old = C.__dict__[match[1]]
|
26 |
+
C.__dict__[match[1]] = new_val
|
27 |
+
|
28 |
+
# Special case
|
29 |
+
if match[1] == 'GPU_NUM':
|
30 |
+
C.__dict__[match[1]] = str(new_val)
|
31 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = C.GPU_NUM
|
32 |
+
|
33 |
+
if match[1] == 'EPOCHS':
|
34 |
+
C.__dict__[match[1]] = new_val
|
35 |
+
C.__dict__['SAVE_EPOCH'] = new_val // 10
|
36 |
+
C.__dict__['VERBOSE_EPOCH'] = new_val // 10
|
37 |
+
|
38 |
+
if match[1] == 'SAVE_EPOCH' or match[1] == 'VERBOSE_EPOCH':
|
39 |
+
if new_val is not None:
|
40 |
+
C.__dict__[match[1]] = C.__dict__['EPOCHS'] // new_val
|
41 |
+
else:
|
42 |
+
C.__dict__[match[1]] = None
|
43 |
+
|
44 |
+
if match[1] == 'VALIDATION_ERR_LIMIT_COUNTER':
|
45 |
+
C.__dict__[match[1]] = new_val
|
46 |
+
C.__dict__['VALIDATION_ERR_LIMIT_COUNTER_BACKUP'] = new_val
|
47 |
+
|
48 |
+
|
49 |
+
print('INFO: Updating constant {}: {} -> {}'.format(match[1], old, C.__dict__[match[1]]))
|
50 |
+
else:
|
51 |
+
print('ERROR: Unknown constant {}'.format(match[1]))
|
52 |
+
|
DeepDeformationMapRegistration/utils/constants.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Constants
|
3 |
+
"""
|
4 |
+
import tensorflow as tf
|
5 |
+
import os
|
6 |
+
import datetime
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# RUN CONFIG
|
10 |
+
REMOTE = False # os.popen('hostname').read().encode('utf-8') == 'medtech-beast' #os.environ.get('REMOTE') == 'True'
|
11 |
+
|
12 |
+
# Remote execution
|
13 |
+
DEV_ORDER = 'PCI_BUS_ID'
|
14 |
+
GPU_NUM = '0'
|
15 |
+
|
16 |
+
# Dataset generation constants
|
17 |
+
# See batchGenerator __next__ method: return [in_mov, in_fix], [disp_map, out_img]
|
18 |
+
MOVING_IMG = 0
|
19 |
+
FIXED_IMG = 1
|
20 |
+
MOVING_PARENCHYMA_MASK = 2
|
21 |
+
FIXED_PARENCHYMA_MASK = 3
|
22 |
+
MOVING_VESSELS_MASK = 4
|
23 |
+
FIXED_VESSELS_MASK = 5
|
24 |
+
MOVING_TUMORS_MASK = 6
|
25 |
+
FIXED_TUMORS_MASK = 7
|
26 |
+
MOVING_SEGMENTATIONS = 8 # Compination of vessels and tumors
|
27 |
+
FIXED_SEGMENTATIONS = 9 # Compination of vessels and tumors
|
28 |
+
DISP_MAP_GT = 0
|
29 |
+
PRED_IMG_GT = 1
|
30 |
+
DISP_VECT_GT = 2
|
31 |
+
DISP_VECT_LOC_GT = 3
|
32 |
+
|
33 |
+
IMG_SIZE = 64 # Assumed a square image
|
34 |
+
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, IMG_SIZE, 1) # (IMG_SIZE, IMG_SIZE, 1)
|
35 |
+
DISP_MAP_SHAPE = (IMG_SIZE, IMG_SIZE, IMG_SIZE, 3)
|
36 |
+
BATCH_SHAPE = (None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 2) # Expected batch shape by the network
|
37 |
+
BATCH_SHAPE_SEGM = (None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 3) # Expected batch shape by the network
|
38 |
+
IMG_BATCH_SHAPE = (None, IMG_SIZE, IMG_SIZE, IMG_SIZE, 1) # Batch shape for single images
|
39 |
+
|
40 |
+
RAW_DATA_BASE_DIR = './data'
|
41 |
+
DEFORMED_DATA_NAME = 'deformed'
|
42 |
+
GROUND_TRUTH_DATA_NAME = 'groundTruth'
|
43 |
+
GROUND_TRUTH_COORDS_FILE = 'centerlineCoords_GT.txt'
|
44 |
+
DEFORMED_COORDS_FILE = 'centerlineCoords_DF.txt'
|
45 |
+
H5_MOV_IMG = 'input/{}'.format(MOVING_IMG)
|
46 |
+
H5_FIX_IMG = 'input/{}'.format(FIXED_IMG)
|
47 |
+
H5_MOV_PARENCHYMA_MASK = 'input/{}'.format(MOVING_PARENCHYMA_MASK)
|
48 |
+
H5_FIX_PARENCHYMA_MASK = 'input/{}'.format(FIXED_PARENCHYMA_MASK)
|
49 |
+
H5_MOV_VESSELS_MASK = 'input/{}'.format(MOVING_VESSELS_MASK)
|
50 |
+
H5_FIX_VESSELS_MASK = 'input/{}'.format(FIXED_VESSELS_MASK)
|
51 |
+
H5_MOV_TUMORS_MASK = 'input/{}'.format(MOVING_TUMORS_MASK)
|
52 |
+
H5_FIX_TUMORS_MASK = 'input/{}'.format(FIXED_TUMORS_MASK)
|
53 |
+
H5_FIX_SEGMENTATIONS = 'input/{}'.format(FIXED_SEGMENTATIONS)
|
54 |
+
H5_MOV_SEGMENTATIONS = 'input/{}'.format(MOVING_SEGMENTATIONS)
|
55 |
+
|
56 |
+
H5_GT_DISP = 'output/{}'.format(DISP_MAP_GT)
|
57 |
+
H5_GT_IMG = 'output/{}'.format(PRED_IMG_GT)
|
58 |
+
H5_GT_DISP_VECT = 'output/{}'.format(DISP_VECT_GT)
|
59 |
+
H5_GT_DISP_VECT_LOC = 'output/{}'.format(DISP_VECT_LOC_GT)
|
60 |
+
H5_PARAMS_INTENSITY_RANGE = 'parameters/intensity'
|
61 |
+
TRAINING_PERC = 0.8
|
62 |
+
VALIDATION_PERC = 1 - TRAINING_PERC
|
63 |
+
MAX_ANGLE = 45.0 # degrees
|
64 |
+
MAX_FLIPS = 2 # Axes to flip over
|
65 |
+
NUM_ROTATIONS = 5
|
66 |
+
MAX_WORKERS = 10
|
67 |
+
|
68 |
+
# Training constants
|
69 |
+
MODEL = 'unet'
|
70 |
+
BATCH_NORM = False
|
71 |
+
TENSORBOARD = False
|
72 |
+
LIMIT_NUM_SAMPLES = None # If you don't want to use all the samples in the training set. None to use all
|
73 |
+
TRAINING_DATASET = 'data/training.hd5'
|
74 |
+
TEST_DATASET = 'data/test.hd5'
|
75 |
+
VALIDATION_DATASET = 'data/validation.hd5'
|
76 |
+
LOSS_FNC = 'mse'
|
77 |
+
LOSS_SCHEME = 'unidirectional'
|
78 |
+
NUM_EPOCHS = 10
|
79 |
+
DATA_FORMAT = 'channels_last' # or 'channels_fist'
|
80 |
+
DATA_DIR = './data'
|
81 |
+
MODEL_CHECKPOINT = './model_checkpoint'
|
82 |
+
BATCH_SIZE = 8
|
83 |
+
EPOCHS = 100
|
84 |
+
SAVE_EPOCH = EPOCHS // 10 # Epoch when to save the model
|
85 |
+
VERBOSE_EPOCH = EPOCHS // 10
|
86 |
+
VALIDATION_ERR_LIMIT = 0.2 # Stop training if reached this limit
|
87 |
+
VALIDATION_ERR_LIMIT_COUNTER = 10 # Number of successive times the validation error was smaller than the threshold
|
88 |
+
VALIDATION_ERR_LIMIT_COUNTER_BACKUP = 10
|
89 |
+
THRESHOLD = 0.5 # Threshold to select the centerline in the interpolated images
|
90 |
+
RESTORE_TRAINING = True # look for previously saved models to resume training
|
91 |
+
EARLY_STOP_PATIENCE = 10
|
92 |
+
LOG_FIELD_NAMES = ['time', 'epoch', 'step',
|
93 |
+
'training_loss_mean', 'training_loss_std',
|
94 |
+
'training_loss1_mean', 'training_loss1_std',
|
95 |
+
'training_loss2_mean', 'training_loss2_std',
|
96 |
+
'training_loss3_mean', 'training_loss3_std',
|
97 |
+
'training_ncc1_mean', 'training_ncc1_std',
|
98 |
+
'training_ncc2_mean', 'training_ncc2_std',
|
99 |
+
'training_ncc3_mean', 'training_ncc3_std',
|
100 |
+
'validation_loss_mean', 'validation_loss_std',
|
101 |
+
'validation_loss1_mean', 'validation_loss1_std',
|
102 |
+
'validation_loss2_mean', 'validation_loss2_std',
|
103 |
+
'validation_loss3_mean', 'validation_loss3_std',
|
104 |
+
'validation_ncc1_mean', 'validation_ncc1_std',
|
105 |
+
'validation_ncc2_mean', 'validation_ncc2_std',
|
106 |
+
'validation_ncc3_mean', 'validation_ncc3_std']
|
107 |
+
LOG_FIELD_NAMES_SHORT = ['time', 'epoch', 'step',
|
108 |
+
'training_loss_mean', 'training_loss_std',
|
109 |
+
'training_loss1_mean', 'training_loss1_std',
|
110 |
+
'training_loss2_mean', 'training_loss2_std',
|
111 |
+
'training_ncc1_mean', 'training_ncc1_std',
|
112 |
+
'training_ncc2_mean', 'training_ncc2_std',
|
113 |
+
'validation_loss_mean', 'validation_loss_std',
|
114 |
+
'validation_loss1_mean', 'validation_loss1_std',
|
115 |
+
'validation_loss2_mean', 'validation_loss2_std',
|
116 |
+
'validation_ncc1_mean', 'validation_ncc1_std',
|
117 |
+
'validation_ncc2_mean', 'validation_ncc2_std']
|
118 |
+
LOG_FIELD_NAMES_UNET = ['time', 'epoch', 'step', 'reg_smooth_coeff', 'reg_jacob_coeff',
|
119 |
+
'training_loss_mean', 'training_loss_std',
|
120 |
+
'training_loss_dissim_mean', 'training_loss_dissim_std',
|
121 |
+
'training_reg_smooth_mean', 'training_reg_smooth_std',
|
122 |
+
'training_reg_jacob_mean', 'training_reg_jacob_std',
|
123 |
+
'training_ncc_mean', 'training_ncc_std',
|
124 |
+
'training_dice_mean', 'training_dice_std',
|
125 |
+
'training_owo_mean', 'training_owo_std',
|
126 |
+
'validation_loss_mean', 'validation_loss_std',
|
127 |
+
'validation_loss_dissim_mean', 'validation_loss_dissim_std',
|
128 |
+
'validation_reg_smooth_mean', 'validation_reg_smooth_std',
|
129 |
+
'validation_reg_jacob_mean', 'validation_reg_jacob_std',
|
130 |
+
'validation_ncc_mean', 'validation_ncc_std',
|
131 |
+
'validation_dice_mean', 'validation_dice_std',
|
132 |
+
'validation_owo_mean', 'validation_owo_std']
|
133 |
+
CUR_DATETIME = datetime.datetime.now().strftime("%H%M_%d%m%Y")
|
134 |
+
DESTINATION_FOLDER = 'training_log_' + CUR_DATETIME
|
135 |
+
CSV_DELIMITER = ";"
|
136 |
+
CSV_QUOTE_CHAR = '"'
|
137 |
+
REG_SMOOTH = 0.0
|
138 |
+
REG_MAG = 1.0
|
139 |
+
REG_TYPE = 'l2'
|
140 |
+
MAX_DISP_DM = 10.
|
141 |
+
MAX_DISP_DM_TF = tf.constant((MAX_DISP_DM,), tf.float32, name='MAX_DISP_DM')
|
142 |
+
MAX_DISP_DM_PERC = 0.25
|
143 |
+
|
144 |
+
W_SIM = 0.7
|
145 |
+
W_REG = 0.3
|
146 |
+
W_INV = 0.1
|
147 |
+
|
148 |
+
# Loss function parameters
|
149 |
+
REG_SMOOTH1 = 1 / 100000
|
150 |
+
REG_SMOOTH2 = 1 / 5000
|
151 |
+
REG_SMOOTH3 = 1 / 5000
|
152 |
+
LOSS1 = 1.0
|
153 |
+
LOSS2 = 0.6
|
154 |
+
LOSS3 = 0.3
|
155 |
+
REG_JACOBIAN = 0.1
|
156 |
+
|
157 |
+
LOSS_COEFFICIENT = 1.0
|
158 |
+
REG_COEFFICIENT = 1.0
|
159 |
+
|
160 |
+
DICE_SMOOTH = 1.
|
161 |
+
|
162 |
+
CC_WINDOW = [9,9,9]
|
163 |
+
|
164 |
+
# Adam optimizer
|
165 |
+
LEARNING_RATE = 1e-3
|
166 |
+
B1 = 0.9
|
167 |
+
B2 = 0.999
|
168 |
+
LEARNING_RATE_DECAY = 0.01
|
169 |
+
LEARNING_RATE_DECAY_STEP = 10000 # Update the learning rate every LEARNING_RATE_DECAY_STEP steps
|
170 |
+
OPTIMIZER = 'adam'
|
171 |
+
|
172 |
+
# Network architecture constants
|
173 |
+
LAYER_MAXPOOL = 0
|
174 |
+
LAYER_UPSAMP = 1
|
175 |
+
LAYER_CONV = 2
|
176 |
+
AFFINE_TRANSF = False
|
177 |
+
OUTPUT_LAYER = 3
|
178 |
+
DROPOUT = True
|
179 |
+
DROPOUT_RATE = 0.2
|
180 |
+
MAX_DATA_SIZE = (1000, 1000, 1)
|
181 |
+
PLATEAU_THR = 0.01 # A slope between +-PLATEAU_THR will be considered a plateau for the LR updating function
|
182 |
+
ENCODER_FILTERS = [4, 8, 16, 32, 64]
|
183 |
+
|
184 |
+
# SSIM
|
185 |
+
SSIM_FILTER_SIZE = 11 # Size of Gaussian filter
|
186 |
+
SSIM_FILTER_SIGMA = 1.5 # Width of Gaussian filter
|
187 |
+
SSIM_K1 = 0.01 # Def. 0.01
|
188 |
+
SSIM_K2 = 0.03 # Recommended values 0 < K2 < 0.4
|
189 |
+
MAX_VALUE = 1.0 # Maximum intensity values
|
190 |
+
|
191 |
+
# Mathematic constants
|
192 |
+
EPS = 1e-8
|
193 |
+
EPS_tf = tf.constant(EPS, dtype=tf.float32)
|
194 |
+
LOG2 = tf.math.log(tf.constant(2, dtype=tf.float32))
|
195 |
+
|
196 |
+
# Debug constants
|
197 |
+
VERBOSE = False
|
198 |
+
DEBUG = False
|
199 |
+
DEBUG_TRAINING = False
|
200 |
+
DEBUG_INPUT_DATA = False
|
201 |
+
|
202 |
+
# Plotting
|
203 |
+
FONT_SIZE = 10
|
204 |
+
DPI = 200 # Dots Per Inch
|
205 |
+
|
206 |
+
# Coordinates
|
207 |
+
B = 0 # Batch dimension
|
208 |
+
H = 1 # Height dimension
|
209 |
+
W = 2 # Width dimension
|
210 |
+
D = 3 # Depth
|
211 |
+
C = -1 # Channel dimension
|
212 |
+
|
213 |
+
D_DISP = 2
|
214 |
+
W_DISP = 1
|
215 |
+
H_DISP = 0
|
216 |
+
|
217 |
+
DIMENSIONALITY = 3
|
218 |
+
|
219 |
+
# Interpolation type
|
220 |
+
BIL_INTERP = 0
|
221 |
+
TPS_INTERP = 1
|
222 |
+
CUADRATIC_C = 0.5
|
223 |
+
|
224 |
+
# Data augmentation
|
225 |
+
MAX_DISP = 5 # Test = 15
|
226 |
+
NUM_ROT = 5
|
227 |
+
NUM_FLIPS = 2
|
228 |
+
MAX_ANGLE = 10
|
229 |
+
|
230 |
+
# Thin Plate Splines implementation constants
|
231 |
+
TPS_NUM_CTRL_PTS_PER_AXIS = 4
|
232 |
+
TPS_NUM_CTRL_PTS = np.power(TPS_NUM_CTRL_PTS_PER_AXIS, DIMENSIONALITY)
|
233 |
+
TPS_REG = 0.01
|
234 |
+
DISP_SCALE = 2 # Scaling of the output of the CNN to increase the range of tanh
|
235 |
+
|
236 |
+
|
237 |
+
class CoordinatesGrid:
|
238 |
+
def __init__(self):
|
239 |
+
self.__grid = 0
|
240 |
+
self.__grid_fl = 0
|
241 |
+
self.__norm = False
|
242 |
+
self.__num_pts = 0
|
243 |
+
self.__batches = False
|
244 |
+
self.__shape = None
|
245 |
+
self.__shape_flat = None
|
246 |
+
|
247 |
+
def set_coords_grid(self, img_shape: tf.TensorShape, num_ppa: int = None, batches: bool = False,
|
248 |
+
img_type: tf.DType = tf.float32, norm: bool = False):
|
249 |
+
self.__batches = batches
|
250 |
+
not_batches = not batches # Just to not make a too complex code when indexing the values
|
251 |
+
if num_ppa is None:
|
252 |
+
num_ppa = img_shape
|
253 |
+
if norm:
|
254 |
+
x_coords = tf.linspace(-1., 1.,
|
255 |
+
num_ppa[W - int(not_batches)]) # np.linspace works fine, tf had some issues...
|
256 |
+
y_coords = tf.linspace(-1., 1., num_ppa[H - int(not_batches)]) # num_ppa: number of points per axis
|
257 |
+
z_coords = tf.linspace(-1., 1., num_ppa[D - int(not_batches)])
|
258 |
+
else:
|
259 |
+
x_coords = tf.linspace(0., img_shape[W - int(not_batches)] - 1.,
|
260 |
+
num_ppa[W - int(not_batches)]) # np.linspace works fine, tf had some issues...
|
261 |
+
y_coords = tf.linspace(0., img_shape[H - int(not_batches)] - 1.,
|
262 |
+
num_ppa[H - int(not_batches)]) # num_ppa: number of points per axis
|
263 |
+
z_coords = tf.linspace(0., img_shape[D - int(not_batches)] - 1., num_ppa[D - int(not_batches)])
|
264 |
+
|
265 |
+
coords = tf.meshgrid(x_coords, y_coords, z_coords, indexing='ij')
|
266 |
+
self.__num_pts = num_ppa[W - int(not_batches)] * num_ppa[H - int(not_batches)] * num_ppa[D - int(not_batches)]
|
267 |
+
|
268 |
+
grid = tf.stack([coords[0], coords[1], coords[2]], axis=-1)
|
269 |
+
grid = tf.cast(grid, img_type)
|
270 |
+
|
271 |
+
grid_fl = tf.stack([tf.reshape(coords[0], [-1]),
|
272 |
+
tf.reshape(coords[1], [-1]),
|
273 |
+
tf.reshape(coords[2], [-1])], axis=-1)
|
274 |
+
grid_fl = tf.cast(grid_fl, img_type)
|
275 |
+
|
276 |
+
grid_homogeneous = tf.stack([tf.reshape(coords[0], [-1]),
|
277 |
+
tf.reshape(coords[1], [-1]),
|
278 |
+
tf.reshape(coords[2], [-1]),
|
279 |
+
tf.ones_like(tf.reshape(coords[0], [-1]))], axis=-1)
|
280 |
+
|
281 |
+
self.__shape = np.asarray([num_ppa[W - int(not_batches)], num_ppa[H - int(not_batches)], num_ppa[D - int(not_batches)], 3])
|
282 |
+
total_num_pts = np.prod(self.__shape[:-1])
|
283 |
+
self.__shape_flat = np.asarray([total_num_pts, 3])
|
284 |
+
if batches:
|
285 |
+
grid = tf.expand_dims(grid, axis=0)
|
286 |
+
grid = tf.tile(grid, [img_shape[B], 1, 1, 1, 1])
|
287 |
+
grid_fl = tf.expand_dims(grid_fl, axis=0)
|
288 |
+
grid_fl = tf.tile(grid_fl, [img_shape[B], 1, 1])
|
289 |
+
grid_homogeneous = tf.expand_dims(grid_homogeneous, axis=0)
|
290 |
+
grid_homogeneous = tf.tile(grid_homogeneous, [img_shape[B], 1, 1])
|
291 |
+
self.__shape = np.concatenate([np.asarray((img_shape[B],)), self.__shape])
|
292 |
+
self.__shape_flat = np.concatenate([np.asarray((img_shape[B],)), self.__shape_flat])
|
293 |
+
|
294 |
+
self.__norm = norm
|
295 |
+
self.__grid_fl = grid_fl
|
296 |
+
self.__grid = grid
|
297 |
+
self.__grid_homogeneous = grid_homogeneous
|
298 |
+
|
299 |
+
@property
|
300 |
+
def grid(self):
|
301 |
+
return self.__grid
|
302 |
+
|
303 |
+
@property
|
304 |
+
def size(self):
|
305 |
+
return self.__len__()
|
306 |
+
|
307 |
+
def grid_flat(self, transpose=False):
|
308 |
+
if transpose:
|
309 |
+
if self.__batches:
|
310 |
+
ret = tf.transpose(self.__grid_fl, (0, 2, 1))
|
311 |
+
else:
|
312 |
+
ret = tf.transpose(self.__grid_fl)
|
313 |
+
else:
|
314 |
+
ret = self.__grid_fl
|
315 |
+
return ret
|
316 |
+
|
317 |
+
def grid_homogeneous(self, transpose=False):
|
318 |
+
if transpose:
|
319 |
+
if self.__batches:
|
320 |
+
ret = tf.transpose(self.__grid_homogeneous, (0, 2, 1))
|
321 |
+
else:
|
322 |
+
ret = tf.transpose(self.__grid_homogeneous)
|
323 |
+
else:
|
324 |
+
ret = self.__grid_homogeneous
|
325 |
+
return ret
|
326 |
+
|
327 |
+
@property
|
328 |
+
def is_normalized(self):
|
329 |
+
return self.__norm
|
330 |
+
|
331 |
+
def __len__(self):
|
332 |
+
return tf.size(self.__grid)
|
333 |
+
|
334 |
+
@property
|
335 |
+
def number_pts(self):
|
336 |
+
return self.__num_pts
|
337 |
+
|
338 |
+
@property
|
339 |
+
def shape_grid_flat(self):
|
340 |
+
return self.__shape_flat
|
341 |
+
|
342 |
+
@property
|
343 |
+
def shape(self):
|
344 |
+
return self.__shape
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
COORDS_GRID = CoordinatesGrid()
|
349 |
+
|
350 |
+
|
351 |
+
class VisualizationParameters:
|
352 |
+
def __init__(self):
|
353 |
+
self.__scale = None # See https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.axes.Axes.quiver.html
|
354 |
+
self.__spacing = 5
|
355 |
+
|
356 |
+
def set_spacing(self, img_shape: tf.TensorShape):
|
357 |
+
self.__spacing = int(5 * np.log(img_shape[W]))
|
358 |
+
|
359 |
+
@property
|
360 |
+
def spacing(self):
|
361 |
+
return self.__spacing
|
362 |
+
|
363 |
+
def set_arrow_scale(self, scale: int):
|
364 |
+
self.__scale = scale
|
365 |
+
|
366 |
+
@property
|
367 |
+
def arrow_scale(self):
|
368 |
+
return self.__scale
|
369 |
+
|
370 |
+
|
371 |
+
QUIVER_PARAMS = VisualizationParameters()
|
372 |
+
|
373 |
+
# Configuration file
|
374 |
+
CONF_FILE_NAME = 'configuration.txt'
|
375 |
+
|
376 |
+
|
377 |
+
def summary():
|
378 |
+
return '##### CONFIGURATION: REMOTE {} DEBUG {} DEBUG TRAINING {}' \
|
379 |
+
'\n\t\tLEARNING RATE: {}' \
|
380 |
+
'\n\t\tBATCH SIZE: {}' \
|
381 |
+
'\n\t\tLIMIT NUM SAMPLES: {}' \
|
382 |
+
'\n\t\tLOSS_FNC: {}' \
|
383 |
+
'\n\t\tTRAINING_DATASET: {} ({:.1f}%/{:.1f}%)' \
|
384 |
+
'\n\t\tTEST_DATASET: {}'.format(REMOTE, DEBUG, DEBUG_TRAINING, LEARNING_RATE, BATCH_SIZE, LIMIT_NUM_SAMPLES,
|
385 |
+
LOSS_FNC, TRAINING_DATASET, TRAINING_PERC * 100, (1 - TRAINING_PERC) * 100,
|
386 |
+
TEST_DATASET)
|
387 |
+
|
388 |
+
|
389 |
+
# LOG Severity levers
|
390 |
+
# https://docs.python.org/2/library/logging.html#logging-levels
|
391 |
+
INF = 20 # Information
|
392 |
+
WAR = 30 # Warning
|
393 |
+
ERR = 40 # Error
|
394 |
+
DEB = 10 # Debug
|
395 |
+
CRI = 50 # Critical
|
396 |
+
|
397 |
+
SEVERITY_STR = {INF: 'INFO',
|
398 |
+
WAR: 'WARNING',
|
399 |
+
ERR: 'ERROR',
|
400 |
+
DEB: 'DEBUG',
|
401 |
+
CRI: 'CRITICAL'}
|
402 |
+
|
403 |
+
HL_LOG_FIELD_NAMES = ['Time', 'Epoch', 'Step',
|
404 |
+
'train_loss', 'train_loss_std',
|
405 |
+
'train_loss1', 'train_loss1_std',
|
406 |
+
'train_loss2', 'train_loss2_std',
|
407 |
+
'train_loss3', 'train_loss3_std',
|
408 |
+
'train_NCC', 'train_NCC_std',
|
409 |
+
'val_loss', 'val_loss_std',
|
410 |
+
'val_loss1', 'val_loss1_std',
|
411 |
+
'val_loss2', 'val_loss2_std',
|
412 |
+
'val_loss3', 'val_loss3_std',
|
413 |
+
'val_NCC', 'val_NCC_std']
|
414 |
+
|
415 |
+
# Sobel filters
|
416 |
+
SOBEL_W_2D = tf.constant([[-1., 0., 1.],
|
417 |
+
[-2., 0., 2.],
|
418 |
+
[-1., 0., 1.]], dtype=tf.float32, name='sobel_w_2d')
|
419 |
+
SOBEL_W_3D = tf.tile(tf.expand_dims(SOBEL_W_2D, axis=-1), [1, 1, 3])
|
420 |
+
SOBEL_H_3D = tf.transpose(SOBEL_W_3D, [1, 0, 2])
|
421 |
+
SOBEL_D_3D = tf.transpose(SOBEL_W_3D, [2, 1, 0])
|
422 |
+
|
423 |
+
aux = tf.expand_dims(tf.expand_dims(SOBEL_W_3D, axis=-1), axis=-1)
|
424 |
+
SOBEL_FILTER_W_3D_IMAGE = aux
|
425 |
+
SOBEL_FILTER_W_3D = tf.tile(aux, [1, 1, 1, 3, 3])
|
426 |
+
# tf.nn.conv3d expects the filter in [D, H, W, C_in, C_out] order
|
427 |
+
SOBEL_FILTER_W_3D = tf.transpose(SOBEL_FILTER_W_3D, [2, 0, 1, 3, 4], name='sobel_filter_i_3d')
|
428 |
+
|
429 |
+
aux = tf.expand_dims(tf.expand_dims(SOBEL_H_3D, axis=-1), axis=-1)
|
430 |
+
SOBEL_FILTER_H_3D_IMAGE = aux
|
431 |
+
SOBEL_FILTER_H_3D = tf.tile(aux, [1, 1, 1, 3, 3])
|
432 |
+
SOBEL_FILTER_H_3D = tf.transpose(SOBEL_FILTER_H_3D, [2, 0, 1, 3, 4], name='sobel_filter_j_3d')
|
433 |
+
|
434 |
+
aux = tf.expand_dims(tf.expand_dims(SOBEL_D_3D, axis=-1), axis=-1)
|
435 |
+
SOBEL_FILTER_D_3D_IMAGE = aux
|
436 |
+
SOBEL_FILTER_D_3D = tf.tile(aux, [1, 1, 1, 3, 3])
|
437 |
+
SOBEL_FILTER_D_3D = tf.transpose(SOBEL_FILTER_D_3D, [2, 1, 0, 3, 4], name='sobel_filter_k_3d')
|
438 |
+
|
439 |
+
# Filters for spatial integration of the displacement map
|
440 |
+
INTEG_WIND_SIZE = IMG_SIZE
|
441 |
+
INTEG_STEPS = 4 # VoxelMorph default value for the integration of the stationary velocity field. >4 memory alloc issue
|
442 |
+
INTEG_FILTER_D = tf.ones([INTEG_WIND_SIZE, 1, 1, 1, 1], name='integrate_h_filter')
|
443 |
+
INTEG_FILTER_H = tf.ones([1, INTEG_WIND_SIZE, 1, 1, 1], name='integrate_w_filter')
|
444 |
+
INTEG_FILTER_W = tf.ones([1, 1, INTEG_WIND_SIZE, 1, 1], name='integrate_d_filter')
|
445 |
+
|
446 |
+
# Laplacian filter
|
447 |
+
LAPLACIAN_27_P = tf.constant(np.asarray([np.ones((3, 3)),
|
448 |
+
[[1, 1, 1],
|
449 |
+
[1, -26, 1],
|
450 |
+
[1, 1, 1]],
|
451 |
+
np.ones((3, 3))]), tf.float32)
|
452 |
+
LAPLACIAN_27_P = tf.expand_dims(tf.expand_dims(LAPLACIAN_27_P, axis=-1), axis=-1)
|
453 |
+
LAPLACIAN_27_P = tf.tile(LAPLACIAN_27_P, [1, 1, 1, 3, 3], name='laplacian_27_p')
|
454 |
+
|
455 |
+
|
456 |
+
LAPLACIAN_7_P = tf.constant(np.asarray([[[0, 0, 0],
|
457 |
+
[0, 1, 0],
|
458 |
+
[0, 0, 0]],
|
459 |
+
[[0, 1, 0],
|
460 |
+
[1, -6, 1],
|
461 |
+
[0, 1, 0]],
|
462 |
+
[[0, 0, 0],
|
463 |
+
[0, 1, 0],
|
464 |
+
[0, 0, 0]]]), tf.float32)
|
465 |
+
LAPLACIAN_7_P = tf.expand_dims(tf.expand_dims(LAPLACIAN_7_P, axis=-1), axis=-1)
|
466 |
+
LAPLACIAN_7_P = tf.tile(LAPLACIAN_7_P, [1, 1, 1, 3, 3], name='laplacian_7_p')
|
467 |
+
|
468 |
+
# Constants for bias loss
|
469 |
+
ZERO_WARP = tf.zeros((1,) + DISP_MAP_SHAPE, name='zero_warp')
|
470 |
+
BIAS_WARP_WEIGHT = 1e-02
|
471 |
+
BIAS_AFFINE_WEIGHT = 1e-02
|
472 |
+
|
473 |
+
# Overlapping score
|
474 |
+
OS_SCALE = 10
|
475 |
+
EPS_1 = 1.0
|
476 |
+
EPS_1_tf = tf.constant(EPS_1)
|
477 |
+
|
478 |
+
# LDDMM
|
479 |
+
GAUSSIAN_KERNEL_SHAPE = (8, 8, 8)
|
480 |
+
|
481 |
+
# Constants for MultiLoss layer
|
482 |
+
PRIOR_W = [1., 1 / 60, 1.]
|
483 |
+
MANUAL_W = [1.] * len(PRIOR_W)
|
484 |
+
|
485 |
+
REG_PRIOR_W = [1e-3]
|
486 |
+
REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
487 |
+
|
DeepDeformationMapRegistration/utils/misc.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import errno
|
3 |
+
import nibabel as nb
|
4 |
+
import numpy as np
|
5 |
+
import re
|
6 |
+
|
7 |
+
def try_mkdir(dir):
|
8 |
+
try:
|
9 |
+
os.makedirs(dir)
|
10 |
+
except OSError as err:
|
11 |
+
if err.errno == errno.EEXIST:
|
12 |
+
print("Directory " + dir + " already exists")
|
13 |
+
else:
|
14 |
+
raise ValueError("Can't create dir " + dir)
|
15 |
+
else:
|
16 |
+
print("Created directory " + dir)
|
17 |
+
|
18 |
+
|
19 |
+
def function_decorator(new_name):
|
20 |
+
""""
|
21 |
+
Change the __name__ property of a function using new_name.
|
22 |
+
:param new_name:
|
23 |
+
:return:
|
24 |
+
"""
|
25 |
+
def decorator(func):
|
26 |
+
func.__name__ = new_name
|
27 |
+
return func
|
28 |
+
return decorator
|
DeepDeformationMapRegistration/utils/nifty_utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import errno
|
3 |
+
import nibabel as nb
|
4 |
+
import numpy as np
|
5 |
+
import re
|
6 |
+
import zipfile
|
7 |
+
import tensorflow as tf
|
8 |
+
|
9 |
+
|
10 |
+
TEMP_UNZIP_PATH = '/mnt/EncryptedData1/Users/javier/ext_datasets/LITS17/temp'
|
11 |
+
NII_EXTENSION = '.nii'
|
12 |
+
|
13 |
+
|
14 |
+
def save_nifti(data, save_path):
|
15 |
+
data_nifti = nb.Nifti1Image(data, affine=np.eye(4))
|
16 |
+
|
17 |
+
data_nifti.header.get_xyzt_units()
|
18 |
+
try:
|
19 |
+
data_nifti.to_filename(save_path) # Save as NiBabel file
|
20 |
+
print('Saved {}'.format(save_path))
|
21 |
+
except ValueError:
|
22 |
+
print('Could not save {}'.format(save_path))
|
23 |
+
|
24 |
+
|
25 |
+
def unzip_nii_file(file_path):
|
26 |
+
file_dir, file_name = os.path.split(file_path)
|
27 |
+
file_name = file_name.split('.zip')[0]
|
28 |
+
|
29 |
+
dest_path = os.path.join(TEMP_UNZIP_PATH, file_name)
|
30 |
+
zipfile.ZipFile(file_path).extractall(TEMP_UNZIP_PATH)
|
31 |
+
|
32 |
+
if not os.path.exists(dest_path):
|
33 |
+
print('ERR: File {} not unzip-ed!'.format(file_path))
|
34 |
+
dest_path = None
|
35 |
+
return dest_path
|
36 |
+
|
37 |
+
|
38 |
+
def delete_temp(file_path, verbose=False):
|
39 |
+
assert NII_EXTENSION in file_path
|
40 |
+
os.remove(file_path)
|
41 |
+
if verbose:
|
42 |
+
print('Deleted file: ', file_path)
|
DeepDeformationMapRegistration/utils/operators
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
|
5 |
+
def min_max_norm(img: np.ndarray, out_max_val=1.):
|
6 |
+
out_img = img
|
7 |
+
max_val = np.amax(img)
|
8 |
+
min_val = np.amin(img)
|
9 |
+
if (max_val - min_val) != 0:
|
10 |
+
out_img = (img - min_val) / (max_val - min_val)
|
11 |
+
return out_img * out_max_val
|
12 |
+
|
13 |
+
|
14 |
+
def soft_threshold(x, threshold, name=None):
|
15 |
+
# https://www.tensorflow.org/probability/api_docs/python/tfp/math/soft_threshold
|
16 |
+
with tf.name_scope(name or 'soft_threshold'):
|
17 |
+
x = tf.convert_to_tensor(x, name='x')
|
18 |
+
threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold')
|
19 |
+
return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
|
20 |
+
|
21 |
+
|
22 |
+
def binary_activation(x):
|
23 |
+
# https://stackoverflow.com/questions/37743574/hard-limiting-threshold-activation-function-in-tensorflow
|
24 |
+
cond = tf.less(x, tf.zeros(tf.shape(x)))
|
25 |
+
out = tf.where(cond, tf.zeros(tf.shape(x)), tf.ones(tf.shape(x)))
|
26 |
+
|
27 |
+
return out
|
28 |
+
|
DeepDeformationMapRegistration/utils/user_interface.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import re
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def show_and_select(file_list, msg='Select a file by the number: ', int_if_single=True):
|
7 |
+
# If the selection is a single number, then return that number instead of the list of length 1
|
8 |
+
invalid_selection = True
|
9 |
+
while invalid_selection:
|
10 |
+
for i, f in enumerate(file_list):
|
11 |
+
print('{:03d}) {}'. format(i+1, os.path.split(f)[-1]))
|
12 |
+
|
13 |
+
sel = np.asarray(re.split(',\s|,|\s',input(msg)), np.int) - 1
|
14 |
+
|
15 |
+
if (np.all(sel >= 0)) and (np.all(sel <= len(file_list))):
|
16 |
+
invalid_selection = False
|
17 |
+
sel = [file_list[s] for s in sel]
|
18 |
+
print('Selected: ', ', '.join([os.path.split(f)[-1] for f in sel]))
|
19 |
+
|
20 |
+
if int_if_single:
|
21 |
+
if len(sel) == 1:
|
22 |
+
sel = sel[0]
|
23 |
+
return sel
|
DeepDeformationMapRegistration/utils/visualization.py
ADDED
@@ -0,0 +1,1151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import matplotlib
|
2 |
+
# matplotlib.use('TkAgg')
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from mpl_toolkits.mplot3d import Axes3D
|
5 |
+
import matplotlib.colors as mcolors
|
6 |
+
from matplotlib.lines import Line2D
|
7 |
+
from matplotlib import cm
|
8 |
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
9 |
+
import tensorflow as tf
|
10 |
+
import numpy as np
|
11 |
+
import pyVesselRegistration_constants as const
|
12 |
+
from skimage.exposure import rescale_intensity
|
13 |
+
import scipy.misc as scpmisc
|
14 |
+
import os
|
15 |
+
|
16 |
+
THRES = 0.9
|
17 |
+
|
18 |
+
# COLOR MAPS
|
19 |
+
chunks = np.linspace(0, 1, 10)
|
20 |
+
cmap1 = plt.get_cmap('hsv', 4)
|
21 |
+
# cmaplist = [cmap1(i) for i in range(cmap1.N)]
|
22 |
+
cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
|
23 |
+
cmaplist[0] = (1, 1, 1, 1.0)
|
24 |
+
cmap1 = mcolors.LinearSegmentedColormap.from_list('custom', cmaplist, cmap1.N)
|
25 |
+
|
26 |
+
colors = [(0, 0, 1, i) for i in chunks]
|
27 |
+
cmap2 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
|
28 |
+
|
29 |
+
colors = [(230 / 255, 97 / 255, 1 / 255, i) for i in chunks]
|
30 |
+
cmap3 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
|
31 |
+
|
32 |
+
colors = [(128 / 255, 0 / 255, 32 / 255, i) for i in chunks]
|
33 |
+
cmap4 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
|
34 |
+
|
35 |
+
cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
|
36 |
+
|
37 |
+
|
38 |
+
def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
|
39 |
+
if dimensionality == 2:
|
40 |
+
_plot_2d(sample, ax, c, name=name)
|
41 |
+
elif dimensionality == 3:
|
42 |
+
_plot_3d(sample, ax, c, name=name)
|
43 |
+
else:
|
44 |
+
raise ValueError('Invalid valud for dimensionality. Expected int 2 or 3')
|
45 |
+
|
46 |
+
|
47 |
+
def matrix_to_orthographicProjection(matrix: np.ndarray, ret_list=False):
|
48 |
+
""" Given a 3D matrix, returns the three orthographic projections: top, front, right.
|
49 |
+
Top corresponds to dimensions 1 and 2
|
50 |
+
Front corresponds to dimensions 0 and 1
|
51 |
+
Right corresponds to dimensions 0 and 2
|
52 |
+
|
53 |
+
:param matrix: 3D matrix
|
54 |
+
:param ret_list: return a list instead of an array (optional)
|
55 |
+
:return: list or array with the three views [top, front, right]
|
56 |
+
"""
|
57 |
+
top = _getProjection(matrix, dim=0) # YZ
|
58 |
+
front = _getProjection(matrix, dim=2) # XY
|
59 |
+
right = _getProjection(matrix, dim=1) # XZ
|
60 |
+
|
61 |
+
if ret_list:
|
62 |
+
return top, front, right
|
63 |
+
else:
|
64 |
+
return np.asarray([top, front, right])
|
65 |
+
|
66 |
+
|
67 |
+
def _getProjection(matrix: np.ndarray, dim: int):
|
68 |
+
orth_view = matrix.sum(axis=dim, dtype=float)
|
69 |
+
orth_view = orth_view > 0.0
|
70 |
+
orth_view.astype(np.float)
|
71 |
+
|
72 |
+
return orth_view
|
73 |
+
|
74 |
+
|
75 |
+
def orthographicProjection_to_matrix(top: np.ndarray, front: np.ndarray, right: np.ndarray):
|
76 |
+
""" Given the three orthographic projections, it returns a 3D-view of the object based on back projection
|
77 |
+
|
78 |
+
:param top: 2D view top view
|
79 |
+
:param front: 2D front view
|
80 |
+
:param right: 2D right view
|
81 |
+
:return: matrix with the 3D-view
|
82 |
+
"""
|
83 |
+
top_mat = np.tile(top, (front.shape[0], 1, 1))
|
84 |
+
front_mat = np.tile(top, (right.shape[1], 1, 1))
|
85 |
+
right_mat = np.tile(top, (top.shape[0], 1, 1))
|
86 |
+
|
87 |
+
reconstruction = np.zeros((front.shape[0], right.shape[1], top.shape[0]))
|
88 |
+
iter = np.nditer([top_mat, front_mat, right_mat, reconstruction], flags=['multi_index'], op_flags=['readwrite'])
|
89 |
+
while not iter.finished:
|
90 |
+
if iter[0] and iter[1] and iter[2]:
|
91 |
+
iter[3] = 1
|
92 |
+
iter.iternext()
|
93 |
+
|
94 |
+
return reconstruction
|
95 |
+
|
96 |
+
|
97 |
+
def _plot_2d(sample: np.ndarray, ax=None, c=None, name=None):
|
98 |
+
if isinstance(sample, tf.Tensor):
|
99 |
+
sample = sample.eval(session=tf.Session())
|
100 |
+
|
101 |
+
x_range = list()
|
102 |
+
y_range = list()
|
103 |
+
marker_size = list()
|
104 |
+
for idx, val in np.ndenumerate(sample):
|
105 |
+
if val >= THRES:
|
106 |
+
x_range.append(idx[0])
|
107 |
+
y_range.append(idx[1])
|
108 |
+
marker_size.append(val ** 2)
|
109 |
+
|
110 |
+
if not ax:
|
111 |
+
fig = plt.figure()
|
112 |
+
ax = fig.add_subplot(111)
|
113 |
+
|
114 |
+
if c:
|
115 |
+
ax.scatter(x_range, y_range, c=c, s=marker_size)
|
116 |
+
else:
|
117 |
+
ax.scatter(x_range, y_range, s=marker_size)
|
118 |
+
|
119 |
+
ax.set_xlabel('X')
|
120 |
+
ax.set_ylabel('Y')
|
121 |
+
if name:
|
122 |
+
ax.set_title(name)
|
123 |
+
|
124 |
+
return ax
|
125 |
+
|
126 |
+
|
127 |
+
def _plot_3d(sample: np.ndarray, ax=None, c=None, name=None):
|
128 |
+
from mpl_toolkits.mplot3d import Axes3D
|
129 |
+
if isinstance(sample, tf.Tensor):
|
130 |
+
sample = sample.eval(session=tf.Session())
|
131 |
+
|
132 |
+
x_range = list()
|
133 |
+
y_range = list()
|
134 |
+
z_range = list()
|
135 |
+
marker_size = list()
|
136 |
+
for idx, val in np.ndenumerate(sample):
|
137 |
+
if val >= THRES:
|
138 |
+
x_range.append(idx[0])
|
139 |
+
y_range.append(idx[1])
|
140 |
+
z_range.append(idx[2])
|
141 |
+
marker_size.append(val ** 2)
|
142 |
+
|
143 |
+
print('Found ', len(x_range), ' points')
|
144 |
+
# x_range = np.linspace(start=0, stop=sample.shape[0], num=sample.shape[0])
|
145 |
+
# y_range = np.linspace(start=0, stop=sample.shape[1], num=sample.shape[1])
|
146 |
+
# z_range = np.linspace(start=0, stop=sample.shape[2], num=sample.shape[2])
|
147 |
+
#
|
148 |
+
# sample_flat = sample.flatten(order='C')
|
149 |
+
|
150 |
+
if len(x_range):
|
151 |
+
if not ax:
|
152 |
+
fig = plt.figure()
|
153 |
+
ax = fig.add_subplot(111, projection='3d')
|
154 |
+
|
155 |
+
if c:
|
156 |
+
ax.scatter(x_range, y_range, z_range, c=c, s=marker_size)
|
157 |
+
else:
|
158 |
+
ax.scatter(x_range, y_range, z_range, s=marker_size)
|
159 |
+
# ax.scatter(x_range, y_range, z_range, s=marker_size)#, c=sample_flat)
|
160 |
+
|
161 |
+
ax.set_xlabel('X')
|
162 |
+
ax.set_ylabel('Y')
|
163 |
+
ax.set_zlabel('Z')
|
164 |
+
if name:
|
165 |
+
ax.set_title(name)
|
166 |
+
|
167 |
+
return ax
|
168 |
+
else:
|
169 |
+
print('Nothing to plot')
|
170 |
+
return None
|
171 |
+
|
172 |
+
|
173 |
+
def plot_training(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None):
|
174 |
+
if fig is not None:
|
175 |
+
fig.clear()
|
176 |
+
plt.figure(fig.number)
|
177 |
+
else:
|
178 |
+
fig = plt.figure(dpi=const.DPI)
|
179 |
+
|
180 |
+
ax_fix = fig.add_subplot(231)
|
181 |
+
im_fix = ax_fix.imshow(list_imgs[0][:, :, 0])
|
182 |
+
ax_fix.set_title('Fix image', fontsize=const.FONT_SIZE)
|
183 |
+
ax_fix.tick_params(axis='both',
|
184 |
+
which='both',
|
185 |
+
bottom=False,
|
186 |
+
left=False,
|
187 |
+
labelleft=False,
|
188 |
+
labelbottom=False)
|
189 |
+
ax_mov = fig.add_subplot(232)
|
190 |
+
im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
|
191 |
+
ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
|
192 |
+
ax_mov.tick_params(axis='both',
|
193 |
+
which='both',
|
194 |
+
bottom=False,
|
195 |
+
left=False,
|
196 |
+
labelleft=False,
|
197 |
+
labelbottom=False)
|
198 |
+
|
199 |
+
ax_pred_im = fig.add_subplot(233)
|
200 |
+
im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
|
201 |
+
ax_pred_im.set_title('Prediction', fontsize=const.FONT_SIZE)
|
202 |
+
ax_pred_im.tick_params(axis='both',
|
203 |
+
which='both',
|
204 |
+
bottom=False,
|
205 |
+
left=False,
|
206 |
+
labelleft=False,
|
207 |
+
labelbottom=False)
|
208 |
+
|
209 |
+
ax_pred_disp = fig.add_subplot(234)
|
210 |
+
if affine_transf:
|
211 |
+
fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
|
212 |
+
[0.0, 0.0, 0.0, 0.0],
|
213 |
+
[0.0, 0.0, 0.0, 0.0],
|
214 |
+
[0.0, 0.0, 0.0, 0.0]])
|
215 |
+
|
216 |
+
bottom = np.asarray([0, 0, 0, 1])
|
217 |
+
|
218 |
+
transf_mat = np.reshape(list_imgs[3], (2, 3))
|
219 |
+
transf_mat = np.stack([transf_mat, bottom], axis=0)
|
220 |
+
|
221 |
+
im_pred_disp = ax_pred_disp.imshow(fake_bg)
|
222 |
+
for i in range(4):
|
223 |
+
for j in range(4):
|
224 |
+
ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
225 |
+
|
226 |
+
ax_pred_disp.set_title('Affine transformation matrix')
|
227 |
+
|
228 |
+
else:
|
229 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3])
|
230 |
+
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
231 |
+
ax_pred_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
232 |
+
ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
|
233 |
+
ax_pred_disp.tick_params(axis='both',
|
234 |
+
which='both',
|
235 |
+
bottom=False,
|
236 |
+
left=False,
|
237 |
+
labelleft=False,
|
238 |
+
labelbottom=False)
|
239 |
+
|
240 |
+
ax_gt_disp = fig.add_subplot(235)
|
241 |
+
if affine_transf:
|
242 |
+
fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
|
243 |
+
[0.0, 0.0, 0.0, 0.0],
|
244 |
+
[0.0, 0.0, 0.0, 0.0],
|
245 |
+
[0.0, 0.0, 0.0, 0.0]])
|
246 |
+
|
247 |
+
bottom = np.asarray([0, 0, 0, 1])
|
248 |
+
|
249 |
+
transf_mat = np.reshape(list_imgs[4], (2, 3))
|
250 |
+
transf_mat = np.stack([transf_mat, bottom], axis=0)
|
251 |
+
|
252 |
+
im_gt_disp = ax_pred_disp.imshow(fake_bg)
|
253 |
+
for i in range(4):
|
254 |
+
for j in range(4):
|
255 |
+
ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
256 |
+
|
257 |
+
ax_pred_disp.set_title('Affine transformation matrix')
|
258 |
+
|
259 |
+
else:
|
260 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[4])
|
261 |
+
im_gt_disp = ax_gt_disp.imshow(s, interpolation='none', aspect='equal')
|
262 |
+
ax_gt_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
263 |
+
ax_gt_disp.set_title('GT disp map', fontsize=const.FONT_SIZE)
|
264 |
+
ax_gt_disp.tick_params(axis='both',
|
265 |
+
which='both',
|
266 |
+
bottom=False,
|
267 |
+
left=False,
|
268 |
+
labelleft=False,
|
269 |
+
labelbottom=False)
|
270 |
+
|
271 |
+
cb_fix = _set_colorbar(fig, ax_fix, im_fix, False)
|
272 |
+
cb_mov = _set_colorbar(fig, ax_mov, im_mov, False)
|
273 |
+
cb_pred = _set_colorbar(fig, ax_pred_im, im_pred_im, False)
|
274 |
+
cb_pred_disp = _set_colorbar(fig, ax_pred_disp, im_pred_disp, False)
|
275 |
+
cd_gt_disp = _set_colorbar(fig, ax_gt_disp, im_gt_disp, False)
|
276 |
+
|
277 |
+
if filename is not None:
|
278 |
+
plt.savefig(filename, format='png') # Call before show
|
279 |
+
if not const.REMOTE:
|
280 |
+
plt.show()
|
281 |
+
else:
|
282 |
+
plt.close()
|
283 |
+
return fig
|
284 |
+
|
285 |
+
|
286 |
+
def save_centreline_img(img, title, filename, fig=None):
|
287 |
+
if fig is not None:
|
288 |
+
fig.clear()
|
289 |
+
plt.figure(fig.number)
|
290 |
+
else:
|
291 |
+
fig = plt.figure(dpi=const.DPI)
|
292 |
+
|
293 |
+
dim = len(img.shape[:-1])
|
294 |
+
|
295 |
+
if dim == 2:
|
296 |
+
ax = fig.add_subplot(111)
|
297 |
+
fig.suptitle(title)
|
298 |
+
im = ax.imshow(img[..., 0], cmap=cmap_bin)
|
299 |
+
ax.tick_params(axis='both',
|
300 |
+
which='both',
|
301 |
+
bottom=False,
|
302 |
+
left=False,
|
303 |
+
labelleft=False,
|
304 |
+
labelbottom=False)
|
305 |
+
|
306 |
+
#cb = _set_colorbar(fig, ax, im, False)
|
307 |
+
else:
|
308 |
+
ax = fig.add_subplot(111, projection='3d')
|
309 |
+
fig.suptitle(title)
|
310 |
+
im = ax.voxels(img[0, ..., 0] > 0.0)
|
311 |
+
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
312 |
+
|
313 |
+
ax.tick_params(axis='both',
|
314 |
+
which='both',
|
315 |
+
bottom=False,
|
316 |
+
left=False,
|
317 |
+
labelleft=False,
|
318 |
+
labelbottom=False)
|
319 |
+
|
320 |
+
plt.savefig(filename, format='png')
|
321 |
+
plt.close()
|
322 |
+
|
323 |
+
|
324 |
+
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None):
|
325 |
+
if fig is not None:
|
326 |
+
fig.clear()
|
327 |
+
plt.figure(fig.number)
|
328 |
+
else:
|
329 |
+
fig = plt.figure(dpi=const.DPI)
|
330 |
+
|
331 |
+
dim = disp_map.shape[-1]
|
332 |
+
|
333 |
+
if dim == 2:
|
334 |
+
ax_x = fig.add_subplot(131)
|
335 |
+
ax_x.set_title('H displacement')
|
336 |
+
im_x = ax_x.imshow(disp_map[..., const.H_DISP])
|
337 |
+
ax_x.tick_params(axis='both',
|
338 |
+
which='both',
|
339 |
+
bottom=False,
|
340 |
+
left=False,
|
341 |
+
labelleft=False,
|
342 |
+
labelbottom=False)
|
343 |
+
cb_x = _set_colorbar(fig, ax_x, im_x, False)
|
344 |
+
|
345 |
+
ax_y = fig.add_subplot(132)
|
346 |
+
ax_y.set_title('W displacement')
|
347 |
+
im_y = ax_y.imshow(disp_map[..., const.W_DISP])
|
348 |
+
ax_y.tick_params(axis='both',
|
349 |
+
which='both',
|
350 |
+
bottom=False,
|
351 |
+
left=False,
|
352 |
+
labelleft=False,
|
353 |
+
labelbottom=False)
|
354 |
+
cb_y = _set_colorbar(fig, ax_y, im_y, False)
|
355 |
+
|
356 |
+
ax = fig.add_subplot(133)
|
357 |
+
if affine_transf:
|
358 |
+
fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
|
359 |
+
[0.0, 0.0, 0.0, 0.0],
|
360 |
+
[0.0, 0.0, 0.0, 0.0],
|
361 |
+
[0.0, 0.0, 0.0, 0.0]])
|
362 |
+
|
363 |
+
bottom = np.asarray([0, 0, 0, 1])
|
364 |
+
|
365 |
+
transf_mat = np.reshape(disp_map, (2, 3))
|
366 |
+
transf_mat = np.stack([transf_mat, bottom], axis=0)
|
367 |
+
|
368 |
+
im = ax.imshow(fake_bg)
|
369 |
+
for i in range(4):
|
370 |
+
for j in range(4):
|
371 |
+
ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
372 |
+
|
373 |
+
else:
|
374 |
+
c, d, s = _prepare_quiver_map(disp_map, dim=dim)
|
375 |
+
im = ax.imshow(s, interpolation='none', aspect='equal')
|
376 |
+
ax.quiver(c[const.H_DISP], c[const.W_DISP], d[const.H_DISP], d[const.W_DISP],
|
377 |
+
scale=const.QUIVER_PARAMS.arrow_scale)
|
378 |
+
cb = _set_colorbar(fig, ax, im, False)
|
379 |
+
ax.set_title('Displacement map')
|
380 |
+
ax.tick_params(axis='both',
|
381 |
+
which='both',
|
382 |
+
bottom=False,
|
383 |
+
left=False,
|
384 |
+
labelleft=False,
|
385 |
+
labelbottom=False)
|
386 |
+
fig.suptitle(title)
|
387 |
+
else:
|
388 |
+
ax = fig.add_subplot(111, projection='3d')
|
389 |
+
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
|
390 |
+
ax.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP], d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
|
391 |
+
norm=True)
|
392 |
+
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
393 |
+
fig.suptitle('Displacement map')
|
394 |
+
ax.tick_params(axis='both', # Same parameters as in 2D https://matplotlib.org/api/_as_gen/mpl_toolkits.mplot3d.axes3d.Axes3D.html
|
395 |
+
which='both',
|
396 |
+
bottom=False,
|
397 |
+
left=False,
|
398 |
+
labelleft=False,
|
399 |
+
labelbottom=False)
|
400 |
+
fig.suptitle(title)
|
401 |
+
|
402 |
+
plt.savefig(filename, format='png')
|
403 |
+
plt.close()
|
404 |
+
|
405 |
+
|
406 |
+
def plot_training_and_validation(list_imgs: [np.ndarray], affine_transf=True, filename='img', fig=None,
|
407 |
+
title_first_row='TRAINING', title_second_row='VALIDATION'):
|
408 |
+
if fig is not None:
|
409 |
+
fig.clear()
|
410 |
+
plt.figure(fig.number)
|
411 |
+
else:
|
412 |
+
fig = plt.figure(dpi=const.DPI)
|
413 |
+
|
414 |
+
dim = len(list_imgs[0].shape[:-1])
|
415 |
+
|
416 |
+
if dim == 2:
|
417 |
+
# TRAINING
|
418 |
+
ax_input = fig.add_subplot(241)
|
419 |
+
ax_input.set_ylabel(title_first_row, fontsize=const.FONT_SIZE)
|
420 |
+
im_fix = ax_input.imshow(list_imgs[0][:, :, 0])
|
421 |
+
ax_input.set_title('Fix image', fontsize=const.FONT_SIZE)
|
422 |
+
ax_input.tick_params(axis='both',
|
423 |
+
which='both',
|
424 |
+
bottom=False,
|
425 |
+
left=False,
|
426 |
+
labelleft=False,
|
427 |
+
labelbottom=False)
|
428 |
+
ax_mov = fig.add_subplot(242)
|
429 |
+
im_mov = ax_mov.imshow(list_imgs[1][:, :, 0])
|
430 |
+
ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
|
431 |
+
ax_mov.tick_params(axis='both',
|
432 |
+
which='both',
|
433 |
+
bottom=False,
|
434 |
+
left=False,
|
435 |
+
labelleft=False,
|
436 |
+
labelbottom=False)
|
437 |
+
|
438 |
+
ax_pred_im = fig.add_subplot(244)
|
439 |
+
im_pred_im = ax_pred_im.imshow(list_imgs[2][:, :, 0])
|
440 |
+
ax_pred_im.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
|
441 |
+
ax_pred_im.tick_params(axis='both',
|
442 |
+
which='both',
|
443 |
+
bottom=False,
|
444 |
+
left=False,
|
445 |
+
labelleft=False,
|
446 |
+
labelbottom=False)
|
447 |
+
|
448 |
+
ax_pred_disp = fig.add_subplot(243)
|
449 |
+
if affine_transf:
|
450 |
+
fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
|
451 |
+
[0.0, 0.0, 0.0, 0.0],
|
452 |
+
[0.0, 0.0, 0.0, 0.0],
|
453 |
+
[0.0, 0.0, 0.0, 0.0]])
|
454 |
+
|
455 |
+
bottom = np.asarray([0, 0, 0, 1])
|
456 |
+
|
457 |
+
transf_mat = np.reshape(list_imgs[3], (2, 3))
|
458 |
+
transf_mat = np.stack([transf_mat, bottom], axis=0)
|
459 |
+
|
460 |
+
im_pred_disp = ax_pred_disp.imshow(fake_bg)
|
461 |
+
for i in range(4):
|
462 |
+
for j in range(4):
|
463 |
+
ax_pred_disp.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
464 |
+
|
465 |
+
ax_pred_disp.set_title('Affine transformation matrix')
|
466 |
+
|
467 |
+
else:
|
468 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(list_imgs[3], dim=dim)
|
469 |
+
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
470 |
+
ax_pred_disp.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
471 |
+
ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
|
472 |
+
ax_pred_disp.tick_params(axis='both',
|
473 |
+
which='both',
|
474 |
+
bottom=False,
|
475 |
+
left=False,
|
476 |
+
labelleft=False,
|
477 |
+
labelbottom=False)
|
478 |
+
|
479 |
+
# VALIDATION
|
480 |
+
axinput_val = fig.add_subplot(245)
|
481 |
+
axinput_val.set_ylabel(title_second_row, fontsize=const.FONT_SIZE)
|
482 |
+
im_fix_val = axinput_val.imshow(list_imgs[4][:, :, 0])
|
483 |
+
axinput_val.set_title('Fix image', fontsize=const.FONT_SIZE)
|
484 |
+
axinput_val.tick_params(axis='both',
|
485 |
+
which='both',
|
486 |
+
bottom=False,
|
487 |
+
left=False,
|
488 |
+
labelleft=False,
|
489 |
+
labelbottom=False)
|
490 |
+
ax_mov_val = fig.add_subplot(246)
|
491 |
+
im_mov_val = ax_mov_val.imshow(list_imgs[5][:, :, 0])
|
492 |
+
ax_mov_val.set_title('Moving image', fontsize=const.FONT_SIZE)
|
493 |
+
ax_mov_val.tick_params(axis='both',
|
494 |
+
which='both',
|
495 |
+
bottom=False,
|
496 |
+
left=False,
|
497 |
+
labelleft=False,
|
498 |
+
labelbottom=False)
|
499 |
+
|
500 |
+
ax_pred_im_val = fig.add_subplot(248)
|
501 |
+
im_pred_im_val = ax_pred_im_val.imshow(list_imgs[6][:, :, 0])
|
502 |
+
ax_pred_im_val.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
|
503 |
+
ax_pred_im_val.tick_params(axis='both',
|
504 |
+
which='both',
|
505 |
+
bottom=False,
|
506 |
+
left=False,
|
507 |
+
labelleft=False,
|
508 |
+
labelbottom=False)
|
509 |
+
|
510 |
+
ax_pred_disp_val = fig.add_subplot(247)
|
511 |
+
if affine_transf:
|
512 |
+
fake_bg = np.array([[0.0, 0.0, 0.0, 0.0],
|
513 |
+
[0.0, 0.0, 0.0, 0.0],
|
514 |
+
[0.0, 0.0, 0.0, 0.0],
|
515 |
+
[0.0, 0.0, 0.0, 0.0]])
|
516 |
+
|
517 |
+
bottom = np.asarray([0, 0, 0, 1])
|
518 |
+
|
519 |
+
transf_mat = np.reshape(list_imgs[7], (2, 3))
|
520 |
+
transf_mat = np.stack([transf_mat, bottom], axis=0)
|
521 |
+
|
522 |
+
im_pred_disp_val = ax_pred_disp_val.imshow(fake_bg)
|
523 |
+
for i in range(4):
|
524 |
+
for j in range(4):
|
525 |
+
ax_pred_disp_val.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
526 |
+
|
527 |
+
ax_pred_disp_val.set_title('Affine transformation matrix')
|
528 |
+
|
529 |
+
else:
|
530 |
+
c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
|
531 |
+
im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
|
532 |
+
ax_pred_disp_val.quiver(c[0], c[1], d[0], d[1], scale=const.QUIVER_PARAMS.arrow_scale)
|
533 |
+
ax_pred_disp_val.set_title('Pred disp map', fontsize=const.FONT_SIZE)
|
534 |
+
ax_pred_disp_val.tick_params(axis='both',
|
535 |
+
which='both',
|
536 |
+
bottom=False,
|
537 |
+
left=False,
|
538 |
+
labelleft=False,
|
539 |
+
labelbottom=False)
|
540 |
+
|
541 |
+
cb_fix = _set_colorbar(fig, ax_input, im_fix, False)
|
542 |
+
cb_mov = _set_colorbar(fig, ax_mov, im_mov, False)
|
543 |
+
cb_pred = _set_colorbar(fig, ax_pred_im, im_pred_im, False)
|
544 |
+
cb_pred_disp = _set_colorbar(fig, ax_pred_disp, im_pred_disp, False)
|
545 |
+
|
546 |
+
cd_fix_val = _set_colorbar(fig, axinput_val, im_fix_val, False)
|
547 |
+
cb_mov_val = _set_colorbar(fig, ax_mov_val, im_mov_val, False)
|
548 |
+
cb_pred_val = _set_colorbar(fig, ax_pred_im_val, im_pred_im_val, False)
|
549 |
+
cb_pred_disp_val = _set_colorbar(fig, ax_pred_disp_val, im_pred_disp_val, False)
|
550 |
+
|
551 |
+
else:
|
552 |
+
# 3D
|
553 |
+
# TRAINING
|
554 |
+
ax_input = fig.add_subplot(231, projection='3d')
|
555 |
+
ax_input.set_ylabel(title_first_row, fontsize=const.FONT_SIZE)
|
556 |
+
im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
557 |
+
im_mov = ax_input.voxels(list_imgs[1][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
|
558 |
+
ax_input.set_title('Fix image', fontsize=const.FONT_SIZE)
|
559 |
+
ax_input.tick_params(axis='both',
|
560 |
+
which='both',
|
561 |
+
bottom=False,
|
562 |
+
left=False,
|
563 |
+
labelleft=False,
|
564 |
+
labelbottom=False)
|
565 |
+
|
566 |
+
ax_pred_im = fig.add_subplot(232, projection='3d')
|
567 |
+
im_pred_im = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction')
|
568 |
+
im_fix = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
569 |
+
ax_pred_im.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
|
570 |
+
ax_pred_im.tick_params(axis='both',
|
571 |
+
which='both',
|
572 |
+
bottom=False,
|
573 |
+
left=False,
|
574 |
+
labelleft=False,
|
575 |
+
labelbottom=False)
|
576 |
+
|
577 |
+
ax_pred_disp = fig.add_subplot(233, projection='3d')
|
578 |
+
|
579 |
+
c, d, s = _prepare_quiver_map(list_imgs[3], dim=dim)
|
580 |
+
im_pred_disp = ax_pred_disp.imshow(s, interpolation='none', aspect='equal')
|
581 |
+
ax_pred_disp.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
|
582 |
+
d[const.H_DISP], d[const.W_DISP], d[const.D_DISP], scale=const.QUIVER_PARAMS.arrow_scale)
|
583 |
+
ax_pred_disp.set_title('Pred disp map', fontsize=const.FONT_SIZE)
|
584 |
+
ax_pred_disp.tick_params(axis='both',
|
585 |
+
which='both',
|
586 |
+
bottom=False,
|
587 |
+
left=False,
|
588 |
+
labelleft=False,
|
589 |
+
labelbottom=False)
|
590 |
+
|
591 |
+
# VALIDATION
|
592 |
+
axinput_val = fig.add_subplot(234, projection='3d')
|
593 |
+
axinput_val.set_ylabel(title_second_row, fontsize=const.FONT_SIZE)
|
594 |
+
im_fix_val = ax_input.voxels(list_imgs[4][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
|
595 |
+
im_mov_val = ax_input.voxels(list_imgs[5][..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving (val)')
|
596 |
+
axinput_val.set_title('Fix image', fontsize=const.FONT_SIZE)
|
597 |
+
axinput_val.tick_params(axis='both',
|
598 |
+
which='both',
|
599 |
+
bottom=False,
|
600 |
+
left=False,
|
601 |
+
labelleft=False,
|
602 |
+
labelbottom=False)
|
603 |
+
|
604 |
+
ax_pred_im_val = fig.add_subplot(235, projection='3d')
|
605 |
+
im_pred_im_val = ax_input.voxels(list_imgs[2][..., 0] > 0.0, facecolors='green', edgecolors='green', label='Prediction (val)')
|
606 |
+
im_fix_val = ax_input.voxels(list_imgs[0][..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed (val)')
|
607 |
+
ax_pred_im_val.set_title('Predicted fix image', fontsize=const.FONT_SIZE)
|
608 |
+
ax_pred_im_val.tick_params(axis='both',
|
609 |
+
which='both',
|
610 |
+
bottom=False,
|
611 |
+
left=False,
|
612 |
+
labelleft=False,
|
613 |
+
labelbottom=False)
|
614 |
+
|
615 |
+
ax_pred_disp_val = fig.add_subplot(236, projection='3d')
|
616 |
+
c, d, s = _prepare_quiver_map(list_imgs[7], dim=dim)
|
617 |
+
im_pred_disp_val = ax_pred_disp_val.imshow(s, interpolation='none', aspect='equal')
|
618 |
+
ax_pred_disp_val.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
|
619 |
+
d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
|
620 |
+
scale=const.QUIVER_PARAMS.arrow_scale)
|
621 |
+
ax_pred_disp_val.set_title('Pred disp map', fontsize=const.FONT_SIZE)
|
622 |
+
ax_pred_disp_val.tick_params(axis='both',
|
623 |
+
which='both',
|
624 |
+
bottom=False,
|
625 |
+
left=False,
|
626 |
+
labelleft=False,
|
627 |
+
labelbottom=False)
|
628 |
+
|
629 |
+
if filename is not None:
|
630 |
+
plt.savefig(filename, format='png') # Call before show
|
631 |
+
if not const.REMOTE:
|
632 |
+
plt.show()
|
633 |
+
else:
|
634 |
+
plt.close()
|
635 |
+
return fig
|
636 |
+
|
637 |
+
|
638 |
+
def _set_colorbar(fig, ax, im, drawedges=True):
|
639 |
+
div = make_axes_locatable(ax)
|
640 |
+
im_cax = div.append_axes('right', size='5%', pad=0.05)
|
641 |
+
im_cb = fig.colorbar(im, cax=im_cax, drawedges=drawedges, shrink=0.5, orientation='vertical')
|
642 |
+
im_cb.ax.tick_params(labelsize=5)
|
643 |
+
|
644 |
+
return im_cb
|
645 |
+
|
646 |
+
|
647 |
+
def _prepare_quiver_map(disp_map: np.ndarray, dim=2, spc=const.QUIVER_PARAMS.spacing):
|
648 |
+
if isinstance(disp_map, tf.Tensor):
|
649 |
+
if tf.executing_eagerly():
|
650 |
+
disp_map = disp_map.numpy()
|
651 |
+
else:
|
652 |
+
disp_map = disp_map.eval()
|
653 |
+
dx = disp_map[..., const.H_DISP]
|
654 |
+
dy = disp_map[..., const.W_DISP]
|
655 |
+
if dim > 2:
|
656 |
+
dz = disp_map[..., const.D_DISP]
|
657 |
+
|
658 |
+
img_size_x = disp_map.shape[const.H_DISP]
|
659 |
+
img_size_y = disp_map.shape[const.W_DISP]
|
660 |
+
if dim > 2:
|
661 |
+
img_size_z = disp_map.shape[const.D_DISP]
|
662 |
+
|
663 |
+
if dim > 2:
|
664 |
+
s = np.sqrt(np.square(dx) + np.square(dy) + np.square(dz))
|
665 |
+
s = np.reshape(s, [img_size_x, img_size_y, img_size_z])
|
666 |
+
|
667 |
+
cx, cy, cz = np.meshgrid(list(range(0, img_size_x)), list(range(0, img_size_y)), list(range(0, img_size_z)),
|
668 |
+
indexing='ij')
|
669 |
+
c = [cx[::spc, ::spc, ::spc], cy[::spc, ::spc, ::spc], cz[::spc, ::spc, ::spc]]
|
670 |
+
d = [dx[::spc, ::spc, ::spc], dy[::spc, ::spc, ::spc], dz[::spc, ::spc, ::spc]]
|
671 |
+
else:
|
672 |
+
s = np.sqrt(np.square(dx) + np.square(dy))
|
673 |
+
s = np.reshape(s, [img_size_x, img_size_y])
|
674 |
+
|
675 |
+
cx, cy = np.meshgrid(list(range(0, img_size_x)), list(range(0, img_size_y)))
|
676 |
+
c = [cx[::spc, ::spc], cy[::spc, ::spc]]
|
677 |
+
d = [dx[::spc, ::spc], dy[::spc, ::spc]]
|
678 |
+
|
679 |
+
return c, d, s
|
680 |
+
|
681 |
+
|
682 |
+
def _prepare_colormap(disp_map: np.ndarray):
|
683 |
+
if isinstance(disp_map, tf.Tensor):
|
684 |
+
disp_map = disp_map.eval()
|
685 |
+
dx = disp_map[:, :, 0]
|
686 |
+
dy = disp_map[:, :, 1]
|
687 |
+
|
688 |
+
mod_img = np.zeros_like(dx)
|
689 |
+
|
690 |
+
for i in range(dx.shape[0]):
|
691 |
+
for j in range(dx.shape[1]):
|
692 |
+
vec = np.asarray([dx[i, j], dy[i, j]])
|
693 |
+
mod_img[i, j] = np.linalg.norm(vec, ord=2)
|
694 |
+
|
695 |
+
p_l, p_h = np.percentile(mod_img, (0, 100))
|
696 |
+
mod_img = rescale_intensity(mod_img, in_range=(p_l, p_h), out_range=(0, 255))
|
697 |
+
|
698 |
+
return mod_img
|
699 |
+
|
700 |
+
|
701 |
+
def plot_input_data(fix_img, mov_img, img_size=(64, 64), title=None, filename=None):
|
702 |
+
num_samples = fix_img.shape[0]
|
703 |
+
|
704 |
+
if num_samples != 16 and num_samples != 32:
|
705 |
+
raise ValueError('Only batches of 16 or 32 samples!')
|
706 |
+
|
707 |
+
fig, ax = plt.subplots(nrows=4 if num_samples == 16 else 8, ncols=4)
|
708 |
+
ncol = 0
|
709 |
+
nrow = 0
|
710 |
+
black_col = np.ones([img_size[0], 0])
|
711 |
+
for sample in range(num_samples):
|
712 |
+
combined_img = np.hstack([fix_img[sample, :, :, 0], black_col, mov_img[sample, :, :, 0]])
|
713 |
+
ax[nrow, ncol].imshow(combined_img, cmap='Greys')
|
714 |
+
ax[nrow, ncol].set_ylabel('#{}'.format(sample))
|
715 |
+
ax[nrow, ncol].tick_params(axis='both',
|
716 |
+
which='both',
|
717 |
+
bottom=False,
|
718 |
+
left=False,
|
719 |
+
labelleft=False,
|
720 |
+
labelbottom=False)
|
721 |
+
ncol += 1
|
722 |
+
if ncol >= 4:
|
723 |
+
ncol = 0
|
724 |
+
nrow += 1
|
725 |
+
|
726 |
+
if title is not None:
|
727 |
+
fig.suptitle(title)
|
728 |
+
|
729 |
+
if filename is not None:
|
730 |
+
plt.savefig(filename, format='png') # Call before show
|
731 |
+
if not const.REMOTE:
|
732 |
+
plt.show()
|
733 |
+
else:
|
734 |
+
plt.close()
|
735 |
+
return fig
|
736 |
+
|
737 |
+
|
738 |
+
def plot_dataset_orthographic_views(view_sets: [[np.ndarray]]):
|
739 |
+
"""
|
740 |
+
|
741 |
+
:param views_fix: Expected order: top, front, left
|
742 |
+
:param views_mov: Expected order: top, front, left
|
743 |
+
:return:
|
744 |
+
"""
|
745 |
+
nrows = len(view_sets)
|
746 |
+
fig, ax = plt.subplots(nrows=nrows, ncols=3)
|
747 |
+
labels = ['top', 'front', 'left']
|
748 |
+
for nrow in range(nrows):
|
749 |
+
for ncol in range(3):
|
750 |
+
if nrows == 1:
|
751 |
+
ax[ncol].imshow(view_sets[nrow][ncol][:, :, 0])
|
752 |
+
ax[ncol].set_title('Fix ' + labels[ncol])
|
753 |
+
ax[ncol].tick_params(axis='both',
|
754 |
+
which='both',
|
755 |
+
bottom=False,
|
756 |
+
left=False,
|
757 |
+
labelleft=False,
|
758 |
+
labelbottom=False)
|
759 |
+
|
760 |
+
else:
|
761 |
+
ax[nrow, ncol].imshow(view_sets[nrow][ncol][:, :, 0])
|
762 |
+
ax[nrow, ncol].set_title('Fix ' + labels[ncol])
|
763 |
+
ax[nrow, ncol].tick_params(axis='both',
|
764 |
+
which='both',
|
765 |
+
bottom=False,
|
766 |
+
left=False,
|
767 |
+
labelleft=False,
|
768 |
+
labelbottom=False)
|
769 |
+
|
770 |
+
plt.show()
|
771 |
+
return fig
|
772 |
+
|
773 |
+
|
774 |
+
def plot_compare_2d_images(img1, img2, img1_name='img1', img2_name='img2'):
|
775 |
+
fig, ax = plt.subplots(nrows=1, ncols=2)
|
776 |
+
ax[0].imshow(img1[:, :, 0])
|
777 |
+
ax[0].set_title(img1_name)
|
778 |
+
ax[0].tick_params(axis='both',
|
779 |
+
which='both',
|
780 |
+
bottom=False,
|
781 |
+
left=False,
|
782 |
+
labelleft=False,
|
783 |
+
labelbottom=False)
|
784 |
+
|
785 |
+
ax[1].imshow(img2[:, :, 0])
|
786 |
+
ax[1].set_title(img2_name)
|
787 |
+
ax[1].tick_params(axis='both',
|
788 |
+
which='both',
|
789 |
+
bottom=False,
|
790 |
+
left=False,
|
791 |
+
labelleft=False,
|
792 |
+
labelbottom=False)
|
793 |
+
|
794 |
+
plt.show()
|
795 |
+
return fig
|
796 |
+
|
797 |
+
|
798 |
+
def plot_dataset_3d(img_sets):
|
799 |
+
from mpl_toolkits.mplot3d import Axes3D
|
800 |
+
fig = plt.figure()
|
801 |
+
ax = fig.add_subplot(111, projection='3d')
|
802 |
+
|
803 |
+
for idx in range(len(img_sets)):
|
804 |
+
ax = _plot_3d(img_sets[idx], ax=ax, name='Set {}'.format(idx))
|
805 |
+
|
806 |
+
plt.show()
|
807 |
+
return fig
|
808 |
+
|
809 |
+
|
810 |
+
def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None):
|
811 |
+
num_rows = fix_img_batch.shape[0]
|
812 |
+
img_size = fix_img_batch.shape[1:3]
|
813 |
+
if fig is not None:
|
814 |
+
fig.clear()
|
815 |
+
plt.figure(fig.number)
|
816 |
+
ax = fig.add_subplot(nrows=num_rows, ncols=4, dpi=const.DPI)
|
817 |
+
else:
|
818 |
+
fig, ax = plt.subplots(nrows=num_rows, ncols=4, dpi=const.DPI)
|
819 |
+
|
820 |
+
for row in range(num_rows):
|
821 |
+
fix_img = fix_img_batch[row, :, :, 0]
|
822 |
+
mov_img = mov_img_batch[row, :, :, 0]
|
823 |
+
disp_map = disp_map_batch[row, :, :, :]
|
824 |
+
pred_img = pred_img_batch[row, :, :, 0]
|
825 |
+
ax[row, 0].imshow(fix_img)
|
826 |
+
ax[row, 0].tick_params(axis='both',
|
827 |
+
which='both',
|
828 |
+
bottom=False,
|
829 |
+
left=False,
|
830 |
+
labelleft=False,
|
831 |
+
labelbottom=False)
|
832 |
+
ax[row, 1].imshow(mov_img)
|
833 |
+
ax[row, 1].tick_params(axis='both',
|
834 |
+
which='both',
|
835 |
+
bottom=False,
|
836 |
+
left=False,
|
837 |
+
labelleft=False,
|
838 |
+
labelbottom=False)
|
839 |
+
|
840 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
|
841 |
+
disp_map_color = _prepare_colormap(disp_map)
|
842 |
+
ax[row, 2].imshow(disp_map_color, interpolation='none', aspect='equal')
|
843 |
+
ax[row, 2].quiver(cx.eval(), cy.eval(), dx.eval(), dy.eval(), units='xy', scale=const.QUIVER_PARAMS.arrow_scale)
|
844 |
+
ax[row, 2].figure.set_size_inches(img_size)
|
845 |
+
ax[row, 2].tick_params(axis='both',
|
846 |
+
which='both',
|
847 |
+
bottom=False,
|
848 |
+
left=False,
|
849 |
+
labelleft=False,
|
850 |
+
labelbottom=False)
|
851 |
+
|
852 |
+
ax[row, 3].tick_params(axis='both',
|
853 |
+
which='both',
|
854 |
+
bottom=False,
|
855 |
+
left=False,
|
856 |
+
labelleft=False,
|
857 |
+
labelbottom=False)
|
858 |
+
|
859 |
+
plt.axis('off')
|
860 |
+
ax[0, 0].set_title('Fixed img ($I_f$)', fontsize=const.FONT_SIZE)
|
861 |
+
ax[0, 1].set_title('Moving img ($I_m$)', fontsize=const.FONT_SIZE)
|
862 |
+
ax[0, 2].set_title('Displacement map ($\delta$)', fontsize=const.FONT_SIZE)
|
863 |
+
ax[0, 3].set_title('Updated $I_m$', fontsize=const.FONT_SIZE)
|
864 |
+
|
865 |
+
if filename is not None:
|
866 |
+
plt.savefig(filename, format='png') # Call before show
|
867 |
+
if not const.REMOTE:
|
868 |
+
plt.show()
|
869 |
+
else:
|
870 |
+
plt.close()
|
871 |
+
return fig
|
872 |
+
|
873 |
+
|
874 |
+
def inspect_disp_map_generation(fix_img, mov_img, disp_map, filename=None, fig=None):
|
875 |
+
if fig is not None:
|
876 |
+
fig.clear()
|
877 |
+
plt.figure(fig.number)
|
878 |
+
else:
|
879 |
+
fig = plt.figure(dpi=const.DPI)
|
880 |
+
|
881 |
+
ax0 = fig.add_subplot(221)
|
882 |
+
im0 = ax0.imshow(fix_img[..., 0])
|
883 |
+
ax0.tick_params(axis='both',
|
884 |
+
which='both',
|
885 |
+
bottom=False,
|
886 |
+
left=False,
|
887 |
+
labelleft=False,
|
888 |
+
labelbottom=False)
|
889 |
+
ax1 = fig.add_subplot(222)
|
890 |
+
im1 = ax1.imshow(mov_img[..., 0])
|
891 |
+
ax1.tick_params(axis='both',
|
892 |
+
which='both',
|
893 |
+
bottom=False,
|
894 |
+
left=False,
|
895 |
+
labelleft=False,
|
896 |
+
labelbottom=False)
|
897 |
+
|
898 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
|
899 |
+
disp_map_color = _prepare_colormap(disp_map)
|
900 |
+
ax2 = fig.add_subplot(223)
|
901 |
+
im2 = ax2.imshow(s, interpolation='none', aspect='equal')
|
902 |
+
|
903 |
+
ax2.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
904 |
+
# ax2.figure.set_size_inches(img_size)
|
905 |
+
ax2.tick_params(axis='both',
|
906 |
+
which='both',
|
907 |
+
bottom=False,
|
908 |
+
left=False,
|
909 |
+
labelleft=False,
|
910 |
+
labelbottom=False)
|
911 |
+
|
912 |
+
ax3 = fig.add_subplot(224)
|
913 |
+
dif = fix_img[..., 0] - mov_img[..., 0]
|
914 |
+
im3 = ax3.imshow(dif)
|
915 |
+
ax3.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
916 |
+
ax3.tick_params(axis='both',
|
917 |
+
which='both',
|
918 |
+
bottom=False,
|
919 |
+
left=False,
|
920 |
+
labelleft=False,
|
921 |
+
labelbottom=False)
|
922 |
+
|
923 |
+
plt.axis('off')
|
924 |
+
ax0.set_title('Fixed img ($I_f$)', fontsize=const.FONT_SIZE)
|
925 |
+
ax1.set_title('Moving img ($I_m$)', fontsize=const.FONT_SIZE)
|
926 |
+
ax2.set_title('Displacement map', fontsize=const.FONT_SIZE)
|
927 |
+
ax3.set_title('Fix - Mov', fontsize=const.FONT_SIZE)
|
928 |
+
|
929 |
+
im0_cb = _set_colorbar(fig, ax0, im0, False)
|
930 |
+
im1_cb = _set_colorbar(fig, ax1, im1, False)
|
931 |
+
disp_cb = _set_colorbar(fig, ax2, im2, False)
|
932 |
+
im3_cb = _set_colorbar(fig, ax3, im3, False)
|
933 |
+
|
934 |
+
if filename is not None:
|
935 |
+
plt.savefig(filename, format='png') # Call before show
|
936 |
+
if not const.REMOTE:
|
937 |
+
plt.show()
|
938 |
+
else:
|
939 |
+
plt.close()
|
940 |
+
|
941 |
+
return fig
|
942 |
+
|
943 |
+
|
944 |
+
def inspect_displacement_grid(ctrl_coords, dense_coords, target_coords, disp_coords, disp_map, mask, fix_img, mov_img,
|
945 |
+
filename=None, fig=None):
|
946 |
+
if fig is not None:
|
947 |
+
fig.clear()
|
948 |
+
plt.figure(fig.number)
|
949 |
+
else:
|
950 |
+
fig = plt.figure()
|
951 |
+
|
952 |
+
ax_grid = fig.add_subplot(231)
|
953 |
+
ax_grid.set_title('Grids', fontsize=const.FONT_SIZE)
|
954 |
+
ax_grid.scatter(ctrl_coords[:, 0], ctrl_coords[:, 1], marker='+', c='r', s=20)
|
955 |
+
ax_grid.scatter(dense_coords[:, 0], dense_coords[:, 1], marker='.', c='r', s=1)
|
956 |
+
ax_grid.tick_params(axis='both',
|
957 |
+
which='both',
|
958 |
+
bottom=False,
|
959 |
+
left=False,
|
960 |
+
labelleft=False,
|
961 |
+
labelbottom=False)
|
962 |
+
|
963 |
+
ax_grid.scatter(target_coords[:, 0], target_coords[:, 1], marker='+', c='b', s=20)
|
964 |
+
ax_grid.scatter(disp_coords[:, 0], disp_coords[:, 1], marker='.', c='b', s=1)
|
965 |
+
|
966 |
+
ax_grid.set_aspect('equal')
|
967 |
+
|
968 |
+
ax_disp = fig.add_subplot(232)
|
969 |
+
ax_disp.set_title('Displacement map', fontsize=const.FONT_SIZE)
|
970 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(disp_map)
|
971 |
+
ax_disp.imshow(s, interpolation='none', aspect='equal')
|
972 |
+
ax_disp.tick_params(axis='both',
|
973 |
+
which='both',
|
974 |
+
bottom=False,
|
975 |
+
left=False,
|
976 |
+
labelleft=False,
|
977 |
+
labelbottom=False)
|
978 |
+
|
979 |
+
ax_mask = fig.add_subplot(233)
|
980 |
+
ax_mask.set_title('Mask', fontsize=const.FONT_SIZE)
|
981 |
+
ax_mask.imshow(mask)
|
982 |
+
ax_mask.tick_params(axis='both',
|
983 |
+
which='both',
|
984 |
+
bottom=False,
|
985 |
+
left=False,
|
986 |
+
labelleft=False,
|
987 |
+
labelbottom=False)
|
988 |
+
|
989 |
+
ax_fix = fig.add_subplot(234)
|
990 |
+
ax_fix.set_title('Fix image', fontsize=const.FONT_SIZE)
|
991 |
+
ax_fix.imshow(fix_img[..., 0])
|
992 |
+
ax_fix.tick_params(axis='both',
|
993 |
+
which='both',
|
994 |
+
bottom=False,
|
995 |
+
left=False,
|
996 |
+
labelleft=False,
|
997 |
+
labelbottom=False)
|
998 |
+
|
999 |
+
ax_mov = fig.add_subplot(235)
|
1000 |
+
ax_mov.set_title('Moving image', fontsize=const.FONT_SIZE)
|
1001 |
+
ax_mov.imshow(mov_img[..., 0])
|
1002 |
+
ax_mov.tick_params(axis='both',
|
1003 |
+
which='both',
|
1004 |
+
bottom=False,
|
1005 |
+
left=False,
|
1006 |
+
labelleft=False,
|
1007 |
+
labelbottom=False)
|
1008 |
+
|
1009 |
+
ax_dif = fig.add_subplot(236)
|
1010 |
+
ax_dif.set_title('Fix - Moving image', fontsize=const.FONT_SIZE)
|
1011 |
+
ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
|
1012 |
+
ax_dif.tick_params(axis='both',
|
1013 |
+
which='both',
|
1014 |
+
bottom=False,
|
1015 |
+
left=False,
|
1016 |
+
labelleft=False,
|
1017 |
+
labelbottom=False)
|
1018 |
+
legend_elems = [Line2D([0], [0], color=cmap_bin(0), lw=2),
|
1019 |
+
Line2D([0], [0], color=cmap_bin(2), lw=2)]
|
1020 |
+
|
1021 |
+
ax_dif.legend(legend_elems, ['Mov', 'Fix'], loc='upper left', bbox_to_anchor=(0., 0., 1., 0.),
|
1022 |
+
ncol=2, mode='expand')
|
1023 |
+
|
1024 |
+
if filename is not None:
|
1025 |
+
plt.savefig(filename, format='png') # Call before show
|
1026 |
+
if not const.REMOTE:
|
1027 |
+
plt.show()
|
1028 |
+
|
1029 |
+
return fig
|
1030 |
+
|
1031 |
+
|
1032 |
+
def compare_disp_maps(disp_m_f, disp_f_m, fix_img, mov_img, filename=None, fig=None):
|
1033 |
+
if fig is not None:
|
1034 |
+
fig.clear()
|
1035 |
+
plt.figure(fig.number)
|
1036 |
+
else:
|
1037 |
+
fig = plt.figure()
|
1038 |
+
|
1039 |
+
ax_d_m_f = fig.add_subplot(131)
|
1040 |
+
ax_d_m_f.set_title('Disp M->F', fontsize=const.FONT_SIZE)
|
1041 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(disp_m_f)
|
1042 |
+
ax_d_m_f.imshow(s, interpolation='none', aspect='equal')
|
1043 |
+
ax_d_m_f.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
1044 |
+
ax_d_m_f.tick_params(axis='both',
|
1045 |
+
which='both',
|
1046 |
+
bottom=False,
|
1047 |
+
left=False,
|
1048 |
+
labelleft=False,
|
1049 |
+
labelbottom=False)
|
1050 |
+
|
1051 |
+
ax_d_f_m = fig.add_subplot(132)
|
1052 |
+
ax_d_f_m.set_title('Disp F->M', fontsize=const.FONT_SIZE)
|
1053 |
+
cx, cy, dx, dy, s = _prepare_quiver_map(disp_f_m)
|
1054 |
+
ax_d_f_m.quiver(cx, cy, dx, dy, scale=const.QUIVER_PARAMS.arrow_scale)
|
1055 |
+
ax_d_f_m.imshow(s, interpolation='none', aspect='equal')
|
1056 |
+
ax_d_f_m.tick_params(axis='both',
|
1057 |
+
which='both',
|
1058 |
+
bottom=False,
|
1059 |
+
left=False,
|
1060 |
+
labelleft=False,
|
1061 |
+
labelbottom=False)
|
1062 |
+
|
1063 |
+
ax_dif = fig.add_subplot(133)
|
1064 |
+
ax_dif.set_title('Fix - Moving image', fontsize=const.FONT_SIZE)
|
1065 |
+
ax_dif.imshow(fix_img[..., 0] - mov_img[..., 0], cmap=cmap_bin)
|
1066 |
+
ax_dif.tick_params(axis='both',
|
1067 |
+
which='both',
|
1068 |
+
bottom=False,
|
1069 |
+
left=False,
|
1070 |
+
labelleft=False,
|
1071 |
+
labelbottom=False)
|
1072 |
+
|
1073 |
+
legend_elems = [Line2D([0], [0], color=cmap_bin(0), lw=2),
|
1074 |
+
Line2D([0], [0], color=cmap_bin(2), lw=2)]
|
1075 |
+
|
1076 |
+
ax_dif.legend(legend_elems, ['Mov', 'Fix'], loc='upper left', bbox_to_anchor=(0., 0., 1., 0.),
|
1077 |
+
ncol=2, mode='expand')
|
1078 |
+
|
1079 |
+
if filename is not None:
|
1080 |
+
plt.savefig(filename, format='png') # Call before show
|
1081 |
+
if not const.REMOTE:
|
1082 |
+
plt.show()
|
1083 |
+
else:
|
1084 |
+
plt.close()
|
1085 |
+
|
1086 |
+
return fig
|
1087 |
+
|
1088 |
+
|
1089 |
+
def plot_train_step(list_imgs: [np.ndarray], fig_title='TRAINING', dest_folder='.', save_file=True):
|
1090 |
+
# list_imgs[0]: fix image
|
1091 |
+
# list_imgs[1]: moving image
|
1092 |
+
# list_imgs[2]: prediction scale 1
|
1093 |
+
# list_imgs[3]: prediction scale 2
|
1094 |
+
# list_imgs[4]: prediction scale 3
|
1095 |
+
# list_imgs[5]: disp map scale 1
|
1096 |
+
# list_imgs[6]: disp map scale 2
|
1097 |
+
# list_imgs[7]: disp map scale 3
|
1098 |
+
num_imgs = len(list_imgs)
|
1099 |
+
num_preds = (num_imgs - 2) // 2
|
1100 |
+
num_cols = num_preds + 1
|
1101 |
+
# 3D
|
1102 |
+
# TRAINING
|
1103 |
+
fig = plt.figure(figsize=(12.8, 10.24))
|
1104 |
+
fig.tight_layout(pad=5.0)
|
1105 |
+
ax = fig.add_subplot(2, num_cols, 1, projection='3d')
|
1106 |
+
ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
1107 |
+
ax.set_title('Fix image', fontsize=const.FONT_SIZE)
|
1108 |
+
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1109 |
+
|
1110 |
+
for i in range(2, num_preds+2):
|
1111 |
+
ax = fig.add_subplot(2, num_cols, i, projection='3d')
|
1112 |
+
ax.voxels(list_imgs[0][0, ..., 0] > 0.0, facecolors='red', edgecolors='red', label='Fixed')
|
1113 |
+
ax.voxels(list_imgs[i][0, ..., 0] > 0.0, facecolors='green', edgecolors='green', label='Pred_{}'.format(i - 1))
|
1114 |
+
ax.set_title('Pred. #{}'.format(i - 1))
|
1115 |
+
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1116 |
+
|
1117 |
+
ax = fig.add_subplot(2, num_cols, num_preds+2, projection='3d')
|
1118 |
+
ax.voxels(list_imgs[1][0, ..., 0] > 0.0, facecolors='blue', edgecolors='blue', label='Moving')
|
1119 |
+
ax.set_title('Fix image', fontsize=const.FONT_SIZE)
|
1120 |
+
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1121 |
+
|
1122 |
+
for i in range(num_preds+2, 2 * num_preds + 2):
|
1123 |
+
ax = fig.add_subplot(2, num_cols, i + 1, projection='3d')
|
1124 |
+
c, d, s = _prepare_quiver_map(list_imgs[i][0, ...], dim=3)
|
1125 |
+
ax.quiver(c[const.H_DISP], c[const.W_DISP], c[const.D_DISP],
|
1126 |
+
d[const.H_DISP], d[const.W_DISP], d[const.D_DISP],
|
1127 |
+
norm=True)
|
1128 |
+
ax.set_title('Disp. #{}'.format(i - 5))
|
1129 |
+
_square_3d_plot(np.arange(0, 63), np.arange(0, 63), np.arange(0, 63), ax)
|
1130 |
+
|
1131 |
+
fig.suptitle(fig_title, fontsize=const.FONT_SIZE)
|
1132 |
+
|
1133 |
+
if save_file:
|
1134 |
+
plt.savefig(os.path.join(dest_folder, fig_title+'.png'), format='png') # Call before show
|
1135 |
+
if not const.REMOTE:
|
1136 |
+
plt.show()
|
1137 |
+
else:
|
1138 |
+
plt.close()
|
1139 |
+
return fig
|
1140 |
+
|
1141 |
+
|
1142 |
+
def _square_3d_plot(X, Y, Z, ax):
|
1143 |
+
max_range = np.array([X.max() - X.min(), Y.max() - Y.min(), Z.max() - Z.min()]).max() / 2.0
|
1144 |
+
|
1145 |
+
mid_x = (X.max() + X.min()) * 0.5
|
1146 |
+
mid_y = (Y.max() + Y.min()) * 0.5
|
1147 |
+
mid_z = (Z.max() + Z.min()) * 0.5
|
1148 |
+
ax.set_xlim(mid_x - max_range, mid_x + max_range)
|
1149 |
+
ax.set_ylim(mid_y - max_range, mid_y + max_range)
|
1150 |
+
ax.set_zlim(mid_z - max_range, mid_z + max_range)
|
1151 |
+
|
EvaluationScripts/evaluation.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
11 |
+
import voxelmorph as vxm
|
12 |
+
import neurite as ne
|
13 |
+
import h5py
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
if PYCHARM_EXEC:
|
17 |
+
import scripts.tf.myScript_constants as const
|
18 |
+
from scripts.tf.myScript_data_generator import DataGeneratorManager
|
19 |
+
from scripts.tf.myScript_utils import save_nifti, try_mkdir
|
20 |
+
else:
|
21 |
+
import myScript_constants as const
|
22 |
+
from myScript_data_generator import DataGeneratorManager
|
23 |
+
from myScript_utils import save_nifti, try_mkdir
|
24 |
+
|
25 |
+
os.environ['CUDA_DEVICE_ORDER'] = const.DEV_ORDER
|
26 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = const.GPU_NUM # Check availability before running using 'nvidia-smi'
|
27 |
+
|
28 |
+
const.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_LITS'
|
29 |
+
const.BATCH_SIZE = 8
|
30 |
+
const.LIMIT_NUM_SAMPLES = None
|
31 |
+
const.EPOCHS = 1000
|
32 |
+
|
33 |
+
if PYCHARM_EXEC:
|
34 |
+
path_prefix = os.path.join('scripts', 'tf')
|
35 |
+
else:
|
36 |
+
path_prefix = ''
|
37 |
+
|
38 |
+
# Load data
|
39 |
+
# Build data generator
|
40 |
+
data_generator = DataGeneratorManager(const.TRAINING_DATASET, const.BATCH_SIZE, True, const.LIMIT_NUM_SAMPLES,
|
41 |
+
1 - const.TRAINING_PERC, voxelmorph=True)
|
42 |
+
|
43 |
+
test_generator = data_generator.get_generator('validation')
|
44 |
+
test_fix_img, test_mov_img, *_ = test_generator.get_random_sample(1)
|
45 |
+
|
46 |
+
# Build model
|
47 |
+
in_shape = test_generator.get_input_shape()[1:-1]
|
48 |
+
enc_features = [16, 32, 32, 32]# const.ENCODER_FILTERS
|
49 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
50 |
+
nb_features = [enc_features, dec_features]
|
51 |
+
vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features, int_steps=0)
|
52 |
+
|
53 |
+
weight_files = [os.path.join(path_prefix, 'checkpoints', f) for f in os.listdir(os.path.join(path_prefix, 'checkpoints')) if 'weights' in f]
|
54 |
+
weight_files.sort()
|
55 |
+
pred_folder = os.path.join(path_prefix, 'predictions')
|
56 |
+
try_mkdir(pred_folder)
|
57 |
+
|
58 |
+
# Prepare the images
|
59 |
+
fix_img = test_fix_img.squeeze()
|
60 |
+
mid_slice_fix = [np.take(fix_img, fix_img.shape[d]//2, axis=d) for d in range(3)]
|
61 |
+
mid_slice_fix[1] = np.rot90(mid_slice_fix[1], 1)
|
62 |
+
mid_slice_fix[2] = np.rot90(mid_slice_fix[2], -1)
|
63 |
+
|
64 |
+
mid_mov_slice = list()
|
65 |
+
mid_disp_slice = list()
|
66 |
+
# Due to slicing, it can happen that the last file is not tested. So include it always
|
67 |
+
slice = 5
|
68 |
+
for f in weight_files[:-1:slice] + [weight_files[-1]]:
|
69 |
+
name = os.path.split(f)[-1].split('.h5')[0]
|
70 |
+
vxm_model.load_weights(f)
|
71 |
+
pred_img, pred_disp = vxm_model.predict([test_mov_img, test_fix_img])
|
72 |
+
pred_img = pred_img.squeeze()
|
73 |
+
|
74 |
+
mov_slices = [np.take(pred_img, pred_img.shape[d]//2, axis=d) for d in range(3)]
|
75 |
+
mov_slices[1] = np.rot90(mov_slices[1], 1)
|
76 |
+
mov_slices[2] = np.rot90(mov_slices[2], -1)
|
77 |
+
mid_mov_slice.append(mov_slices)
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
# Get sample for testing
|
84 |
+
test_sample = test_generator.get_single_sample()
|
TrainingScripts/Train_2d.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
8 |
+
|
9 |
+
import tensorflow as tf
|
10 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
11 |
+
import voxelmorph as vxm
|
12 |
+
from datetime import datetime
|
13 |
+
|
14 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
15 |
+
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
|
16 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
17 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
18 |
+
|
19 |
+
|
20 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
21 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = C.GPU_NUM # Check availability before running using 'nvidia-smi'
|
22 |
+
|
23 |
+
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/ov_dataset/training'
|
24 |
+
C.BATCH_SIZE = 256
|
25 |
+
C.LIMIT_NUM_SAMPLES = None
|
26 |
+
C.EPOCHS = 10000
|
27 |
+
|
28 |
+
if PYCHARM_EXEC:
|
29 |
+
path_prefix = os.path.join('scripts', 'tf')
|
30 |
+
else:
|
31 |
+
path_prefix = ''
|
32 |
+
|
33 |
+
# Load data
|
34 |
+
# Build data generator
|
35 |
+
sample_list = [os.path.join(C.TRAINING_DATASET, f) for f in os.listdir(C.TRAINING_DATASET) if
|
36 |
+
f.startswith('sample')]
|
37 |
+
sample_list.sort()
|
38 |
+
|
39 |
+
data_generator = DataGeneratorManager2D(sample_list[:C.LIMIT_NUM_SAMPLES],
|
40 |
+
C.BATCH_SIZE, C.TRAINING_PERC,
|
41 |
+
(64, 64, 1),
|
42 |
+
fix_img_tag='dilated/input/fix',
|
43 |
+
mov_img_tag='dilated/input/mov'
|
44 |
+
)
|
45 |
+
|
46 |
+
# Build model
|
47 |
+
in_shape = data_generator.train_generator.input_shape[:-1]
|
48 |
+
enc_features = [32, 32, 32, 32, 32, 32] # const.ENCODER_FILTERS
|
49 |
+
dec_features = [32, 32, 32, 32, 32, 32, 32, 16] # const.ENCODER_FILTERS[::-1]
|
50 |
+
nb_features = [enc_features, dec_features]
|
51 |
+
vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features, int_steps=0)
|
52 |
+
|
53 |
+
# Losses and loss weights
|
54 |
+
def comb_loss(y_true, y_pred):
|
55 |
+
return 1e-3 * HausdorffDistance(ndim=2, nerosion=2).loss(y_true, y_pred) + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
+
|
57 |
+
|
58 |
+
losses = [comb_loss, vxm.losses.Grad('l2').loss]
|
59 |
+
loss_weights = [1, 0.01]
|
60 |
+
|
61 |
+
# Compile the model
|
62 |
+
vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)
|
63 |
+
|
64 |
+
# Train
|
65 |
+
output_folder = os.path.join('train_2d_dice_hausdorff_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
|
66 |
+
try_mkdir(output_folder)
|
67 |
+
try_mkdir(os.path.join(output_folder, 'checkpoints'))
|
68 |
+
try_mkdir(os.path.join(output_folder, 'tensorboard'))
|
69 |
+
my_callbacks = [
|
70 |
+
# EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
|
71 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
72 |
+
save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
|
73 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
74 |
+
save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
|
75 |
+
# CSVLogger(train_log_name, ';'),
|
76 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
77 |
+
TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
78 |
+
batch_size=C.BATCH_SIZE, write_images=True, histogram_freq=10, update_freq='epoch',
|
79 |
+
write_grads=True),
|
80 |
+
EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
|
81 |
+
]
|
82 |
+
hist = vxm_model.fit_generator(data_generator.train_generator,
|
83 |
+
epochs=C.EPOCHS,
|
84 |
+
validation_data=data_generator.validation_generator,
|
85 |
+
verbose=2,
|
86 |
+
callbacks=my_callbacks)
|
TrainingScripts/Train_2d_uncertaintyWeighting.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import tensorflow as tf
|
11 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
12 |
+
import voxelmorph as vxm
|
13 |
+
import neurite as ne
|
14 |
+
import h5py
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
+
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager2D
|
19 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
21 |
+
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
+
|
23 |
+
|
24 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # const.GPU_NUM # Check availability before running using 'nvidia-smi'
|
26 |
+
|
27 |
+
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/ov_dataset/training'
|
28 |
+
C.BATCH_SIZE = 256
|
29 |
+
C.LIMIT_NUM_SAMPLES = None
|
30 |
+
C.EPOCHS = 10000
|
31 |
+
|
32 |
+
if PYCHARM_EXEC:
|
33 |
+
path_prefix = os.path.join('scripts', 'tf')
|
34 |
+
else:
|
35 |
+
path_prefix = ''
|
36 |
+
|
37 |
+
# Load data
|
38 |
+
# Build data generator
|
39 |
+
sample_list = [os.path.join(C.TRAINING_DATASET, f) for f in os.listdir(C.TRAINING_DATASET) if
|
40 |
+
f.startswith('sample')]
|
41 |
+
sample_list.sort()
|
42 |
+
|
43 |
+
data_generator = DataGeneratorManager2D(sample_list[:C.LIMIT_NUM_SAMPLES],
|
44 |
+
C.BATCH_SIZE, C.TRAINING_PERC,
|
45 |
+
(64, 64, 1),
|
46 |
+
fix_img_tag='dilated/input/fix',
|
47 |
+
mov_img_tag='dilated/input/mov',
|
48 |
+
multi_loss=True,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Build model
|
52 |
+
in_shape_img, in_shape_grad = data_generator.train_generator.input_shape
|
53 |
+
enc_features = [32, 32, 32, 32, 32, 32] # const.ENCODER_FILTERS
|
54 |
+
dec_features = [32, 32, 32, 32, 32, 32, 32, 16] # const.ENCODER_FILTERS[::-1]
|
55 |
+
nb_features = [enc_features, dec_features]
|
56 |
+
vxm_model = vxm.networks.VxmDense(inshape=in_shape_img[:-1], nb_unet_features=nb_features, int_steps=0)
|
57 |
+
|
58 |
+
#moving = tf.keras.Input(shape=in_shape_img, name='multiLoss_moving_input', dtype=tf.float32)
|
59 |
+
#fixed = tf.keras.Input(shape=in_shape_img, name='multiLoss_fixed_input', dtype=tf.float32)
|
60 |
+
grad = tf.keras.Input(shape=(*in_shape_img[:-1], 2), name='multiLoss_grad_input', dtype=tf.float32)
|
61 |
+
|
62 |
+
def dice_loss(y_true, y_pred):
|
63 |
+
# Dice().loss returns -Dice score
|
64 |
+
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
65 |
+
|
66 |
+
#fixed_pred, dm_pred = vxm_model([moving, fixed])
|
67 |
+
multiLoss = UncertaintyWeighting(num_loss_fns=2,
|
68 |
+
num_reg_fns=1,
|
69 |
+
loss_fns=[HausdorffDistance(2, 2).loss, dice_loss],
|
70 |
+
reg_fns=[vxm.losses.Grad('l2').loss],
|
71 |
+
prior_loss_w=[1., 1.],
|
72 |
+
prior_reg_w=[0.01],
|
73 |
+
name='MultiLossLayer')
|
74 |
+
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], vxm_model.references.y_source, vxm_model.references.y_source, grad, vxm_model.references.pos_flow])
|
75 |
+
|
76 |
+
full_model = tf.keras.Model(inputs=vxm_model.inputs + [grad], outputs=vxm_model.outputs + [loss])
|
77 |
+
|
78 |
+
# Compile the model
|
79 |
+
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|
80 |
+
|
81 |
+
# Train
|
82 |
+
output_folder = os.path.join('train_2d_multiloss_haussdorf_dice_grad' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
83 |
+
try_mkdir(output_folder)
|
84 |
+
try_mkdir(os.path.join(output_folder, 'checkpoints'))
|
85 |
+
try_mkdir(os.path.join(output_folder, 'tensorboard'))
|
86 |
+
my_callbacks = [
|
87 |
+
# EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
|
88 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
89 |
+
save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
|
90 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
91 |
+
save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
|
92 |
+
# CSVLogger(train_log_name, ';'),
|
93 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
94 |
+
TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
95 |
+
batch_size=C.BATCH_SIZE, write_images=True, histogram_freq=10, update_freq='epoch',
|
96 |
+
write_grads=True),
|
97 |
+
EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
|
98 |
+
]
|
99 |
+
hist = full_model.fit_generator(data_generator.train_generator,
|
100 |
+
epochs=C.EPOCHS,
|
101 |
+
validation_data=data_generator.validation_generator,
|
102 |
+
verbose=2,
|
103 |
+
callbacks=my_callbacks)
|
TrainingScripts/Train_3d.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
11 |
+
import voxelmorph as vxm
|
12 |
+
import neurite as ne
|
13 |
+
import h5py
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
17 |
+
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
+
|
20 |
+
|
21 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
22 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '1' # Check availability before running using 'nvidia-smi'
|
23 |
+
|
24 |
+
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_LITS'
|
25 |
+
C.BATCH_SIZE = 2
|
26 |
+
C.LIMIT_NUM_SAMPLES = None
|
27 |
+
C.EPOCHS = 10000
|
28 |
+
|
29 |
+
# Load data
|
30 |
+
# Build data generator
|
31 |
+
data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
|
32 |
+
1 - C.TRAINING_PERC, voxelmorph=True)
|
33 |
+
|
34 |
+
train_generator = data_generator.get_generator('train')
|
35 |
+
validation_generator = data_generator.get_generator('validation')
|
36 |
+
|
37 |
+
|
38 |
+
# Build model
|
39 |
+
in_shape = train_generator.get_input_shape()[1:-1]
|
40 |
+
enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
|
41 |
+
dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
42 |
+
nb_features = [enc_features, dec_features]
|
43 |
+
vxm_model = vxm.networks.VxmDense(inshape=in_shape, nb_unet_features=nb_features, int_steps=7)
|
44 |
+
|
45 |
+
|
46 |
+
# Losses and loss weights
|
47 |
+
|
48 |
+
def comb_loss(y_true, y_pred):
|
49 |
+
return vxm.losses.MSE().loss(y_true, y_pred) + vxm.losses.NCC().loss(y_true, y_pred)
|
50 |
+
|
51 |
+
|
52 |
+
losses = [comb_loss, vxm.losses.Grad('l2').loss]
|
53 |
+
loss_weights = [1., 0.01]
|
54 |
+
|
55 |
+
# Compile the model
|
56 |
+
vxm_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=losses, loss_weights=loss_weights)
|
57 |
+
|
58 |
+
# Train
|
59 |
+
output_folder = os.path.join('train_3d_mse_ncc_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
|
60 |
+
try_mkdir(output_folder)
|
61 |
+
try_mkdir(os.path.join(output_folder, 'checkpoints'))
|
62 |
+
try_mkdir(os.path.join(output_folder, 'tensorboard'))
|
63 |
+
my_callbacks = [
|
64 |
+
#EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
|
65 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
66 |
+
save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
|
67 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
68 |
+
save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
|
69 |
+
# CSVLogger(train_log_name, ';'),
|
70 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
71 |
+
TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
72 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=10, update_freq='epoch',
|
73 |
+
write_grads=True),
|
74 |
+
EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
|
75 |
+
]
|
76 |
+
hist = vxm_model.fit(train_generator, epochs=C.EPOCHS, validation_data=validation_generator, verbose=2, callbacks=my_callbacks)
|
TrainingScripts/Train_3d_weaklySupervised.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import tensorflow as tf
|
10 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
11 |
+
import voxelmorph as vxm
|
12 |
+
import neurite as ne
|
13 |
+
import h5py
|
14 |
+
from datetime import datetime
|
15 |
+
|
16 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
17 |
+
from DeepDeformationMapRegistration.data_generator import DataGeneratorManager
|
18 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir
|
19 |
+
from DeepDeformationMapRegistration.networks import VxmWeaklySupervised
|
20 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistance
|
21 |
+
from DeepDeformationMapRegistration.layers import UncertaintyWeighting
|
22 |
+
|
23 |
+
|
24 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
25 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
|
26 |
+
|
27 |
+
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/vessel_registration/sanity_dataset_vessels'
|
28 |
+
C.BATCH_SIZE = 2
|
29 |
+
C.LIMIT_NUM_SAMPLES = None
|
30 |
+
C.EPOCHS = 10000
|
31 |
+
|
32 |
+
# Load data
|
33 |
+
# Build data generator
|
34 |
+
|
35 |
+
data_generator = DataGeneratorManager(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.LIMIT_NUM_SAMPLES,
|
36 |
+
1 - C.TRAINING_PERC, voxelmorph=True, segmentations=True)
|
37 |
+
|
38 |
+
train_generator = data_generator.get_generator('train')
|
39 |
+
validation_generator = data_generator.get_generator('validation')
|
40 |
+
|
41 |
+
|
42 |
+
# Build model
|
43 |
+
in_shape = train_generator.get_input_shape()[1:-1]
|
44 |
+
enc_features = [16, 32, 32, 32, 32, 32]# const.ENCODER_FILTERS
|
45 |
+
dec_features = [32, 32, 32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
46 |
+
nb_features = [enc_features, dec_features]
|
47 |
+
vxm_model = VxmWeaklySupervised(inshape=in_shape, all_labels=[1], nb_unet_features=nb_features, int_steps=5)
|
48 |
+
|
49 |
+
# Losses and loss weights
|
50 |
+
|
51 |
+
grad = tf.keras.Input(shape=(*in_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
52 |
+
fix_img = tf.keras.Input(shape=(*in_shape, 1), name='multiLoss_fix_img_input', dtype=tf.float32)
|
53 |
+
def dice_loss(y_true, y_pred):
|
54 |
+
# Dice().loss returns -Dice score
|
55 |
+
return 1 + vxm.losses.Dice().loss(y_true, y_pred)
|
56 |
+
|
57 |
+
multiLoss = UncertaintyWeighting(num_loss_fns=3,
|
58 |
+
num_reg_fns=1,
|
59 |
+
loss_fns=[HausdorffDistance(3, 5).loss, dice_loss, vxm.losses.NCC().loss],
|
60 |
+
reg_fns=[vxm.losses.Grad('l2').loss],
|
61 |
+
prior_loss_w=[1., 1., 1.],
|
62 |
+
prior_reg_w=[0.01],
|
63 |
+
name='MultiLossLayer')
|
64 |
+
loss = multiLoss([vxm_model.inputs[1], vxm_model.inputs[1], fix_img,
|
65 |
+
vxm_model.references.pred_segm, vxm_model.references.pred_segm, vxm_model.references.pred_img,
|
66 |
+
grad,
|
67 |
+
vxm_model.references.pos_flow])
|
68 |
+
|
69 |
+
full_model = tf.keras.Model(inputs=vxm_model.inputs + [fix_img, grad], outputs=vxm_model.outputs + [loss])
|
70 |
+
|
71 |
+
# Compile the model
|
72 |
+
full_model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), loss=None)
|
73 |
+
|
74 |
+
# Train
|
75 |
+
output_folder = os.path.join('train_3d_multiloss_segm_haus_dice_ncc_grad_'+datetime.now().strftime("%H%M%S-%d%m%Y"))
|
76 |
+
try_mkdir(output_folder)
|
77 |
+
try_mkdir(os.path.join(output_folder, 'checkpoints'))
|
78 |
+
try_mkdir(os.path.join(output_folder, 'tensorboard'))
|
79 |
+
my_callbacks = [
|
80 |
+
#EarlyStopping(patience=const.EARLY_STOP_PATIENCE, monitor='dice', mode='max', verbose=1),
|
81 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
82 |
+
save_best_only=True, monitor='val_loss', verbose=0, mode='min'),
|
83 |
+
ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
84 |
+
save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min'),
|
85 |
+
# CSVLogger(train_log_name, ';'),
|
86 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
87 |
+
TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
88 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=10, update_freq='epoch',
|
89 |
+
write_grads=True),
|
90 |
+
EarlyStopping(monitor='val_loss', verbose=1, patience=50, min_delta=0.0001)
|
91 |
+
]
|
92 |
+
hist = full_model.fit(train_generator, epochs=C.EPOCHS, validation_data=validation_generator, verbose=2, callbacks=my_callbacks)
|