jpdefrutos commited on
Commit
6a4f823
·
1 Parent(s): 74c6a32

Scripts for training on the IXI T1 MRI Dataset

Browse files
Brain_study/Build_test_set.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import shutil
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ currentdir = os.path.dirname(os.path.realpath(__file__))
8
+ parentdir = os.path.dirname(currentdir)
9
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
10
+
11
+ import tensorflow as tf
12
+ # tf.enable_eager_execution(config=config)
13
+
14
+ import numpy as np
15
+ import h5py
16
+
17
+ import DeepDeformationMapRegistration.utils.constants as C
18
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
20
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
21
+ from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
22
+ from tqdm import tqdm
23
+
24
+ from Brain_study.data_generator import BatchGenerator
25
+
26
+ from skimage.measure import regionprops
27
+ from scipy.interpolate import griddata
28
+
29
+ import argparse
30
+
31
+
32
+ DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
33
+ MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
34
+ DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
35
+
36
+ POINTS = None
37
+ MISSING_CENTROID = np.asarray([[np.nan]*3])
38
+
39
+
40
+ def get_mov_centroids(fix_seg, disp_map):
41
+ fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, 28))
42
+ disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
43
+ return fix_centroids, fix_centroids + disp, disp
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
49
+ parser.add_argument('--reldir', type=str, help='Relative path to dataset, in where to store the files', default='')
50
+ parser.add_argument('--gpu', type=int, help='GPU', default=0)
51
+ parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
52
+ parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
53
+ args = parser.parse_args()
54
+
55
+ assert args.dataset != '', "Missing original dataset dataset"
56
+ if args.dir == '' and args.reldir != '':
57
+ OUTPUT_FOLDER_DIR = os.path.join(args.dataset, 'test_dataset')
58
+ elif args.dir != '' and args.reldir == '':
59
+ OUTPUT_FOLDER_DIR = args.dir
60
+ else:
61
+ raise ValueError("Either provide 'dir' or 'reldir'")
62
+
63
+ if args.erase:
64
+ shutil.rmtree(OUTPUT_FOLDER_DIR, ignore_errors=True)
65
+ os.makedirs(OUTPUT_FOLDER_DIR, exist_ok=True)
66
+ print('DESTINATION FOLDER: ', OUTPUT_FOLDER_DIR)
67
+
68
+ DATASET = args.dataset
69
+
70
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
71
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
72
+
73
+ data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
74
+
75
+ img_generator = data_generator.get_train_generator()
76
+ nb_labels = len(img_generator.get_segmentation_labels())
77
+ image_input_shape = img_generator.get_data_shape()[-1][:-1]
78
+ image_output_shape = [64] * 3
79
+ # Build model
80
+
81
+ xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
82
+ yy = np.linspace(0, image_output_shape[1], image_output_shape[2], endpoint=False)
83
+ zz = np.linspace(0, image_output_shape[2], image_output_shape[1], endpoint=False)
84
+
85
+ xx, yy, zz = np.meshgrid(xx, yy, zz)
86
+
87
+ POINTS = np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
88
+
89
+ input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm')
90
+ augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
91
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
92
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
93
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
94
+ num_augmentations=C.NUM_AUGMENTATIONS,
95
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
96
+ brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
97
+ in_img_shape=image_input_shape,
98
+ out_img_shape=image_output_shape,
99
+ only_image=False,
100
+ only_resize=False,
101
+ trainable=False,
102
+ return_displacement_map=True)
103
+ augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
104
+
105
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
106
+ config.gpu_options.allow_growth = True
107
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
108
+
109
+ sess = tf.Session(config=config)
110
+ tf.keras.backend.set_session(sess)
111
+ with sess.as_default():
112
+ sess.run(tf.global_variables_initializer())
113
+ progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
114
+ for step, (in_batch, _) in progress_bar:
115
+ fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
116
+
117
+ fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map)
118
+
119
+ out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
120
+ out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
121
+ img_shape = fix_img.shape
122
+ segm_shape = fix_seg.shape
123
+ disp_shape = disp_map.shape
124
+ centroids_shape = fix_centroids.shape
125
+ with h5py.File(out_file, 'w') as f:
126
+ f.create_dataset('fix_image', shape=img_shape[1:], dtype=np.float32, data=fix_img[0, ...])
127
+ f.create_dataset('mov_image', shape=img_shape[1:], dtype=np.float32, data=mov_img[0, ...])
128
+ f.create_dataset('fix_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=fix_seg[0, ...])
129
+ f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
130
+ f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
131
+ f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
132
+
133
+ with h5py.File(out_file_dm, 'w') as f:
134
+ f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
135
+ f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
136
+
137
+ print('Done')
Brain_study/Evaluate_network.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import shutil
4
+
5
+ import h5py
6
+ import matplotlib.pyplot as plt
7
+
8
+ currentdir = os.path.dirname(os.path.realpath(__file__))
9
+ parentdir = os.path.dirname(currentdir)
10
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
11
+
12
+ import tensorflow as tf
13
+ # tf.enable_eager_execution(config=config)
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ import voxelmorph as vxm
18
+
19
+ import DeepDeformationMapRegistration.utils.constants as C
20
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
21
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
22
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion
23
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
24
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
25
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
26
+ from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
27
+ from scipy.interpolate import RegularGridInterpolator
28
+ from tqdm import tqdm
29
+
30
+ import h5py
31
+
32
+ from Brain_study.data_generator import BatchGenerator
33
+
34
+ import argparse
35
+
36
+ from skimage.transform import warp
37
+ import neurite as ne
38
+
39
+ DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
40
+ MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
41
+ DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
42
+
43
+ OUTPUT_FOLDER_NAME = 'Evaluate'
44
+
45
+ if __name__ == '__main__':
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument('-m', '--model', type=str, help='.h5 of the model', default='')
48
+ parser.add_argument('-d', '--dir', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default='')
49
+ parser.add_argument('--gpu', type=int, help='GPU', default=0)
50
+ parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
51
+ default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training')
52
+ parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
53
+ args = parser.parse_args()
54
+ if args.model != '':
55
+ assert '.h5' in args.model, 'No checkpoint file provided, use -d/--dir instead'
56
+ MODEL_FILE = args.model
57
+ DATA_ROOT_DIR = os.path.split(args.model)[0]
58
+ elif args.dir != '':
59
+ assert '.h5' not in args.model, 'Provided checkpoint file, user -m/--model instead'
60
+ MODEL_FILE = os.path.join(args.dir, 'checkpoints', 'best_model.h5')
61
+ DATA_ROOT_DIR = args.dir
62
+ else:
63
+ raise ValueError("Provide either the model file or the directory ./containing checkpoints/best_model.h5")
64
+
65
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
66
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
67
+ DATASET = args.dataset
68
+
69
+ print('MODEL LOCATION: ', MODEL_FILE)
70
+
71
+ # data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
72
+ output_folder = os.path.join(DATA_ROOT_DIR, OUTPUT_FOLDER_NAME) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
73
+ # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
74
+ if args.erase:
75
+ shutil.rmtree(output_folder, ignore_errors=True)
76
+ os.makedirs(output_folder, exist_ok=True)
77
+ print('DESTINATION FOLDER: ', output_folder)
78
+
79
+ data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
80
+
81
+ img_generator = data_generator.get_train_generator()
82
+ nb_labels = len(img_generator.get_segmentation_labels())
83
+ image_input_shape = img_generator.get_data_shape()[-1][:-1]
84
+ image_output_shape = [64] * 3
85
+
86
+ # Build model
87
+
88
+ input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm')
89
+ augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
90
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
91
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
92
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
93
+ num_augmentations=C.NUM_AUGMENTATIONS,
94
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
95
+ brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
96
+ in_img_shape=image_input_shape,
97
+ out_img_shape=image_output_shape,
98
+ only_image=False,
99
+ only_resize=False,
100
+ trainable=False)
101
+ augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
102
+
103
+ loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
104
+ NCC(image_input_shape).loss,
105
+ vxm.losses.MSE().loss,
106
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
107
+
108
+ metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
109
+ NCC(image_input_shape).metric,
110
+ vxm.losses.MSE().loss,
111
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
112
+ GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).loss,
113
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).loss]
114
+
115
+ network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
116
+ 'VxmDense': vxm.networks.VxmDense,
117
+ 'AdamAccumulated': AdamAccumulated,
118
+ 'loss': loss_fncs,
119
+ 'metric': metric_fncs},
120
+ compile=False)
121
+
122
+ # Needed for VxmDense type of network
123
+ warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
124
+
125
+ # Record metrics
126
+ metrics = pd.DataFrame(columns=['File', 'SSIM', 'MS-SSIM', 'MSE', 'DICE', 'HD'])
127
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
128
+ config.gpu_options.allow_growth = True
129
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
130
+
131
+ sess = tf.Session(config=config)
132
+ tf.keras.backend.set_session(sess)
133
+ with sess.as_default():
134
+ sess.run(tf.global_variables_initializer())
135
+ network.load_weights(MODEL_FILE, by_name=True)
136
+ progress_bar = tqdm(enumerate(img_generator, 1), desc='Evaluation', total=len(img_generator))
137
+ for step, (in_batch, _) in progress_bar:
138
+ fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
139
+
140
+ if network.name == 'vxm_dense_semi_supervised_seg':
141
+ pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
142
+ else:
143
+ pred_img, disp_map = network.predict([mov_img, fix_img])
144
+ pred_seg = warp_segmentation.predict([mov_seg, disp_map])
145
+
146
+ # I need the labels to be OHE to compute the segmentation metrics.
147
+ dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval()
148
+ hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval()
149
+
150
+ pred_seg = np.argmax(pred_seg, axis=-1)[..., np.newaxis].astype(np.float32)
151
+ mov_seg = np.argmax(mov_seg, axis=-1)[..., np.newaxis].astype(np.float32)
152
+ fix_seg = np.argmax(fix_seg, axis=-1)[..., np.newaxis].astype(np.float32)
153
+
154
+ mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1)
155
+ dest_coords = mov_coords + disp_map[0, ...]
156
+
157
+ ssim = StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric(fix_img, pred_img).eval()
158
+ ms_ssim = MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric(fix_img, pred_img).eval()[0]
159
+ mse = vxm.losses.MSE().loss(fix_img, pred_img).eval()
160
+
161
+ metrics.append({'File': step,
162
+ 'SSIM': ssim,
163
+ 'MS-SSIM': ms_ssim,
164
+ 'MSE': mse,
165
+ 'DICE': dice,
166
+ 'HD': hd}, ignore_index=True)
167
+ save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
168
+ save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
169
+ save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
170
+ save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
171
+ save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
172
+ save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
173
+
174
+ magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
175
+ _ = plt.hist(magnitude.flatten())
176
+ plt.title('Histogram of disp. magnitudes')
177
+ # plt.show(block=False)
178
+ plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
179
+ plt.close()
180
+
181
+ plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures.png'.format(step)), show=False)
182
+ save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
183
+
184
+ progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
185
+
186
+ metrics.to_csv(os.path.join(output_folder, 'metrics.csv'))
187
+ print('Done')
Brain_study/Evaluate_network__test_fixed.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ import shutil
4
+ import time
5
+
6
+ import h5py
7
+ import matplotlib.pyplot as plt
8
+
9
+ currentdir = os.path.dirname(os.path.realpath(__file__))
10
+ parentdir = os.path.dirname(currentdir)
11
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
12
+
13
+ import tensorflow as tf
14
+ # tf.enable_eager_execution(config=config)
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ import voxelmorph as vxm
19
+
20
+ import DeepDeformationMapRegistration.utils.constants as C
21
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
22
+ from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
23
+ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
24
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
25
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
26
+ from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
27
+ from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
28
+ from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
29
+ from scipy.interpolate import RegularGridInterpolator
30
+ from tqdm import tqdm
31
+
32
+ import h5py
33
+ import re
34
+ from Brain_study.data_generator import BatchGenerator
35
+
36
+ import argparse
37
+
38
+ from skimage.transform import warp
39
+ import neurite as ne
40
+
41
+ DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
42
+ MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
43
+ DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
44
+
45
+
46
+ if __name__ == '__main__':
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument('-m', '--model', nargs='+', type=str, help='.h5 of the model', default=None)
49
+ parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
50
+ parser.add_argument('--gpu', type=int, help='GPU', default=0)
51
+ parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
52
+ default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training')
53
+ parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
54
+ parser.add_argument('--outdirname', type=str, default='Evaluate')
55
+ args = parser.parse_args()
56
+ if args.model is not None:
57
+ assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
58
+ MODEL_FILE_LIST = args.model
59
+ DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in args.model]
60
+ elif args.dir is not None:
61
+ assert '.h5' not in args.dir[0], 'Provided checkpoint file, user -m/--model instead'
62
+ MODEL_FILE_LIST = [os.path.join(dir_path, 'checkpoints', 'best_model.h5') for dir_path in args.dir]
63
+ DATA_ROOT_DIR_LIST = args.dir
64
+ else:
65
+ raise ValueError("Provide either the model file or the directory ./containing checkpoints/best_model.h5")
66
+
67
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
68
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
69
+ DATASET = args.dataset
70
+ list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
71
+ list_test_files.sort()
72
+
73
+ with h5py.File(list_test_files[0], 'r') as f:
74
+ image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
75
+ nb_labels = f['fix_segmentations'][:].shape[-1]
76
+
77
+ # Header of the metrics csv file
78
+ csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
79
+
80
+ # TF stuff
81
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
82
+ config.gpu_options.allow_growth = True
83
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
84
+ config.allow_soft_placement = True
85
+
86
+ sess = tf.Session(config=config)
87
+ tf.keras.backend.set_session(sess)
88
+
89
+ # Loss and metric functions. Common to all models
90
+ loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
91
+ NCC(image_input_shape).loss,
92
+ vxm.losses.MSE().loss,
93
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
94
+
95
+ metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
96
+ NCC(image_input_shape).metric,
97
+ vxm.losses.MSE().loss,
98
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
99
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
100
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
101
+ GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
102
+
103
+ ### METRICS GRAPH ###
104
+ fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
105
+ pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
106
+ fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
107
+ pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
108
+
109
+ ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
110
+ ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
111
+ mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
112
+ ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
113
+ dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
114
+ hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
115
+ dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
116
+ # hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
117
+
118
+ # Needed for VxmDense type of network
119
+ warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
120
+
121
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
122
+
123
+ for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
124
+ print('MODEL LOCATION: ', MODEL_FILE)
125
+
126
+ # data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
127
+ output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
128
+ # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
129
+ if args.erase:
130
+ shutil.rmtree(output_folder, ignore_errors=True)
131
+ os.makedirs(output_folder, exist_ok=True)
132
+ print('DESTINATION FOLDER: ', output_folder)
133
+
134
+ try:
135
+ network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
136
+ 'VxmDense': vxm.networks.VxmDense,
137
+ 'AdamAccumulated': AdamAccumulated,
138
+ 'loss': loss_fncs,
139
+ 'metric': metric_fncs},
140
+ compile=False)
141
+ except ValueError as e:
142
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
143
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
144
+ nb_features = [enc_features, dec_features]
145
+ if re.search('^UW|SEGGUIDED_', MODEL_FILE):
146
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
147
+ nb_labels=nb_labels,
148
+ nb_unet_features=nb_features,
149
+ int_steps=0,
150
+ int_downsize=1,
151
+ seg_downsize=1)
152
+ else:
153
+ network = vxm.networks.VxmDense(inshape=image_output_shape,
154
+ nb_unet_features=nb_features,
155
+ int_steps=0)
156
+ network.load_weights(MODEL_FILE, by_name=True)
157
+ # Record metrics
158
+ metrics_file = os.path.join(output_folder, 'metrics.csv')
159
+ with open(metrics_file, 'w') as f:
160
+ f.write(';'.join(csv_header)+'\n')
161
+
162
+ ssim = ncc = mse = ms_ssim = dice = hd = 0
163
+ with sess.as_default():
164
+ sess.run(tf.global_variables_initializer())
165
+ network.load_weights(MODEL_FILE, by_name=True)
166
+ progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
167
+ for step, in_batch in progress_bar:
168
+ with h5py.File(in_batch, 'r') as f:
169
+ fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
170
+ mov_img = f['mov_image'][:][np.newaxis, ...]
171
+ fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
172
+ mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
173
+ fix_centroids = f['fix_centroids'][:]
174
+
175
+ if network.name == 'vxm_dense_semi_supervised_seg':
176
+ t0 = time.time()
177
+ pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
178
+ t1 = time.time()
179
+ else:
180
+ t0 = time.time()
181
+ pred_img, disp_map = network.predict([mov_img, fix_img])
182
+ pred_seg = warp_segmentation.predict([mov_seg, disp_map])
183
+ t1 = time.time()
184
+
185
+ mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(0, 28))
186
+ # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
187
+ pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
188
+
189
+ # I need the labels to be OHE to compute the segmentation metrics.
190
+ dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
191
+
192
+ pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
193
+ mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
194
+ fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
195
+
196
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
197
+ ms_ssim = ms_ssim[0]
198
+
199
+ # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
200
+ upsample_scale = 128 / 64
201
+ fix_centroids_isotropic = fix_centroids * upsample_scale
202
+ # mov_centroids_isotropic = mov_centroids * upsample_scale
203
+ pred_centroids_isotropic = pred_centroids * upsample_scale
204
+
205
+ fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
206
+ # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
207
+ pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
208
+ # Now we can measure the TRE in mm
209
+ tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
210
+ tre = np.mean([v for v in tre_array if not np.isnan(v)])
211
+ # ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
212
+
213
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
214
+ with open(metrics_file, 'a') as f:
215
+ f.write(';'.join(map(str, new_line))+'\n')
216
+
217
+ save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
218
+ save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
219
+ save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
220
+ save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
221
+ save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
222
+ save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
223
+
224
+ # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
225
+ # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
226
+ # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
227
+ # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
228
+ # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
229
+ # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
230
+
231
+ # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
232
+ # _ = plt.hist(magnitude.flatten())
233
+ # plt.title('Histogram of disp. magnitudes')
234
+ # plt.show(block=False)
235
+ # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
236
+ # plt.close()
237
+
238
+ plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures.png'.format(step)), show=False)
239
+ save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
240
+
241
+ progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
242
+
243
+ print('Summary\n=======\n')
244
+ print(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0))
245
+ print('\n=======\n')
246
+ tf.keras.backend.clear_session()
247
+ # sess.close()
248
+ del network
249
+ print('Done')
Brain_study/MultiTrain_Baseline.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ from Brain_study.Train_Baseline import launch_train
7
+ import argparse
8
+
9
+ TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
10
+
11
+ err = list()
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
16
+ parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
17
+ parser.add_argument('--similarity', type=str, help='Similarity metric: mse, ncc, ssim')
18
+ parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
19
+ parser.add_argument('--gpu', type=str, help='GPU number', default='0')
20
+ parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
21
+ parser.add_argument('--rw', type=float, help='Regularization weigh', default=5e-3)
22
+
23
+ args = parser.parse_args()
24
+
25
+ print('TRAIN ' + args.dataset)
26
+ launch_train(dataset_folder=args.dataset,
27
+ validation_folder=args.validation,
28
+ output_folder=os.path.join(args.output, 'BASELINE_L_{}__MET_mse_ncc_ssim'.format(args.similarity)),
29
+ gpu_num=args.gpu,
30
+ lr=args.lr,
31
+ rw=args.rw,
32
+ simil=args.similarity)
Brain_study/MultiTrain_SegGuided.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ from Brain_study.Train_SegmentationGuided import launch_train
7
+ import argparse
8
+
9
+ TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
10
+
11
+ err = list()
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
16
+ parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
17
+ parser.add_argument('--similarity', type=str, help='Similarity loss function: mse, ncc, ssim')
18
+ parser.add_argument('--segmentation', type=str, help='Segmentation loss function: hd, dice')
19
+ parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
20
+ parser.add_argument('--gpu', type=str, help='GPU number', default='0')
21
+ parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
22
+ parser.add_argument('--rw', type=float, help='Regularization weigh', default=2e-2)
23
+
24
+ args = parser.parse_args()
25
+
26
+ print('TRAIN ' + args.dataset)
27
+ launch_train(dataset_folder=args.dataset,
28
+ validation_folder=args.validation,
29
+ output_folder=os.path.join(args.output, 'SEGGUIDED_Lsim_{}__Lseg_{}__MET_mse_ncc_ssim'.format(args.similarity, args.segmentation)),
30
+ gpu_num=args.gpu,
31
+ lr=args.lr,
32
+ rw=args.rw,
33
+ simil=args.similarity,
34
+ segm=args.segmentation)
Brain_study/MultiTrain_UW.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ from Brain_study.Train_UncertaintyWeighted import launch_train
7
+ import argparse
8
+
9
+ TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
10
+
11
+ err = list()
12
+
13
+ if __name__ == '__main__':
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
16
+ parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
17
+ parser.add_argument('--similarity', nargs='+', type=str, help='Similarity loss function: mse, ncc, ssim', default=[])
18
+ parser.add_argument('--segmentation', nargs='+', type=str, help='Segmentation loss function: hd, dice', default=[])
19
+ parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
20
+ parser.add_argument('--gpu', type=str, help='GPU number', default='0')
21
+ parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
22
+ parser.add_argument('--rw', type=float, help='Regularization weigh', default=5e-3)
23
+
24
+ args = parser.parse_args()
25
+
26
+ output_folder = os.path.join(args.output, 'UW_Lsim_{}__Lseg_{}__MET_mse_ncc_ssim'.format('__'.join(args.similarity),
27
+ '__'.join(args.segmentation)))
28
+ print('TRAIN ' + args.dataset)
29
+ launch_train(dataset_folder=args.dataset,
30
+ validation_folder=args.validation,
31
+ output_folder=output_folder,
32
+ gpu_num=args.gpu,
33
+ prior_reg_w=args.rw,
34
+ lr=args.lr,
35
+ simil=args.similarity,
36
+ segm=args.segmentation)
37
+
38
+
Brain_study/Train_Baseline.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+
9
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
10
+ from tensorflow.keras import Input
11
+ from tensorflow.keras.models import Model
12
+ from tensorflow.python.keras.utils import Progbar
13
+ from tensorflow.python.framework.errors import InvalidArgumentError
14
+
15
+ import voxelmorph as vxm
16
+ import neurite as ne
17
+ import h5py
18
+ import pickle
19
+
20
+ import DeepDeformationMapRegistration.utils.constants as C
21
+ from DeepDeformationMapRegistration.losses import NCC, StructuralSimilarity, StructuralSimilarity_simplified
22
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir, DatasetCopy, function_decorator
23
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
24
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
25
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
26
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS
27
+
28
+ from Brain_study.data_generator import BatchGenerator
29
+ from Brain_study.utils import SummaryDictionary, named_logs
30
+
31
+ from tqdm import tqdm
32
+ from datetime import datetime
33
+
34
+
35
+ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim'):
36
+ assert dataset_folder is not None and output_folder is not None
37
+
38
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
39
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
40
+ C.GPU_NUM = str(gpu_num)
41
+
42
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
43
+ os.makedirs(output_folder, exist_ok=True)
44
+ # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
45
+ log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
46
+ C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
47
+ C.VALIDATION_DATASET = validation_folder
48
+ C.ACCUM_GRADIENT_STEP = 16
49
+ C.BATCH_SIZE = 16 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
50
+ C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
51
+ C.LEARNING_RATE = lr
52
+ C.LIMIT_NUM_SAMPLES = None
53
+ C.EPOCHS = 10000
54
+
55
+ aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
56
+ "GPU: {}\n" \
57
+ "BATCH SIZE: {}\n" \
58
+ "LR: {}\n" \
59
+ "SIMILARITY: {}\n" \
60
+ "REG. WEIGHT: {}\n" \
61
+ "EPOCHS: {:d}\n" \
62
+ "ACCUM. GRAD: {}\n" \
63
+ "EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
64
+ C.TRAINING_DATASET,
65
+ C.VALIDATION_DATASET,
66
+ C.GPU_NUM,
67
+ C.BATCH_SIZE,
68
+ C.LEARNING_RATE,
69
+ simil,
70
+ rw,
71
+ C.EPOCHS,
72
+ C.ACCUM_GRADIENT_STEP,
73
+ C.EARLY_STOP_PATIENCE)
74
+
75
+ log_file.write(aux)
76
+ print(aux)
77
+
78
+ # Load data
79
+ # Build data generator
80
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
81
+ C.TRAINING_PERC, True, ['none'], directory_val=C.VALIDATION_DATASET)
82
+
83
+ train_generator = data_generator.get_train_generator()
84
+ validation_generator = data_generator.get_validation_generator()
85
+
86
+ image_input_shape = train_generator.get_data_shape()[-1][:-1]
87
+ image_output_shape = [64] * 3
88
+
89
+ # Config the training sessions
90
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
91
+ config.gpu_options.allow_growth = True
92
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
93
+ sess = tf.Session(config=config)
94
+ tf.keras.backend.set_session(sess)
95
+
96
+ # Build model
97
+ input_layer_train = Input(shape=train_generator.get_data_shape()[-1], name='input_train')
98
+ augm_layer_train = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
99
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
100
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
101
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
102
+ num_augmentations=C.NUM_AUGMENTATIONS,
103
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
104
+ brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
105
+ in_img_shape=image_input_shape,
106
+ out_img_shape=image_output_shape,
107
+ only_image=True,
108
+ only_resize=False,
109
+ trainable=False)
110
+ augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
111
+
112
+ input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid')
113
+ augm_layer_valid = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
114
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
115
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
116
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
117
+ num_augmentations=C.NUM_AUGMENTATIONS,
118
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
119
+ brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
120
+ in_img_shape=image_input_shape,
121
+ out_img_shape=image_output_shape,
122
+ only_image=False,
123
+ only_resize=False,
124
+ trainable=False)
125
+ augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
126
+
127
+ # Build model
128
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
129
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
130
+ nb_features = [enc_features, dec_features]
131
+ network = vxm.networks.VxmDense(inshape=image_output_shape,
132
+ nb_unet_features=nb_features,
133
+ int_steps=0)
134
+
135
+ # Losses and loss weights
136
+ SSIM_KER_SIZE = 5
137
+ MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
138
+ MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
139
+ if simil.lower() == 'mse':
140
+ loss_fnc = vxm.losses.MSE().loss
141
+ elif simil.lower() == 'ncc':
142
+ loss_fnc = NCC(image_input_shape).loss
143
+ elif simil.lower() == 'ssim':
144
+ loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
145
+ elif simil.lower() == 'ms_ssim':
146
+ loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
147
+ elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse':
148
+ @function_decorator('MSSSIM_MSE__loss')
149
+ def loss_fnc(y_true, y_pred):
150
+ return vxm.losses.MSE().loss(y_true, y_pred) +\
151
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
152
+ elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc':
153
+ @function_decorator('MSSSIM_NCC__loss')
154
+ def loss_fnc(y_true, y_pred):
155
+ return NCC(image_input_shape).loss(y_true, y_pred) +\
156
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
157
+ elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse':
158
+ @function_decorator('SSIM_MSE__loss')
159
+ def loss_fnc(y_true, y_pred):
160
+ return vxm.losses.MSE().loss(y_true, y_pred) +\
161
+ StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
162
+ elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc':
163
+ @function_decorator('SSIM_NCC__loss')
164
+ def loss_fnc(y_true, y_pred):
165
+ return NCC(image_input_shape).loss(y_true, y_pred) +\
166
+ StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
167
+ else:
168
+ raise ValueError('Unknown similarity metric: ' + simil)
169
+
170
+ # Train
171
+ os.makedirs(output_folder, exist_ok=True)
172
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
173
+ os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
174
+ os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
175
+
176
+ callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
177
+ save_best_only=True, monitor='val_loss', verbose=1, mode='min')
178
+ # callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
179
+ # save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
180
+ # CSVLogger(train_log_name, ';'),
181
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
182
+ callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
183
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
184
+ update_freq='epoch', # or 'batch' or integer
185
+ write_graph=True, write_grads=True
186
+ )
187
+ callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
188
+
189
+ losses = {'transformer': loss_fnc,
190
+ 'flow': vxm.losses.Grad('l2').loss}
191
+ metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
192
+ MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
193
+ tf.keras.losses.MSE,
194
+ NCC(image_input_shape).metric],
195
+ #'flow': vxm.losses.Grad('l2').loss
196
+ }
197
+ loss_weights = {'transformer': 1.,
198
+ 'flow': rw}
199
+
200
+ # Compile the model
201
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
202
+ network.compile(optimizer=optimizer,
203
+ loss=losses,
204
+ loss_weights=loss_weights,
205
+ metrics=metrics)
206
+
207
+ callback_tensorboard.set_model(network)
208
+ callback_best_model.set_model(network)
209
+ # callback_save_model.set_model(network)
210
+ callback_early_stop.set_model(network)
211
+ # TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
212
+
213
+ summary = SummaryDictionary(network, C.BATCH_SIZE)
214
+ names = network.metrics_names # It give both the loss and metric names
215
+ log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
216
+ with sess.as_default():
217
+ callback_tensorboard.on_train_begin()
218
+ callback_early_stop.on_train_begin()
219
+ callback_best_model.on_train_begin()
220
+ # callback_save_model.on_train_begin()
221
+ for epoch in range(C.EPOCHS):
222
+ callback_tensorboard.on_epoch_begin(epoch)
223
+ callback_early_stop.on_epoch_begin(epoch)
224
+ callback_best_model.on_epoch_begin(epoch)
225
+ # callback_save_model.on_epoch_begin(epoch)
226
+
227
+ print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
228
+ print('TRAINING')
229
+
230
+ log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
231
+ progress_bar = Progbar(len(train_generator), width=30, verbose=1)
232
+ for step, (in_batch, _) in enumerate(train_generator, 1):
233
+ # callback_tensorboard.on_train_batch_begin(step)
234
+ callback_best_model.on_train_batch_begin(step)
235
+ # callback_save_model.on_train_batch_begin(step)
236
+ callback_early_stop.on_train_batch_begin(step)
237
+
238
+ try:
239
+ fix_img, mov_img, *_ = augm_model_train.predict(in_batch)
240
+ np.nan_to_num(fix_img, copy=False)
241
+ np.nan_to_num(mov_img, copy=False)
242
+ if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
243
+ msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
244
+ np.unique(mov_img))
245
+ print(msg)
246
+ log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
247
+
248
+ except InvalidArgumentError as err:
249
+ print('TF Error : {}'.format(str(err)))
250
+ continue
251
+
252
+ ret = network.train_on_batch(x=(mov_img, fix_img),
253
+ y=(fix_img, fix_img)) # The second element doesn't matter
254
+ if np.isnan(ret).any():
255
+ os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
256
+ save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
257
+ save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
258
+ pred_img, dm = network((mov_img, fix_img))
259
+ save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
260
+ save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
261
+ log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
262
+ raise ValueError('CORRUPTION ERROR: Halting training')
263
+
264
+ summary.on_train_batch_end(ret)
265
+ # callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
266
+ callback_best_model.on_train_batch_end(step, named_logs(network, ret))
267
+ # callback_save_model.on_train_batch_end(step, named_logs(network, ret))
268
+ callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
269
+ progress_bar.update(step, zip(names, ret))
270
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
271
+ val_values = progress_bar._values.copy()
272
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
273
+
274
+ print('\nVALIDATION')
275
+ log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
276
+ progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
277
+ for step, (in_batch, _) in enumerate(validation_generator, 1):
278
+ # callback_tensorboard.on_test_batch_begin(step) # This is cursed, don't do it again
279
+ # callback_early_stop.on_test_batch_begin(step)
280
+ try:
281
+ fix_img, mov_img, *_ = augm_model_valid.predict(in_batch)
282
+ except InvalidArgumentError as err:
283
+ print('TF Error : {}'.format(str(err)))
284
+ continue
285
+
286
+ ret = network.test_on_batch(x=(mov_img, fix_img),
287
+ y=(fix_img, fix_img))
288
+ # pred_segm = network.register(mov_segm, fix_segm)
289
+ summary.on_validation_batch_end(ret)
290
+ # callback_early_stop.on_test_batch_end(step, named_logs(network, ret))
291
+ # callback_tensorboard.on_test_batch_end(step, named_logs(network, ret)) # This is cursed, don't do it again
292
+ progress_bar.update(step, zip(names, ret))
293
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
294
+ val_values = progress_bar._values.copy()
295
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
296
+
297
+ train_generator.on_epoch_end()
298
+ validation_generator.on_epoch_end()
299
+ epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
300
+ callback_tensorboard.on_epoch_end(epoch, epoch_summary)
301
+ callback_early_stop.on_epoch_end(epoch, epoch_summary)
302
+ callback_best_model.on_epoch_end(epoch, epoch_summary)
303
+ # callback_save_model.on_epoch_end(epoch, epoch_summary)
304
+ print('End of epoch {}: '.format(epoch), ret, '\n')
305
+
306
+ callback_tensorboard.on_train_end()
307
+ # callback_save_model.on_train_end()
308
+ callback_best_model.on_train_end()
309
+ callback_early_stop.on_train_end()
310
+
311
+
312
+ if __name__ == '__main__':
313
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
314
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
315
+
316
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
317
+ config.gpu_options.allow_growth = True
318
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
319
+ tf.keras.backend.set_session(tf.Session(config=config))
320
+
321
+ launch_train('/mnt/EncryptedData1/Users/javier/vessel_registration/LiTS/None',
322
+ 'TrainOutput/THESIS/UW_None_mse_ssim_haus', 0, mse=True)
Brain_study/Train_SegmentationGuided.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
9
+ from tensorflow.keras import Input
10
+ from tensorflow.keras.models import Model
11
+ from tensorflow.python.keras.utils import Progbar
12
+ from tensorflow.python.framework.errors import InvalidArgumentError
13
+
14
+ import voxelmorph as vxm
15
+ import neurite as ne
16
+ import h5py
17
+ from datetime import datetime
18
+ import pickle
19
+
20
+ import DeepDeformationMapRegistration.utils.constants as C
21
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir, function_decorator
22
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
23
+ from DeepDeformationMapRegistration.losses import NCC, HausdorffDistanceErosion, GeneralizedDICEScore, StructuralSimilarity_simplified
24
+ from DeepDeformationMapRegistration.layers import AugmentationLayer
25
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS
26
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
27
+
28
+ from Brain_study.data_generator import BatchGenerator
29
+ from Brain_study.utils import SummaryDictionary, named_logs
30
+
31
+ import time
32
+
33
+ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd'):
34
+ assert dataset_folder is not None and output_folder is not None
35
+
36
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
37
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
38
+ C.GPU_NUM = str(gpu_num)
39
+
40
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
41
+ os.makedirs(output_folder, exist_ok=True)
42
+ log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
43
+ C.TRAINING_DATASET = dataset_folder
44
+ C.VALIDATION_DATASET = validation_folder
45
+ C.ACCUM_GRADIENT_STEP = 16
46
+ C.BATCH_SIZE = 2 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
47
+ C.EARLY_STOP_PATIENCE = 10 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
48
+ C.LEARNING_RATE = lr
49
+ C.LIMIT_NUM_SAMPLES = None
50
+ C.EPOCHS = 10000
51
+
52
+ aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
53
+ "GPU: {}\n" \
54
+ "BATCH SIZE: {}\n" \
55
+ "LR: {}\n" \
56
+ "SIMILARITY: {}\n" \
57
+ "REG. WEIGHT: {}\n" \
58
+ "EPOCHS: {:d}\n" \
59
+ "ACCUM. GRAD: {}\n" \
60
+ "EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
61
+ C.TRAINING_DATASET,
62
+ C.VALIDATION_DATASET,
63
+ C.GPU_NUM,
64
+ C.BATCH_SIZE,
65
+ C.LEARNING_RATE,
66
+ simil,
67
+ rw,
68
+ C.EPOCHS,
69
+ C.ACCUM_GRADIENT_STEP,
70
+ C.EARLY_STOP_PATIENCE)
71
+ log_file.write(aux)
72
+ print(aux)
73
+
74
+ # Load data
75
+ # Build data generator
76
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
77
+ C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
78
+ directory_val=C.VALIDATION_DATASET)
79
+
80
+ train_generator = data_generator.get_train_generator()
81
+ validation_generator = data_generator.get_validation_generator()
82
+
83
+ image_input_shape = train_generator.get_data_shape()[1][:-1]
84
+ image_output_shape = [64] * 3
85
+
86
+ nb_labels = len(train_generator.get_segmentation_labels())
87
+
88
+ # Config the training sessions
89
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
90
+ config.gpu_options.allow_growth = True
91
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
92
+ config.allow_soft_placement = True # https://github.com/tensorflow/tensorflow/issues/30782
93
+ sess = tf.Session(config=config)
94
+ tf.keras.backend.set_session(sess)
95
+
96
+ # Build model
97
+ input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation')
98
+ augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
99
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
100
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
101
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
102
+ num_augmentations=C.NUM_AUGMENTATIONS,
103
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
104
+ brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
105
+ in_img_shape=image_input_shape,
106
+ out_img_shape=image_output_shape,
107
+ only_image=False,
108
+ only_resize=False,
109
+ trainable=False)
110
+ augm_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
111
+
112
+ enc_features = [16, 32, 32, 32]# const.ENCODER_FILTERS
113
+ dec_features = [32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
114
+ nb_features = [enc_features, dec_features]
115
+
116
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
117
+ nb_labels=nb_labels,
118
+ nb_unet_features=nb_features,
119
+ int_steps=0,
120
+ int_downsize=1,
121
+ seg_downsize=1)
122
+
123
+ # Compile the model
124
+ SSIM_KER_SIZE = 5
125
+ MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
126
+ MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
127
+ if simil=='ssim':
128
+ loss_simil = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
129
+ elif simil=='ms_ssim':
130
+ loss_simil = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
131
+ elif simil=='ncc':
132
+ loss_simil = NCC(image_input_shape).loss
133
+ elif simil=='ms_ssim__ncc' or simil=='ncc__ms_ssim':
134
+ @function_decorator('MS_SSIM_NCC__loss')
135
+ def loss_simil(y_true, y_pred):
136
+ return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred)
137
+ elif simil=='ms_ssim__mse' or simil=='mse__ms_ssim':
138
+ @function_decorator('MS_SSIM_MSE__loss')
139
+ def loss_simil(y_true, y_pred):
140
+ return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred)
141
+ elif simil=='ssim__ncc' or simil=='ncc__ssim' :
142
+ @function_decorator('SSIM_NCC__loss')
143
+ def loss_simil(y_true, y_pred):
144
+ return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred)
145
+ elif simil=='ssim__mse' or simil=='mse__ssim':
146
+ @function_decorator('SSIM_MSE__loss')
147
+ def loss_simil(y_true, y_pred):
148
+ return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred)
149
+ else:
150
+ loss_simil = vxm.losses.MSE().loss
151
+
152
+ if segm == 'hd':
153
+ loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
154
+ elif segm == 'dice':
155
+ loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
156
+ else:
157
+ raise ValueError('No valid value for segm')
158
+
159
+ losses = {'transformer': loss_simil,
160
+ 'seg_transformer': loss_segm,
161
+ 'flow': vxm.losses.Grad('l2').loss}
162
+ loss_weights = {'transformer': 1,
163
+ 'seg_transformer': 1.,
164
+ 'flow': 5e-3}
165
+ metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
166
+ 'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
167
+ HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
168
+ ]}
169
+ metrics_weights = {'transformer': 1,
170
+ 'seg_transformer': 1,
171
+ 'flow': rw}
172
+
173
+ # Train
174
+ os.makedirs(output_folder, exist_ok=True)
175
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
176
+ os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
177
+ os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
178
+
179
+ callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
180
+ save_best_only=True, monitor='val_loss', verbose=1, mode='min')
181
+ # callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
182
+ # save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
183
+ # CSVLogger(train_log_name, ';'),
184
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
185
+ callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
186
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
187
+ update_freq='epoch', # or 'batch' or integer
188
+ write_graph=True, write_grads=True
189
+ )
190
+ callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
191
+
192
+ # Compile the model
193
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
194
+ network.compile(optimizer=optimizer,
195
+ loss=losses,
196
+ loss_weights=loss_weights,
197
+ metrics=metrics)
198
+
199
+ callback_tensorboard.set_model(network)
200
+ callback_best_model.set_model(network)
201
+ # callback_save_model.set_model(network)
202
+ callback_early_stop.set_model(network)
203
+
204
+ summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP)
205
+ names = network.metrics_names # It give both the loss and metric names
206
+ log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
207
+ with sess.as_default():
208
+ #sess.run(tf.global_variables_initializer())
209
+ callback_tensorboard.on_train_begin()
210
+ callback_early_stop.on_train_begin()
211
+ callback_best_model.on_train_begin()
212
+ # callback_save_model.on_train_begin()
213
+ for epoch in range(C.EPOCHS):
214
+ callback_tensorboard.on_epoch_begin(epoch)
215
+ callback_early_stop.on_epoch_begin(epoch)
216
+ callback_best_model.on_epoch_begin(epoch)
217
+ # callback_save_model.on_epoch_begin(epoch)
218
+
219
+ print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
220
+ print('TRAINING')
221
+
222
+ log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
223
+ progress_bar = Progbar(len(train_generator), width=30, verbose=1)
224
+ t0 = time.time()
225
+ for step, (in_batch, _) in enumerate(train_generator, 1):
226
+ #print('Loaded in {} s'.format(time.time() - t0))
227
+ # callback_tensorboard.on_train_batch_begin(step)
228
+ callback_best_model.on_train_batch_begin(step)
229
+ # callback_save_model.on_train_batch_begin(step)
230
+ callback_early_stop.on_train_batch_begin(step)
231
+
232
+ try:
233
+ t0 = time.time()
234
+ fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
235
+ #print('Augmented in {} s'.format(time.time() - t0))
236
+ np.nan_to_num(fix_img, copy=False)
237
+ np.nan_to_num(mov_img, copy=False)
238
+ if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
239
+ msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
240
+ np.unique(mov_img))
241
+ print(msg)
242
+ log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
243
+
244
+ except InvalidArgumentError as err:
245
+ print('TF Error : {}'.format(str(err)))
246
+ continue
247
+
248
+ t0 = time.time()
249
+ ret = network.train_on_batch(x=(mov_img, fix_img, mov_seg),
250
+ y=(fix_img, fix_img, fix_seg))
251
+ # print("Trained on batch in {} s".format(time.time() - t0))
252
+
253
+ if np.isnan(ret).any():
254
+ os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
255
+ save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
256
+ save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
257
+ pred_img, dm = network((mov_img, fix_img))
258
+ save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
259
+ save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
260
+ log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
261
+ raise ValueError('CORRUPTION ERROR: Halting training')
262
+
263
+ summary.on_train_batch_end(ret)
264
+ # callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
265
+ callback_best_model.on_train_batch_end(step, named_logs(network, ret))
266
+ # callback_save_model.on_train_batch_end(step, named_logs(network, ret))
267
+ callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
268
+ progress_bar.update(step, zip(names, ret))
269
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
270
+ t0 = time.time()
271
+ print('End of epoch{}: '.format(step), ret, '\n')
272
+ val_values = progress_bar._values.copy()
273
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
274
+
275
+ print('\nVALIDATION')
276
+ log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
277
+ progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
278
+ for step, (in_batch, _) in enumerate(validation_generator, 1):
279
+ # callback_tensorboard.on_test_batch_begin(step) # This is cursed, don't do it again
280
+ # callback_early_stop.on_test_batch_begin(step)
281
+ try:
282
+ fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
283
+ except InvalidArgumentError as err:
284
+ print('TF Error : {}'.format(str(err)))
285
+ continue
286
+
287
+ ret = network.test_on_batch(x=(mov_img, fix_img, mov_seg),
288
+ y=(fix_img, fix_img, fix_seg))
289
+ # pred_segm = network.register(mov_segm, fix_segm)
290
+ summary.on_validation_batch_end(ret)
291
+ # callback_early_stop.on_test_batch_end(step, named_logs(network, ret))
292
+ # callback_tensorboard.on_test_batch_end(step, named_logs(network, ret)) # This is cursed, don't do it again
293
+ progress_bar.update(step, zip(names, ret))
294
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
295
+ val_values = progress_bar._values.copy()
296
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
297
+
298
+ train_generator.on_epoch_end()
299
+ validation_generator.on_epoch_end()
300
+ epoch_summary = summary.on_epoch_end()
301
+ callback_tensorboard.on_epoch_end(epoch, epoch_summary)
302
+ callback_early_stop.on_epoch_end(epoch, epoch_summary)
303
+ callback_best_model.on_epoch_end(epoch, epoch_summary)
304
+ # callback_save_model.on_epoch_end(epoch, named_logs(network, ret, True))
305
+
306
+ callback_tensorboard.on_train_end()
307
+ # callback_save_model.on_train_end()
308
+ callback_best_model.on_train_end()
309
+ callback_early_stop.on_train_end()
310
+
311
+
312
+ if __name__ == '__main__':
313
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
314
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
315
+
316
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
317
+ config.gpu_options.allow_growth = True
318
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
319
+ tf.keras.backend.set_session(tf.Session(config=config))
320
+
321
+ launch_train('/mnt/EncryptedData1/Users/javier/Brain_study/ERASE',
322
+ 'TrainOutput/THESIS/UW_None_mse_ssim_haus',
323
+ 0)
Brain_study/Train_UncertaintyWeighted.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
9
+ from tensorflow.keras import Input
10
+ from tensorflow.keras.models import Model
11
+ from tensorflow.python.keras.utils import Progbar
12
+ from tensorflow.python.framework.errors import InvalidArgumentError
13
+ import voxelmorph as vxm
14
+ import neurite as ne
15
+ import h5py
16
+ from datetime import datetime
17
+ import pickle
18
+
19
+ import DeepDeformationMapRegistration.utils.constants as C
20
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir, DatasetCopy, function_decorator
21
+ from DeepDeformationMapRegistration.networks import WeaklySupervised
22
+ from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion, NCC, StructuralSimilarity_simplified, GeneralizedDICEScore
23
+ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS
24
+ from DeepDeformationMapRegistration.layers import UncertaintyWeighting, AugmentationLayer
25
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
26
+
27
+ from Brain_study.data_generator import BatchGenerator
28
+ from Brain_study.utils import SummaryDictionary, named_logs
29
+
30
+
31
+ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4,
32
+ gpu_num=0, simil=['mse'], segm=['dice']):
33
+ assert dataset_folder is not None and output_folder is not None
34
+
35
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
36
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
37
+ C.GPU_NUM = str(gpu_num)
38
+
39
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
40
+ # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
41
+ os.makedirs(output_folder, exist_ok=True)
42
+ log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
43
+ C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
44
+ C.VALIDATION_DATASET = validation_folder
45
+ C.ACCUM_GRADIENT_STEP = 16
46
+ C.BATCH_SIZE = 2 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
47
+ C.EARLY_STOP_PATIENCE = 10 * (C.ACCUM_GRADIENT_STEP/2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
48
+ C.LEARNING_RATE = lr
49
+ C.LIMIT_NUM_SAMPLES = None
50
+ C.EPOCHS = 10000
51
+
52
+ aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
53
+ "GPU: {}\n" \
54
+ "BATCH SIZE: {}\n" \
55
+ "LR: {}\n" \
56
+ "SIMILARITY {:d}: {}\n" \
57
+ "SEGMENTATION {:d}: {}\n" \
58
+ "EPOCHS: {:d}" \
59
+ "ACCUM. GRAD: {}" \
60
+ "EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
61
+ C.TRAINING_DATASET,
62
+ C.VALIDATION_DATASET,
63
+ C.GPU_NUM,
64
+ C.BATCH_SIZE,
65
+ C.LEARNING_RATE,
66
+ len(simil), ', '.join(simil),
67
+ len(segm), ', '.join(segm),
68
+ C.EPOCHS,
69
+ C.ACCUM_GRADIENT_STEP,
70
+ C.EARLY_STOP_PATIENCE)
71
+ log_file.write(aux)
72
+ print(aux)
73
+
74
+ # Load data
75
+ # Build data generator
76
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
77
+ C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
78
+ directory_val=C.VALIDATION_DATASET)
79
+
80
+ train_generator = data_generator.get_train_generator()
81
+ validation_generator = data_generator.get_validation_generator()
82
+
83
+ image_input_shape = train_generator.get_data_shape()[-1][:-1]
84
+ image_output_shape = [64] * 3
85
+
86
+ nb_labels = len(train_generator.get_segmentation_labels())
87
+
88
+ # Config the training sessions
89
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
90
+ config.gpu_options.allow_growth = True
91
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
92
+ config.allow_soft_placement = True
93
+ sess = tf.Session(config=config)
94
+ tf.keras.backend.set_session(sess)
95
+
96
+ # Losses and loss weights
97
+ SSIM_KER_SIZE = 5
98
+ MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
99
+ MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
100
+
101
+ loss_simil = []
102
+ prior_loss_w = []
103
+ for s in simil:
104
+ if s=='ssim':
105
+ loss_simil.append(StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss)
106
+ prior_loss_w.append(1.)
107
+ elif s=='ms_ssim':
108
+ loss_simil.append(MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss)
109
+ prior_loss_w.append(1.)
110
+ elif s=='ncc':
111
+ loss_simil.append(NCC(image_input_shape).loss)
112
+ prior_loss_w.append(1.)
113
+ elif s=='mse':
114
+ loss_simil.append(vxm.losses.MSE().loss)
115
+ prior_loss_w.append(1.)
116
+ else:
117
+ raise ValueError('Unknown similarity function: ', s)
118
+
119
+ loss_segm = []
120
+ for s in segm:
121
+ if s=='dice':
122
+ loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).loss)
123
+ prior_loss_w.append(1.)
124
+ elif s=='hd':
125
+ loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).loss)
126
+ prior_loss_w.append(1.)
127
+ else:
128
+ raise ValueError('Unknown similarity function: ', s)
129
+
130
+ # Build augmentation layer model
131
+ input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation')
132
+ augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
133
+ max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
134
+ max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
135
+ num_control_points=C.NUM_CONTROL_PTS_AUG,
136
+ num_augmentations=C.NUM_AUGMENTATIONS,
137
+ gamma_augmentation=C.GAMMA_AUGMENTATION,
138
+ in_img_shape=image_input_shape,
139
+ out_img_shape=image_output_shape,
140
+ only_image=False,
141
+ only_resize=False,
142
+ trainable=False)
143
+ augmentation_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
144
+
145
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
146
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
147
+ nb_features = [enc_features, dec_features]
148
+ network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
149
+ nb_labels=nb_labels,
150
+ nb_unet_features=nb_features,
151
+ int_steps=0,
152
+ int_downsize=1,
153
+ seg_downsize=1)
154
+ # Network inputs: mov_img, fix_img, mov_seg
155
+ # Network outputs: pred_img, disp_map, pred_seg
156
+ grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
157
+ fix_seg = tf.keras.Input(shape=(*image_output_shape, len(train_generator.get_segmentation_labels())),
158
+ name='multiLoss_fix_seg_input', dtype=tf.float32)
159
+
160
+ multiLoss = UncertaintyWeighting(num_loss_fns=len(loss_simil) + len(loss_segm),
161
+ num_reg_fns=1,
162
+ loss_fns=[*loss_simil,
163
+ *loss_segm],
164
+ reg_fns=[vxm.losses.Grad('l2').loss],
165
+ prior_loss_w=prior_loss_w,
166
+ # prior_loss_w=[1., 0.1, 1., 1.],
167
+ prior_reg_w=[prior_reg_w],
168
+ name='MultiLossLayer')
169
+ loss = multiLoss([*[network.inputs[1]]*len(loss_simil), *[fix_seg]*len(loss_segm),
170
+ *[network.outputs[0]]*len(loss_simil), *[network.outputs[2]]*len(loss_simil),
171
+ grad,
172
+ network.outputs[1]])
173
+
174
+ # inputs = [mov_img, fix_img, mov_segm, fix_segm, zero_grads]
175
+ # outputs = [pred_img, flow, pred_segm, loss]
176
+ full_model = tf.keras.Model(inputs=network.inputs + [fix_seg, grad],
177
+ outputs=network.outputs + [loss])
178
+
179
+ # Train
180
+ os.makedirs(output_folder, exist_ok=True)
181
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
182
+ os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
183
+ os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
184
+
185
+ callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
186
+ save_best_only=True, monitor='val_loss', verbose=1, mode='min')
187
+ # callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
188
+ # save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
189
+ # CSVLogger(train_log_name, ';'),
190
+ # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
191
+ callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
192
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
193
+ update_freq='epoch', # or 'batch' or integer
194
+ write_graph=True, write_grads=True
195
+ )
196
+ callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
197
+
198
+ # Compile the model
199
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
200
+ full_model.compile(optimizer=optimizer, loss=None)
201
+
202
+ callback_tensorboard.set_model(full_model)
203
+ callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
204
+ # callback_save_model.set_model(network)
205
+ callback_early_stop.set_model(full_model)
206
+ # TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
207
+
208
+ summary = SummaryDictionary(full_model, C.BATCH_SIZE)
209
+ names = full_model.metrics_names # It give both the loss and metric names
210
+ zero_grads = tf.zeros_like(network.references.pos_flow, name='dummy_zero_grads') # Dummy zeros-tensor
211
+ log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
212
+ with sess.as_default():
213
+ callback_tensorboard.on_train_begin()
214
+ callback_early_stop.on_train_begin()
215
+ callback_best_model.on_train_begin()
216
+ # callback_save_model.on_train_begin()
217
+ for epoch in range(C.EPOCHS):
218
+ callback_tensorboard.on_epoch_begin(epoch)
219
+ callback_early_stop.on_epoch_begin(epoch)
220
+ callback_best_model.on_epoch_begin(epoch)
221
+ # callback_save_model.on_epoch_begin(epoch)
222
+
223
+ print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
224
+ print('TRAINING')
225
+
226
+ log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
227
+ progress_bar = Progbar(len(train_generator), width=30, verbose=1)
228
+ for step, (in_batch, _) in enumerate(train_generator, 1):
229
+ # callback_tensorboard.on_train_batch_begin(step)
230
+ callback_best_model.on_train_batch_begin(step)
231
+ # callback_save_model.on_train_batch_begin(step)
232
+ callback_early_stop.on_train_batch_begin(step)
233
+
234
+ try:
235
+ fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
236
+ np.nan_to_num(fix_img, copy=False)
237
+ np.nan_to_num(mov_img, copy=False)
238
+ except InvalidArgumentError as err:
239
+ print('TF Error : {}'.format(str(err)))
240
+ continue
241
+ # inputs = [mov_img, fix_img, mov_segm, fix_segm, zero_grads]
242
+ # outputs = [pred_img, flow, pred_segm, loss]
243
+ ret = full_model.train_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads))
244
+
245
+ summary.on_train_batch_end(ret)
246
+ # callback_tensorboard.on_train_batch_end(step, named_logs(full_model, ret))
247
+ callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
248
+ # callback_save_model.on_train_batch_end(step, named_logs(network, ret))
249
+ callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
250
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
251
+ # print(ret, '\n')
252
+ progress_bar.update(step, zip(names, ret))
253
+ print('End of epoch{}: '.format(step), ret, '\n')
254
+ val_values = progress_bar._values.copy()
255
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
256
+
257
+ print('\nVALIDATION')
258
+ log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
259
+ progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
260
+ for step, (in_batch, _) in enumerate(validation_generator, 1):
261
+ # callback_tensorboard.on_test_batch_begin(step) # This is cursed, don't do it again
262
+ # callback_early_stop.on_test_batch_begin(step)
263
+ try:
264
+ fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
265
+ except InvalidArgumentError as err:
266
+ print('TF Error : {}'.format(str(err)))
267
+ continue
268
+
269
+ ret = full_model.test_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads))
270
+ # pred_segm = network.register(mov_segm, fix_segm)
271
+ summary.on_validation_batch_end(ret)
272
+ # callback_early_stop.on_test_batch_end(step, named_logs(full_model, ret))
273
+ # callback_tensorboard.on_test_batch_end(step, named_logs(network, ret)) # This is cursed, don't do it again
274
+ progress_bar.update(step, zip(names, ret))
275
+ log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
276
+ val_values = progress_bar._values.copy()
277
+ ret = [val_values[x][0]/val_values[x][1] for x in names]
278
+
279
+ train_generator.on_epoch_end()
280
+ validation_generator.on_epoch_end()
281
+ epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
282
+ callback_tensorboard.on_epoch_end(epoch, epoch_summary)
283
+ callback_best_model.on_epoch_end(epoch, epoch_summary)
284
+ callback_early_stop.on_epoch_end(epoch, epoch_summary)
285
+ # callback_save_model.on_train_end(epoch, epoch_summary)
286
+
287
+ callback_tensorboard.on_train_end()
288
+ # callback_save_model.on_train_end()
289
+ callback_best_model.on_train_end()
290
+ callback_early_stop.on_train_end()
291
+
292
+
293
+ if __name__ == '__main__':
294
+ os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
295
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
296
+
297
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
298
+ config.gpu_options.allow_growth = True
299
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
300
+ tf.keras.backend.set_session(tf.Session(config=config))
301
+
302
+ launch_train('/mnt/EncryptedData1/Users/javier/Brain_study/ERASE',
303
+ 'TrainOutput/THESIS/UW_None_mse_ssim_haus',
304
+ 0)
Brain_study/__init__.py ADDED
File without changes
Brain_study/data_generator.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from tensorflow import keras
5
+ import os
6
+ import h5py
7
+ import random
8
+ from PIL import Image
9
+ import nibabel as nib
10
+ from nilearn.image import resample_img
11
+ from skimage.exposure import equalize_adapthist
12
+ from scipy.ndimage import zoom
13
+ import tensorflow as tf
14
+
15
+ import DeepDeformationMapRegistration.utils.constants as C
16
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm
17
+ from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
18
+ from voxelmorph.tf.layers import SpatialTransformer
19
+ from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT
20
+
21
+ from tensorflow.python.keras.preprocessing.image import Iterator
22
+ from tensorflow.python.keras.utils import Sequence
23
+ import sys
24
+
25
+ #import concurrent.futures
26
+ #import multiprocessing as mp
27
+ import time
28
+
29
+ class BatchGenerator:
30
+ def __init__(self,
31
+ directory,
32
+ batch_size,
33
+ shuffle=True,
34
+ split=0.7,
35
+ combine_segmentations=True,
36
+ labels=['all'],
37
+ directory_val=None):
38
+ self.file_directory = directory
39
+ self.batch_size = batch_size
40
+ self.combine_segmentations = combine_segmentations
41
+ self.labels = labels
42
+ self.shuffle = shuffle
43
+ self.split = split
44
+
45
+ if directory_val is None:
46
+ self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
47
+ random.shuffle(self.file_list) if self.shuffle else self.file_list.sort()
48
+ self.num_samples = len(self.file_list)
49
+ training_samples = self.file_list[:int(self.num_samples * self.split)]
50
+
51
+ self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
52
+ if self.split < 1.:
53
+ validation_samples = list(set(self.file_list) - set(training_samples))
54
+ self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
55
+ validation=True)
56
+ else:
57
+ self.validation_iter = None
58
+ else:
59
+ training_samples = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
60
+ random.shuffle(training_samples) if self.shuffle else training_samples.sort()
61
+
62
+ validation_samples = [os.path.join(directory_val, f) for f in os.listdir(directory_val) if f.endswith(('h5', 'hd5'))]
63
+ random.shuffle(validation_samples) if self.shuffle else validation_samples.sort()
64
+
65
+ self.num_samples = len(training_samples) + len(validation_samples)
66
+ self.file_list = training_samples + validation_samples
67
+
68
+ self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
69
+ self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
70
+ validation=True)
71
+
72
+ def get_train_generator(self):
73
+ return self.train_iter
74
+
75
+ def get_validation_generator(self):
76
+ if self.validation_iter is not None:
77
+ return self.validation_iter
78
+ else:
79
+ raise ValueError('No validation iterator. Split must be < 1.0')
80
+
81
+ def get_file_list(self):
82
+ return self.file_list
83
+
84
+ def get_data_shape(self):
85
+ return self.train_iter.get_data_shape()
86
+
87
+
88
+ ALL_LABELS = {2., 3., 4., 6., 8., 9., 11., 12., 14., 16., 20., 23., 29., 33., 39., 53., 67., 76., 102., 203., 210.,
89
+ 211., 218., 219., 232., 233., 254., 255.}
90
+ ALL_LABELS_LOC = {label: loc for label, loc in zip(ALL_LABELS, range(0, len(ALL_LABELS)))}
91
+
92
+
93
+ class BatchIterator(Sequence):
94
+ def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'],
95
+ zero_grads=[64, 64, 64, 3], validation=False, **kwargs):
96
+ # super(BatchIterator, self).__init__(n=len(file_list),
97
+ # batch_size=batch_size,
98
+ # shuffle=shuffle,
99
+ # seed=None,
100
+ # **kwargs)
101
+ self.batch_size = batch_size
102
+ self.shuffle = shuffle
103
+ self.file_list = file_list
104
+ self.combine_segmentations = combine_segmentations
105
+ self.labels = labels
106
+ self.zero_grads = zero_grads
107
+ self.idx_list = np.arange(0, len(self.file_list))
108
+ self.validation = validation
109
+ self._initialize()
110
+ self.shuffle_samples()
111
+
112
+ def _initialize(self):
113
+ with h5py.File(self.file_list[0], 'r') as f:
114
+ self.image_shape = list(f['image'][:].shape)
115
+ self.segm_shape = list(f['segmentation'][:].shape)
116
+ if not self.combine_segmentations:
117
+ self.segm_shape[-1] = len(f['segmentation_labels'][:]) if self.labels[0].lower() == 'all' else len(self.labels)
118
+
119
+ self.batch_shape = self.image_shape.copy()
120
+ if self.labels[0].lower() != 'none':
121
+ self.batch_shape[-1] = 2 if self.combine_segmentations else 1 + self.segm_shape[-1] # +1 because we have the fix and the moving images
122
+
123
+ if self.labels[0] != 'all':
124
+ if isinstance(self.labels[0], str):
125
+ self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels]
126
+
127
+ self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0)
128
+ #self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.batch_size)
129
+ #self.mp_pool = mp.Pool(self.batch_size)
130
+
131
+ def shuffle_samples(self):
132
+ np.random.shuffle(self.idx_list)
133
+
134
+ def __len__(self):
135
+ return self.num_steps
136
+
137
+ def _filter_segmentations(self, segm, segm_labels):
138
+ if self.combine_segmentations:
139
+ # TODO
140
+ warnings.warn('Cannot select labels when combinine_segmentations options is active')
141
+ if self.labels[0] != 'all':
142
+ if set(self.labels).issubset(set(segm_labels)):
143
+ # If labels in self.labels are in segm
144
+ idx = [ALL_LABELS_LOC[l] for l in self.labels]
145
+ segm = segm[..., idx]
146
+ else:
147
+ # Else we have to collect those labels that are contained and complete with zeros
148
+ idx = [ALL_LABELS_LOC[l] for l in list(set(self.labels).intersection(set(segm_labels)))]
149
+ aux = segm.copy()
150
+ segm = np.zeros(self.segm_shape)
151
+ segm[..., :len(idx)] = aux[..., idx]
152
+ # TODO: leave the zero-ed segmentations before or after the selected labels based on the order
153
+ return segm
154
+
155
+ def _load_sample(self, file_path):
156
+ with h5py.File(file_path, 'r') as f:
157
+ img = f['image'][:]
158
+ segm_labels = f['segmentation_labels'][:]
159
+ if self.combine_segmentations:
160
+ segm = f['segmentation'][:]
161
+ else:
162
+ segm = f['segmentation_expanded'][:]
163
+ if segm.shape[-1] != self.segm_shape[-1]:
164
+ aux = np.zeros(self.segm_shape)
165
+ aux[..., :segm.shape[-1]] = segm # Ensure the same shape in case there are missing labels in aux
166
+ segm = aux
167
+ # TODO: selection label segm = aux[..., self.labels] but:
168
+ # what if aux does not have a label in self.labels??
169
+
170
+ if self.labels[0].lower() != 'none' or self.validation: # I expect to ask for the segmentations during val
171
+ segm = self._filter_segmentations(segm, segm_labels)
172
+
173
+ if self.validation:
174
+ ret_val = np.concatenate([img, segm], axis=-1), (img, segm, np.zeros(self.zero_grads))
175
+ else:
176
+ ret_val = np.concatenate([img, segm], axis=-1), (img, np.zeros(self.zero_grads))
177
+ else:
178
+ ret_val = img, (img, np.zeros(self.zero_grads))
179
+ return ret_val
180
+
181
+ def __getitem__(self, idx):
182
+ in_batch = list()
183
+ # out_batch = list()
184
+
185
+ batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size]
186
+ file_list = [self.file_list[i] for i in batch_idxs]
187
+ # if self.batch_size > 1:
188
+ # # Multiprocessing to speed up laoding
189
+ #
190
+ # for ret in self.executor.map(self._load_sample, file_list):
191
+ # b, i = ret
192
+ # in_batch.append(b)
193
+ # # out_batch.append(i)
194
+ # else:
195
+ # No need for multithreading, we are loading a single file
196
+ for f in file_list:
197
+ b, i = self._load_sample(f)
198
+ in_batch.append(b)
199
+ # out_batch.append(i)
200
+
201
+ in_batch = np.asarray(in_batch)
202
+ # out_batch = np.asarray(out_batch)
203
+ return in_batch, in_batch
204
+
205
+ def __iter__(self):
206
+ """Create a generator that iterate over the Sequence."""
207
+ for item in (self[i] for i in range(len(self))):
208
+ yield item
209
+
210
+ def get_data_shape(self):
211
+ return self.batch_shape, self.image_shape, self.segm_shape
212
+
213
+ def on_epoch_end(self):
214
+ self.shuffle_samples()
215
+
216
+ def get_segmentation_labels(self):
217
+ if self.combine_segmentations:
218
+ labels = [1]
219
+ else:
220
+ with h5py.File(self.file_list[0], 'r') as f:
221
+ labels = np.unique(f['segmentation'][:])
222
+ labels = np.sort(labels)[1:] # Ignore the background
223
+ return labels
224
+
225
+
226
+
227
+
228
+
229
+
230
+ '''
231
+ def get_size(obj, seen=None):
232
+ """Recursively finds size of objects"""
233
+ size = sys.getsizeof(obj)
234
+ if seen is None:
235
+ seen = set()
236
+ obj_id = id(obj)
237
+ if obj_id in seen:
238
+ return 0
239
+ # Important mark as seen *before* entering recursion to gracefully handle
240
+ # self-referential objects
241
+ seen.add(obj_id)
242
+ if isinstance(obj, dict):
243
+ size += sum([get_size(v, seen) for v in obj.values()])
244
+ size += sum([get_size(k, seen) for k in obj.keys()])
245
+ elif hasattr(obj, '__dict__'):
246
+ size += get_size(obj.__dict__, seen)
247
+ elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
248
+ size += sum([get_size(i, seen) for i in obj])
249
+ return size
250
+
251
+
252
+ class BatchIterator(Iterator):
253
+ def __init__(self, generator, file_list, input_shape, output_shape, batch_size, shuffle, all_files_in_batch):
254
+ self.file_list = file_list
255
+ self.generator = generator
256
+ self.input_shape = input_shape
257
+ self.nr_of_inputs = len(input_shape)
258
+ self.output_shape = output_shape
259
+ self.nr_of_outputs = len(output_shape)
260
+ self.all_files_in_batch = all_files_in_batch
261
+ self.preload_to_memory = False
262
+ self.file_cache = {}
263
+ self.max_cache_size = 10*1024
264
+ self.verbose = False
265
+ if self.preload_to_memory:
266
+ for filename, file_index in self.file_list:
267
+ file = h5py.File(filename, 'r')
268
+ inputs = {}
269
+ for name, data in file['input'].items():
270
+ inputs[name] = np.copy(data)
271
+ self.file_cache[filename] = {'input': inputs, 'output': np.copy(file['output'])}
272
+ file.close()
273
+ if get_size(self.file_cache) / (1024*1024) >= self.max_cache_size:
274
+ print('File cache has reached limit of', self.max_cache_size, 'MBs')
275
+ break
276
+ epoch_size = len(file_list)
277
+ if all_files_in_batch:
278
+ epoch_size = len(file_list) * 10
279
+ super(BatchIterator, self).__init__(epoch_size, batch_size, shuffle, None)
280
+
281
+ def _get_sample(self, index):
282
+ filename, file_index = self.file_list[index]
283
+ if filename in self.file_cache:
284
+ file = self.file_cache[filename]
285
+ else:
286
+ file = h5py.File(filename, 'r')
287
+ inputs = []
288
+ outputs = []
289
+ for name, data in file['input'].items():
290
+ inputs.append(data[file_index, :])
291
+ for name, data in file['output'].items():
292
+ if len(data.shape) > 1:
293
+ outputs.append(data[file_index, :])
294
+ else:
295
+ outputs.append(data[file_index])
296
+ #outputs.append(file['output'][file_index, :]) # TODO fix
297
+ if filename not in self.file_cache:
298
+ file.close()
299
+ return inputs, outputs
300
+
301
+ def _get_random_sample_in_file(self, file_index):
302
+ filename = self.file_list[file_index]
303
+ file = h5py.File(filename, 'r')
304
+ x = file['output/0']
305
+ sample = np.random.randint(0, x.shape[0])
306
+ #print('Sampling image', sample, 'from file', filename)
307
+ inputs = []
308
+ outputs = []
309
+ for name, data in file['input'].items():
310
+ inputs.append(data[sample, :])
311
+ for name, data in file['output'].items():
312
+ outputs.append(data[file_index, :])
313
+ #outputs.append(file['output'][sample, :]) # TODO FIX output
314
+ file.close()
315
+ return inputs, outputs
316
+
317
+ def next(self):
318
+
319
+ with self.lock:
320
+ index_array = next(self.index_generator)
321
+
322
+ #print(len(index_array))
323
+ return self._get_batches_of_transformed_samples(index_array)
324
+
325
+ def _get_batches_of_transformed_samples(self, index_array):
326
+ start_batch = time.time()
327
+ batches_x = []
328
+ batches_y = []
329
+ for input_index in range(self.nr_of_inputs):
330
+ batches_x.append(np.zeros(tuple([len(index_array)] + list(self.input_shape[input_index]))))
331
+ for output_index in range(self.nr_of_outputs):
332
+ batches_y.append(np.zeros(tuple([len(index_array)] + list(self.output_shape[output_index]))))
333
+
334
+ timings_sampling = np.zeros((len(index_array,)))
335
+ timings_transform = np.zeros((len(index_array,)))
336
+ for batch_index, sample_index in enumerate(index_array):
337
+ # Have to copy here in order to not modify original data
338
+ start = time.time()
339
+ if self.all_files_in_batch:
340
+ input, output = self._get_random_sample_in_file(batch_index)
341
+ else:
342
+ input, output = self._get_sample(sample_index)
343
+ timings_sampling[batch_index] = time.time() - start
344
+ start = time.time()
345
+ input, output = self.generator.transform(input, output)
346
+ timings_transform[batch_index] = time.time() - start
347
+
348
+ #print('inputs', self.nr_of_inputs, len(input))
349
+ for input_index in range(self.nr_of_inputs):
350
+ batches_x[input_index][batch_index] = input[input_index]
351
+ for output_index in range(self.nr_of_outputs):
352
+ batches_y[output_index][batch_index] = output[output_index]
353
+
354
+ elapsed = time.time() - start_batch
355
+ if self.verbose:
356
+ print('Time to prepare batch:', round(elapsed,3), 'seconds')
357
+ print('Sampling mean:', round(timings_sampling.mean(), 3), 'seconds')
358
+ print('Transform mean:', round(timings_transform.mean(), 3), 'seconds')
359
+
360
+ return batches_x, batches_y
361
+
362
+
363
+ CLASSIFICATION = 'classification'
364
+ SEGMENTATION = 'segmentation'
365
+
366
+
367
+ class BatchGenerator():
368
+ def __init__(self, filelist, all_files_in_batch=False):
369
+ self.methods = []
370
+ self.args = []
371
+ self.crop_width_to = None
372
+ self.image_list = []
373
+ self.input_shape = []
374
+ self.output_shape = []
375
+ self.all_files_in_batch = all_files_in_batch
376
+ self.transforms = []
377
+
378
+ if all_files_in_batch:
379
+ file = h5py.File(filelist[0], 'r')
380
+ for name, data in file['input'].items():
381
+ self.input_shape.append(data.shape[1:])
382
+ for name, data in file['output'].items():
383
+ self.output_shape.append(data.shape[1:])
384
+ # TODO fix
385
+ #self.output_shape.append(file['output'].shape[1:])
386
+ file.close()
387
+ self.image_list = filelist
388
+ return
389
+
390
+ # Go through filelist
391
+ first = True
392
+ for filename in filelist:
393
+ samples = None
394
+ # Open file to see how many samples it has
395
+ file = h5py.File(filename, 'r')
396
+ for name, data in file['input'].items():
397
+ if first:
398
+ self.input_shape.append(data.shape[1:])
399
+ samples = data.shape[0]
400
+ # TODO fix
401
+ for name, data in file['output'].items():
402
+ if first:
403
+ self.output_shape.append(data.shape[1:])
404
+ if samples != data.shape[0]:
405
+ raise ValueError()
406
+ #self.output_shape.append(file['output'].shape[1:])
407
+ if len(self.output_shape) == 1:
408
+ self.problem_type = CLASSIFICATION
409
+ else:
410
+ self.problem_type = SEGMENTATION
411
+
412
+ file.close()
413
+ if samples is None:
414
+ raise ValueError()
415
+ # Append a tuple to image_list for each image consisting of filename and index
416
+ print(filename, samples)
417
+ for i in range(samples):
418
+ self.image_list.append((filename, i))
419
+ first = False
420
+
421
+ print('Image generator with', len(self.image_list), ' image samples created')
422
+
423
+ def flow(self, batch_size, shuffle=True):
424
+
425
+ return BatchIterator(self, self.image_list, self.input_shape, self.output_shape, batch_size, shuffle, self.all_files_in_batch)
426
+
427
+ def transform(self, inputs, outputs):
428
+ #input = input.astype(np.float32) # TODO
429
+ #output = output.astype(np.float32)
430
+ for input_indices, output_indices, transform in self.transforms:
431
+ transform.randomize()
432
+ inputs, outputs = transform.transform_all(inputs, outputs, input_indices, output_indices)
433
+ return inputs, outputs
434
+
435
+ def add_transform(self, input_indices: Union[int, List[int], None], output_indices: Union[int, List[int], None], transform: Transform):
436
+ if type(input_indices) is int:
437
+ input_indices = [input_indices]
438
+ if type(output_indices) is int:
439
+ output_indices = [output_indices]
440
+
441
+ self.transforms.append((
442
+ input_indices,
443
+ output_indices,
444
+ transform
445
+ ))
446
+
447
+ def get_size(self):
448
+ if self.all_files_in_batch:
449
+ return 10*len(self.image_list)
450
+ else:
451
+ return len(self.image_list)
452
+
453
+ '''
Brain_study/format_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import nibabel as nib
3
+ from nilearn.image import resample_img
4
+ import os
5
+ import re
6
+ import numpy as np
7
+ from scipy.ndimage import zoom
8
+ from tqdm import tqdm
9
+
10
+ SEGMENTATION_NR2LBL_LUT = {0: 'background',
11
+ 2: 'parietal-right-gm',
12
+ 3: 'lateral-ventricle-left',
13
+ 4: 'occipital-right-gm',
14
+ 6: 'parietal-left-gm',
15
+ 8: 'occipital-left-gm',
16
+ 9: 'lateral-ventricle-right',
17
+ 11: 'globus-pallidus-right',
18
+ 12: 'globus-pallidus-left',
19
+ 14: 'putamen-left',
20
+ 16: 'putamen-right',
21
+ 20: 'brain-stem',
22
+ 23: 'subthalamic-nucleus-right',
23
+ 29: 'fornix-left',
24
+ 33: 'subthalamic-nucleus-left',
25
+ 39: 'caudate-left',
26
+ 53: 'caudate-right',
27
+ 67: 'cerebellum-left',
28
+ 76: 'cerebellum-right',
29
+ 102: 'thalamus-left',
30
+ 203: 'thalamus-right',
31
+ 210: 'frontal-left-gm',
32
+ 211: 'frontal-right-gm',
33
+ 218: 'temporal-left-gm',
34
+ 219: 'temporal-right-gm',
35
+ 232: '3rd-ventricle',
36
+ 233: '4th-ventricle',
37
+ 254: 'fornix-right',
38
+ 255: 'csf'}
39
+ SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
40
+
41
+ IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1'
42
+ SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/anatomical_masks'
43
+
44
+ IMG_NAME_PATTERN = '(.*).nii.gz'
45
+ SEG_NAME_PATTERN = '(.*)_lobes.nii.gz'
46
+
47
+ OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test'
48
+
49
+ if __name__ == '__main__':
50
+ img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
51
+ img_list.sort()
52
+
53
+ seg_list = [os.path.join(SEG_DIRECTORY, f) for f in os.listdir(SEG_DIRECTORY) if f.endswith('.nii.gz')]
54
+ seg_list.sort()
55
+
56
+ os.makedirs(OUT_DIRECTORY, exist_ok=True)
57
+ for seg_file in tqdm(seg_list):
58
+ img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
59
+ img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
60
+
61
+ img = resample_img(nib.load(img_file), np.eye(3))
62
+ seg = resample_img(nib.load(seg_file), np.eye(3), interpolation='nearest')
63
+
64
+ isot_shape = img.shape
65
+
66
+ # Resize to 128x128x128
67
+ img = np.asarray(img.dataobj)
68
+ img = zoom(img, np.asarray([128]*3) / np.asarray(isot_shape), order=3)
69
+
70
+ seg = np.asarray(seg.dataobj)
71
+ seg = zoom(seg, np.asarray([128]*3) / np.asarray(isot_shape), order=0)
72
+
73
+ unique_lbls = np.unique(seg)[1:] # Omit background
74
+ seg_expanded = np.tile(np.zeros_like(seg)[..., np.newaxis], (1, 1, 1, len(unique_lbls)))
75
+ for ch, lbl in enumerate(unique_lbls):
76
+ seg_expanded[seg == lbl, ch] = 1
77
+
78
+ h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
79
+
80
+ h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
81
+ h5_file.create_dataset('segmentation', data=seg[..., np.newaxis].astype(np.uint8), dtype=np.uint8)
82
+ h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
83
+ h5_file.create_dataset('segmentation_labels', data=unique_lbls)
84
+ h5_file.create_dataset('isotropic_shape', data=isot_shape)
85
+
86
+ h5_file.close()
87
+
88
+
89
+
90
+
91
+
Brain_study/split_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+ import warnings
5
+
6
+ import math
7
+ from shutil import copyfile
8
+ from tqdm import tqdm
9
+ import concurrent.futures
10
+ import numpy as np
11
+
12
+
13
+ def copy_file(s_d):
14
+ s, d = s_d
15
+ file_name = os.path.split(s)[-1]
16
+ copyfile(s, os.path.join(d, file_name))
17
+ return int(os.path.exists(d))
18
+
19
+
20
+ if __name__ == '__main__':
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
23
+ parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
24
+ parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
25
+ parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
26
+ parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
27
+ parser.add_argument('-r', '--random', type=bool, help='Randomly split the dataset or not. Default: True', default=True)
28
+
29
+ args = parser.parse_args()
30
+
31
+ assert args.train + args.validation + args.test == 1.0, 'Train+Validation+Test != 1 (100%)'
32
+
33
+ file_set = [os.path.join(args.dir, f) for f in os.listdir(args.dir) if f.endswith(args.format)]
34
+ random.shuffle(file_set) if args.random else file_set.sort()
35
+
36
+ num_files = len(file_set)
37
+ num_validation = math.floor(num_files * args.validation)
38
+ num_test = math.floor(num_files * args.test)
39
+ num_train = num_files - num_test - num_validation
40
+
41
+ dataset_root, dataset_name = os.path.split(args.dir)
42
+ dst_train = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'train_set')
43
+ dst_validation = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'validation_set')
44
+ dst_test = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'test_set')
45
+
46
+ print('OUTPUT INFORMATION\n=============')
47
+ print('Train:\t\t{}'.format(num_train))
48
+ print('Validation:\t{}'.format(num_validation))
49
+ print('Test:\t\t{}'.format(num_test))
50
+ print('Num. samples\t{}'.format(num_files))
51
+ print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_'+dataset_name))
52
+
53
+ dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
54
+
55
+ os.makedirs(dst_train, exist_ok=True)
56
+ os.makedirs(dst_validation, exist_ok=True)
57
+ os.makedirs(dst_test, exist_ok=True)
58
+
59
+ progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
60
+ with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
61
+ results = list(tqdm(ex.map(copy_file, zip(file_set, dest)), desc='Copying files', total=num_files))
62
+
63
+ num_copies = np.sum(results)
64
+ if num_copies == num_files:
65
+ print('Done successfully')
66
+ else:
67
+ warnings.warn('Missing files: {}'.format(num_files - num_copies))
68
+
Brain_study/test_datagenerator.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Brain_study.data_generator import BatchGenerator
2
+
3
+ import DeepDeformationMapRegistration.utils.constants as C
4
+ from tqdm import tqdm
5
+
6
+ from tensorflow import keras
7
+ import tensorflow as tf
8
+ from tensorflow.keras.callbacks import TensorBoard
9
+ import os
10
+ import voxelmorph as vxm
11
+ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
12
+ from DeepDeformationMapRegistration.losses import NCC, StructuralSimilarity, StructuralSimilarity_simplified
13
+
14
+
15
+ def named_logs(model, logs, validation=False):
16
+ result = {'size': C.BATCH_SIZE} # https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8#gistcomment-3041181
17
+ for l in zip(model.metrics_names, logs):
18
+ k = ('val_' if validation else '') + l[0]
19
+ result[k] = l[1]
20
+ return result
21
+
22
+ if __name__ == '__main__':
23
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(1)
24
+
25
+ C.BATCH_SIZE = 12
26
+ C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
27
+ output_folder = "/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE"
28
+
29
+ data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.TRAINING_PERC, True, ['none'])
30
+ train_generator = data_generator.get_train_generator()
31
+ val_generator = data_generator.get_validation_generator()
32
+
33
+ e_iter = tqdm(range(100))
34
+ t_iter = tqdm(train_generator)
35
+ v_iter = tqdm(val_generator)
36
+
37
+ e_iter.set_description('Epoch')
38
+ t_iter.set_description('Train')
39
+ v_iter.set_description('Val')
40
+ #
41
+ # for s in e_iter:
42
+ # for b in t_iter:
43
+ # continue
44
+ #
45
+ # for b in v_iter:
46
+ # continue
47
+
48
+ # Build model
49
+ enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
50
+ dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
51
+ nb_features = [enc_features, dec_features]
52
+ network = vxm.networks.VxmDense(inshape=(64, 64, 64),
53
+ nb_unet_features=nb_features,
54
+ int_steps=0)
55
+
56
+ d = os.path.join(os.getcwd(), 'tensorboard_test')
57
+ os.makedirs(d, exist_ok=True)
58
+ callback_tensorboard = TensorBoard(log_dir=d,
59
+ batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0, update_freq='epoch',
60
+ write_graph=True,
61
+ write_grads=True)
62
+
63
+ losses = {'transformer': StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
64
+ 'flow': vxm.losses.Grad('l2').loss}
65
+ metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
66
+ tf.keras.losses.MSE],
67
+ # 'flow': vxm.losses.Grad('l2').loss
68
+ }
69
+ loss_weights = {'transformer': 1.,
70
+ 'flow': 5e-3}
71
+
72
+ optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
73
+
74
+ network.compile(optimizer=optimizer,
75
+ loss=losses,
76
+ loss_weights=loss_weights,
77
+ metrics=metrics)
78
+ callback_tensorboard.set_model(network)
79
+ dummy = lambda x: named_logs(network, [x, 0, x, 0, 0])
80
+
81
+ callback_tensorboard.on_train_begin()
82
+ for s in e_iter:
83
+ callback_tensorboard.on_epoch_begin(s)
84
+
85
+ for n in range(100):
86
+ callback_tensorboard.on_train_batch_begin(n)
87
+ input('Press enter')
88
+ callback_tensorboard.on_train_batch_end(n, dummy(n))
89
+
90
+ callback_tensorboard.on_epoch_end(s, dummy)
91
+ callback_tensorboard.on_train_end()
Brain_study/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import DeepDeformationMapRegistration.utils.constants as C
3
+
4
+ class SummaryDictionary:
5
+ def __init__(self, model, batch_size, accumulative_gradients_step=None):
6
+ self.train_names = model.metrics_names
7
+ self.val_names = ['val_'+n for n in self.train_names]
8
+ self.batch_size = batch_size
9
+ self.acc_grad_step = accumulative_gradients_step
10
+ self._reset()
11
+
12
+ def _reset(self):
13
+ self.summary_dict = {'size': self.batch_size}
14
+ if self.acc_grad_step is not None:
15
+ self.summary_dict = {'accumulative_grad_step': self.acc_grad_step}
16
+ for k in self.train_names + self.val_names:
17
+ self.summary_dict[k] = list()
18
+
19
+ def on_train_batch_end(self, values):
20
+ for k, v in zip(self.train_names, values):
21
+ self.summary_dict[k].append(v)
22
+
23
+ def on_validation_batch_end(self, values):
24
+ for k, v in zip(self.val_names, values):
25
+ self.summary_dict[k].append(v)
26
+
27
+ def on_epoch_end(self):
28
+ for k, v in self.summary_dict.items():
29
+ self.summary_dict[k] = np.asarray(v).mean()
30
+
31
+ ret_val = self.summary_dict.copy()
32
+ self._reset()
33
+ return ret_val
34
+
35
+
36
+ def named_logs(model, logs, validation=False):
37
+ result = {'size': C.BATCH_SIZE} # https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8#gistcomment-3041181
38
+ for l in zip(model.metrics_names, logs):
39
+ k = ('val_' if validation else '') + l[0]
40
+ result[k] = l[1]
41
+ return result
EvaluationScripts/Evaluate_class.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.signal import correlate as cc
2
+ from scipy.spatial.distance import euclidean
3
+ from skimage.metrics import mean_squared_error as mse
4
+ from skimage.metrics import structural_similarity as ssim
5
+ from medpy.metric.binary import dc, hd95
6
+ import numpy as np
7
+ import pandas as pd
8
+ import os
9
+ from DeepDeformationMapRegistration.utils.constants import EPS
10
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
11
+ from skimage.transform import resize
12
+ from skimage.measure import regionprops, label
13
+
14
+
15
+ def ncc(y_true, y_pred, eps=EPS):
16
+ f_yt = np.reshape(y_true, [-1])
17
+ f_yp = np.reshape(y_pred, [-1])
18
+ mean_yt = np.mean(f_yt)
19
+ mean_yp = np.mean(f_yp)
20
+
21
+ n_f_yt = f_yt - mean_yt
22
+ n_f_yp = f_yp - mean_yp
23
+ norm_yt = np.linalg.norm(f_yt, ord=2)
24
+ norm_yp = np.linalg.norm(f_yp, ord=2)
25
+ numerator = np.sum(np.multiply(n_f_yt, n_f_yp))
26
+ denominator = norm_yt * norm_yp + eps
27
+ return np.divide(numerator, denominator)
28
+
29
+
30
+ class EvaluationFigures:
31
+ def __init__(self, output_folder):
32
+ pd.set_option('display.max_columns', None)
33
+ self.__metrics_df = pd.DataFrame(columns=['Name', 'Train_ds', 'Eval_ds', 'MSE', 'NCC', 'SSIM',
34
+ 'DICE_PAR', 'DICE_TUM', 'DICE_VES',
35
+ 'HD95_PAR', 'HD95_TUM', 'HD95_VES', 'TRE'])
36
+ self.__output_folder = output_folder
37
+
38
+ def add_sample(self, name, train_ds, eval_ds, fix_t, par_t, tum_t, ves_t, centroid_t, fix_p, par_p, tum_p, ves_p, centroid_p, scale_transform=None):
39
+
40
+ n_fix_t = self.__mean_centred_img(fix_t)
41
+ n_fix_p = self.__mean_centred_img(fix_p)
42
+
43
+ if scale_transform is not None:
44
+ s_centroid_t = self.__scale_point(centroid_t, scale_transform)
45
+ s_centroid_p = self.__scale_point(centroid_p, scale_transform)
46
+ else:
47
+ s_centroid_t = centroid_t
48
+ s_centroid_p = centroid_p
49
+
50
+ new_row = {'Name': name,
51
+ 'Train_ds': train_ds,
52
+ 'Eval_ds': eval_ds,
53
+ 'MSE': mse(fix_t, fix_p),
54
+ 'NCC': ncc(n_fix_t, n_fix_p),
55
+ 'SSIM': ssim(fix_t, fix_p, multichannel=True),
56
+ 'DICE_PAR': dc(par_p, par_t),
57
+ 'DICE_TUM': dc(tum_p, tum_t),
58
+ 'DICE_VES': dc(ves_p, ves_t),
59
+ 'HD95_PAR': hd95(par_p, par_t) if np.sum(par_p) else 64,
60
+ 'HD95_TUM': hd95(tum_p, tum_t) if np.sum(tum_p) else 64,
61
+ 'HD95_VES': hd95(ves_p, ves_t) if np.sum(ves_p) else 64,
62
+ 'TRE': euclidean(s_centroid_t, s_centroid_p)}
63
+
64
+ self.__metrics_df = self.__metrics_df.append(new_row, ignore_index=True)
65
+
66
+ @staticmethod
67
+ def __mean_centred_img(img):
68
+ return img - np.mean(img)
69
+
70
+ @staticmethod
71
+ def __scale_point(point, scale_matrix):
72
+ assert scale_matrix.shape == (4, 4), 'Transformation matrix is expected to have shape (4, 4)'
73
+ aux_aug = np.ones((4,))
74
+ aux_aug[:3] = point
75
+ return np.matmul(scale_matrix, aux_aug)[:1]
76
+
77
+ def save_metrics(self, dest_folder=None):
78
+ if dest_folder is None:
79
+ dest_folder = self.__output_folder
80
+ self.__metrics_df.to_csv(os.path.join(dest_folder, 'metrics.csv'))
81
+ self.__metrics_df.to_latex(os.path.join(dest_folder, 'table.txt'), sparsify=True)
82
+ print('Metrics saved in: ' + os.path.join(dest_folder))
83
+
84
+ def print_summary(self):
85
+ print(self.__metrics_df[['MSE', 'NCC', 'SSIM',
86
+ 'DICE_PAR', 'DICE_TUM', 'DICE_VES',
87
+ 'HD95_PAR', 'HD95_TUM', 'HD95_VES', 'TRE']].describe())
88
+
89
+
90
+ def resize_img_to_original_space(img, bb, first_reshape, original_shape, clip_img=False, flow=False):
91
+ first_reshape = first_reshape.astype(int)
92
+ bb = bb.astype(int)
93
+ original_shape = original_shape.astype(int)
94
+
95
+ if flow:
96
+ # Multiply before resizing to reduce the number of multiplications
97
+ img = _rescale_flow_values(img, bb, img.shape, first_reshape, original_shape)
98
+
99
+ min_i, min_j, min_k, bb_i, bb_j, bb_k = bb
100
+ max_i = min_i + bb_i
101
+ max_j = min_j + bb_j
102
+ max_k = min_k + bb_k
103
+
104
+ img_bb = resize(img, (bb_i, bb_j, bb_k)) # Get the original bounding box shape
105
+
106
+ # Place the bounding box again in the cubic volume
107
+ img_copy = np.zeros((*first_reshape, img.shape[-1]) if len(img.shape) > 3 else first_reshape) # Get channels if any
108
+ img_copy[min_i:max_i, min_j:max_j, min_k:max_k, ...] = img_bb
109
+
110
+ # Now resize to the original shape
111
+ resized_img = resize(img_copy, original_shape, preserve_range=True, anti_aliasing=False)
112
+ if clip_img or flow:
113
+ # clip_mask = np.zeros(img_copy.shape[:3], np.int)
114
+ # clip_mask[min_i:max_i, min_j:max_j, min_k:max_k] = 1
115
+ # clip_mask = resize(clip_mask, original_shape, preserve_range=True, anti_aliasing=False)
116
+ # clip_mask[clip_mask > 0.5] = 1
117
+ # clip_mask[clip_mask < 1] = 0
118
+ #
119
+ # [min_i, min_j, min_k, max_i, max_j, max_k] = regionprops(label(clip_mask))[0].bbox
120
+ #
121
+ # resized_img = resized_img[min_i:max_i, min_j:max_j, min_k:max_k, ...]
122
+
123
+ # Compute the coordinates of the boundix box in the upsampled volume, instead of resizing a mask image
124
+ S = resize_transformation(img.shape, bb=None, first_reshape=first_reshape, original_shape=original_shape, translate=True)
125
+ bb_coords = np.asarray([[min_i, min_j, min_k], [max_i, max_j, max_k]])
126
+ bb_coords = np.hstack([bb_coords, np.ones((2, 1))])
127
+
128
+ upsamp_bbox_coords = np.around(np.matmul(S, bb_coords.T)[:-1, :].T).astype(np.int)
129
+ min_i = upsamp_bbox_coords[0][0]
130
+ min_j = upsamp_bbox_coords[0][1]
131
+ min_k = upsamp_bbox_coords[0][2]
132
+ max_i = upsamp_bbox_coords[1][0]
133
+ max_j = upsamp_bbox_coords[1][1]
134
+ max_k = upsamp_bbox_coords[1][2]
135
+ resized_img = resized_img[min_i:max_i, min_j:max_j, min_k:max_k, ...]
136
+
137
+ if flow:
138
+ # Return also the origin of the bb in the resized volume for the following interpolation
139
+ return resized_img, np.asarray([min_i, min_j, min_k])
140
+ return resized_img
141
+ # This is supposed to be an isotropic image with voxel size 1 mm
142
+
143
+
144
+ def _rescale_flow_values(flow, bb, current_img_shape, first_reshape, original_shape):
145
+ S = resize_transformation(current_img_shape, bb, first_reshape, original_shape, translate=False)
146
+
147
+ [si, sj, sk] = np.diag(S[:3, :3])
148
+ flow[..., 0] *= si
149
+ flow[..., 1] *= sj
150
+ flow[..., 2] *= sk
151
+
152
+ return flow
153
+
154
+
155
+ def resize_pts_to_original_space(pt, bb, current_img_shape, first_reshape, original_shape):
156
+ T = resize_transformation(current_img_shape, bb, first_reshape, original_shape)
157
+ if len(pt.shape) > 1:
158
+ pt_aug = np.ones((4, pt.shape[0]))
159
+ pt_aug[0:3, :] = pt.T
160
+ else:
161
+ pt_aug = np.ones((4,))
162
+ pt_aug[0:3] = pt
163
+ trf_pt = np.matmul(T, pt_aug)[:-1, ...].T
164
+
165
+ return trf_pt
166
+
167
+
168
+ def resize_transformation(current_img_shape, bb=None, first_reshape=None, original_shape=None, translate=True):
169
+ first_reshape = first_reshape.astype(int)
170
+ original_shape = original_shape.astype(int)
171
+
172
+ first_resize_trf = np.eye(4)
173
+ if bb is not None:
174
+ bb = bb.astype(int)
175
+ min_i, min_j, min_k, bb_i, bb_j, bb_k = bb
176
+ np.fill_diagonal(first_resize_trf, [bb_i / current_img_shape[0], bb_j / current_img_shape[1], bb_k / current_img_shape[2], 1])
177
+ if translate:
178
+ first_resize_trf[:3, -1] = np.asarray([min_i, min_j, min_k])
179
+
180
+ original_resize_trf = np.eye(4)
181
+ np.fill_diagonal(original_resize_trf, [original_shape[0] / first_reshape[0], original_shape[1] / first_reshape[1], original_shape[2] / first_reshape[2], 1])
182
+
183
+ return np.matmul(original_resize_trf, first_resize_trf)
184
+