jpdefrutos commited on
Commit
286a978
·
1 Parent(s): 7a8ed91

Updating latest changes

Browse files
Brain_study/ABSTRACT/figures.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import re
4
+ import warnings
5
+
6
+ import nibabel as nib
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib import cm
10
+ from matplotlib.colors import ListedColormap, LinearSegmentedColormap
11
+
12
+ segm_cm = cm.get_cmap('Dark2', 256)
13
+ segm_cm = segm_cm(np.linspace(0, 1, 28))
14
+ segm_cm[0, :] = np.asarray([0, 0, 0, 0])
15
+ segm_cm = ListedColormap(segm_cm)
16
+
17
+ if __name__ == '__main__':
18
+ parser = argparse.ArgumentParser()
19
+
20
+ parser.add_argument('-d', '--dir', type=str, help='Directories where the models are stored', default=None)
21
+ parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
22
+ parser.add_argument('--overwrite', type=bool, default=True)
23
+ args = parser.parse_args()
24
+ assert args.dir is not None, "No directories provided. Stopping"
25
+
26
+ list_fix_img = list()
27
+ list_mov_img = list()
28
+ list_fix_seg = list()
29
+ list_mov_seg = list()
30
+ list_pred_img = list()
31
+ list_pred_seg = list()
32
+ print('Fetching data...')
33
+ for r, d, f in os.walk(args.dir):
34
+ if os.path.split(r)[1] == 'Evaluation_paper':
35
+ for name in f:
36
+ if re.search('^050', name) and name.endswith('nii.gz'):
37
+ if re.search('fix_img', name) and name.endswith('nii.gz'):
38
+ list_fix_img.append(os.path.join(r, name))
39
+ elif re.search('mov_img', name):
40
+ list_mov_img.append(os.path.join(r, name))
41
+ elif re.search('fix_seg', name):
42
+ list_fix_seg.append(os.path.join(r, name))
43
+ elif re.search('mov_seg', name):
44
+ list_mov_seg.append(os.path.join(r, name))
45
+ elif re.search('pred_img', name):
46
+ list_pred_img.append(os.path.join(r, name))
47
+ elif re.search('pred_seg', name):
48
+ list_pred_seg.append(os.path.join(r, name))
49
+
50
+ # Figure: all coronal views
51
+ # Fix img | Mov img
52
+ # BASELINE 1 | BASELINE 2 | SEGGUIDED
53
+ # UW 1 | UW 2 | UW 3
54
+ list_fix_img.sort()
55
+ list_fix_seg.sort()
56
+ list_mov_img.sort()
57
+ list_mov_seg.sort()
58
+ list_pred_img.sort()
59
+ list_pred_seg.sort()
60
+ print('Making Test_data.png...')
61
+ selected_slice = 30
62
+ fix_img = np.asarray(nib.load(list_fix_img[0]).dataobj)[..., selected_slice, 0]
63
+ mov_img = np.asarray(nib.load(list_mov_img[0]).dataobj)[..., selected_slice, 0]
64
+ fix_seg = np.asarray(nib.load(list_fix_seg[0]).dataobj)[..., selected_slice, 0]
65
+ mov_seg = np.asarray(nib.load(list_mov_seg[0]).dataobj)[..., selected_slice, 0]
66
+
67
+ fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(9, 3), dpi=200)
68
+
69
+ for i, (img, title) in enumerate(zip([(fix_img, fix_seg), (mov_img, mov_seg)],
70
+ [('Fixed image', 'Fixed Segms.'), ('Moving image', 'Moving Segms.')])):
71
+
72
+ ax[i].imshow(img[0], origin='lower', cmap='Greys_r')
73
+ ax[i+2].imshow(img[0], origin='lower', cmap='Greys_r')
74
+ ax[i+2].imshow(img[1], origin='lower', cmap=segm_cm, alpha=0.6)
75
+
76
+ ax[i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
77
+ ax[i+2].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
78
+
79
+ ax[i].set_xlabel(title[0], fontsize=16)
80
+ ax[i+2].set_xlabel(title[1], fontsize=16)
81
+
82
+ plt.tight_layout()
83
+ if not args.overwrite and os.path.exists(os.path.join(args.output, 'Test_data.png')):
84
+ warnings.warn('File Test_data.png already exists. Skipping')
85
+ else:
86
+ plt.savefig(os.path.join(args.output, 'Test_data.png'), format='png')
87
+ plt.close()
88
+
89
+ print('Making Pred_data.png...')
90
+ fig, ax = plt.subplots(nrows=2, ncols=6, figsize=(9, 3), dpi=200)
91
+
92
+ for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
93
+ img = np.asarray(nib.load(pred_img_path).dataobj)[..., selected_slice, 0]
94
+ seg = np.asarray(nib.load(pred_seg_path).dataobj)[..., selected_slice, 0]
95
+
96
+ ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
97
+ ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
98
+ ax[1, i].imshow(seg, origin='lower', cmap=segm_cm, alpha=0.6)
99
+
100
+ ax[0, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
101
+ ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
102
+
103
+ model = re.search('((UW|SEGGUIDED|BASELINE).*)_{2,}MET', pred_img_path).group(1).rstrip('_')
104
+ model = model.replace('_Lsim', ' ')
105
+ model = model.replace('_Lseg', ' ')
106
+ model = model.replace('_L', ' ')
107
+ model = model.replace('_', ' ')
108
+ model = model.upper()
109
+ model = ' '.join(model.split())
110
+
111
+ ax[1, i].set_xlabel(model, fontsize=9)
112
+ plt.tight_layout()
113
+ if not args.overwrite and os.path.exists(os.path.join(args.output, 'Pred_data.png')):
114
+ warnings.warn('File Pred_data.png already exists. Skipping')
115
+ else:
116
+ plt.savefig(os.path.join(args.output, 'Pred_data.png'), format='png')
117
+ plt.close()
118
+
119
+ print('Making Pred_data_large.png...')
120
+ fig, ax = plt.subplots(nrows=2, ncols=8, figsize=(9, 3), dpi=200)
121
+ list_pred_img = [list_mov_img[0]] + list_pred_img
122
+ list_pred_img = [list_fix_img[0]] + list_pred_img
123
+ list_pred_seg = [list_mov_seg[0]] + list_pred_seg
124
+ list_pred_seg = [list_fix_seg[0]] + list_pred_seg
125
+
126
+ for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
127
+ img = np.asarray(nib.load(pred_img_path).dataobj)[..., selected_slice, 0]
128
+ seg = np.asarray(nib.load(pred_seg_path).dataobj)[..., selected_slice, 0]
129
+
130
+ ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
131
+ ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
132
+ ax[1, i].imshow(seg, origin='lower', cmap=segm_cm, alpha=0.6)
133
+
134
+ ax[0, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
135
+ ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
136
+
137
+ if i > 1:
138
+ model = re.search('((UW|SEGGUIDED|BASELINE).*)_{2,}MET', pred_img_path).group(1).rstrip('_')
139
+ model = model.replace('_Lsim', ' ')
140
+ model = model.replace('_Lseg', ' ')
141
+ model = model.replace('_L', ' ')
142
+ model = model.replace('_', ' ')
143
+ model = model.upper()
144
+ model = ' '.join(model.split())
145
+ elif i == 0:
146
+ model = 'Moving image'
147
+ else:
148
+ model = 'Fixed image'
149
+
150
+ ax[1, i].set_xlabel(model, fontsize=7)
151
+ plt.tight_layout()
152
+ if not args.overwrite and os.path.exists(os.path.join(args.output, 'Pred_data_large.png')):
153
+ warnings.warn('File Pred_data.png already exists. Skipping')
154
+ else:
155
+ plt.savefig(os.path.join(args.output, 'Pred_data_large.png'), format='png')
156
+ plt.close()
157
+
158
+ print('...done!')
Brain_study/ABSTRACT/format_tables_abstract.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import pandas as pd
4
+ import os
5
+ import argparse
6
+ import re
7
+ import shutil
8
+
9
+ DICT_MODEL_NAMES = {'BASELINE': 'BL',
10
+ 'SEGGUIDED': 'SG',
11
+ 'UW': 'UW'}
12
+
13
+ DICT_METRICS_NAMES = {'NCC': 'N',
14
+ 'SSIM': 'S',
15
+ 'DICE': 'D',
16
+ 'DICE MACRO': 'D',
17
+ 'HD': 'H', }
18
+
19
+
20
+ def row_name(in_path: str):
21
+ model = re.search('((UW|SEGGUIDED|BASELINE).*)_\d', in_path).group(1).rstrip('_')
22
+ model = model.replace('_Lsim', '')
23
+ model = model.replace('_Lseg', '')
24
+ model = model.replace('_L', '')
25
+ model = model.replace('_', ' ')
26
+ model = model.upper()
27
+ elements = model.split()
28
+ model = elements[0]
29
+ metrics = list()
30
+ model = DICT_MODEL_NAMES[model]
31
+ for m in elements[1:]:
32
+ if m != 'MACRO':
33
+ metrics.append(DICT_METRICS_NAMES[m])
34
+
35
+ return '{}-{}'.format(model, ''.join(metrics))
36
+
37
+
38
+ if __name__ == '__main__':
39
+ parser = argparse.ArgumentParser()
40
+
41
+ parser.add_argument('-d', '--dir', nargs='+', type=str, help='List of directories where metrics.csv file is',
42
+ default=None)
43
+ parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
44
+ parser.add_argument('--overwrite', type=bool, default=True)
45
+ parser.add_argument('--filename', type=str, help='Output file name', default='metrics')
46
+ parser.add_argument('--removemetrics', nargs='+', type=str, default=None)
47
+ args = parser.parse_args()
48
+ assert args.dir is not None, "No directories provided. Stopping"
49
+
50
+ if len(args.dir) == 1:
51
+ list_files = list()
52
+ for r, d, f in os.walk(args.dir[0]):
53
+ for name in f:
54
+ if 'metrics.csv' == name: # and os.path.split(r)[1] == 'Evaluation_paper':
55
+ list_files.append(os.path.join(r, name))
56
+ else:
57
+ list_files = [os.path.join(d, 'metrics.csv') for d in args.dir]
58
+
59
+ for d in list_files:
60
+ assert os.path.exists(d), "Missing metrics.csv file in: " + os.path.split(d)[0]
61
+
62
+ print('Metric files found: {}'.format(list_files))
63
+
64
+ dataframes = list()
65
+ if len(list_files):
66
+ for d in list_files:
67
+ df = pd.read_csv(d, sep=';', header=0)
68
+ model = row_name(d)
69
+
70
+ df.insert(0, "Model", model)
71
+ df.drop(columns=list(df.filter(regex='Unnamed')), inplace=True)
72
+ df.drop(columns=['File', 'MSE', 'No_missing_lbls'], inplace=True)
73
+ dataframes.append(df)
74
+
75
+ full_table = pd.concat(dataframes)
76
+ if args.removemetrics is not None:
77
+ full_table = full_table.drop(columns=args.removemetrics)
78
+ mean_table = full_table.copy()
79
+ # mean_table.insert(column='Type', value='Avg.', loc=1)
80
+ # mean_table = mean_table.groupby(['Type', 'Model']).mean().round(3)
81
+ mean_table = mean_table.groupby(['Model'])
82
+ hd95 = mean_table.HD.quantile(0.95).map('{:.2f}'.format)
83
+ mean_table = mean_table.mean().round(3)
84
+
85
+ std_table = full_table.copy()
86
+ # std_table.insert(column='Type', value='STD', loc=1)
87
+ # std_table = std_table.groupby(['Type', 'Model']).std().round(3)
88
+ std_table = std_table.groupby(['Model']).std().round(3)
89
+
90
+ # metrics_table = pd.concat([mean_table, std_table]).swaplevel(axis='rows')
91
+ metrics_table = mean_table.applymap('{:.2f}'.format) + u"\u00B1" + std_table.applymap('{:.2f}'.format)
92
+ time_col = metrics_table.pop('Time')
93
+ metrics_table.insert(len(metrics_table.columns), 'Time', time_col)
94
+ metrics_table.insert(5, 'HD 95%ile', hd95)
95
+
96
+ metrics_file = os.path.join(args.output, args.filename + '.tex')
97
+ if os.path.exists(metrics_file) and args.overwrite:
98
+ shutil.rmtree(metrics_file, ignore_errors=True)
99
+ metrics_table.to_latex(metrics_file,
100
+ column_format='l' + 'c' * len(metrics_table.columns),
101
+ caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DICE and HD. As well as the number of missing labels in the predicted images.')
102
+ elif os.path.exists(metrics_file):
103
+ warnings.warn('File {} already exists. Skipping'.format(metrics_file))
104
+ else:
105
+ metrics_table.to_latex(metrics_file,
106
+ column_format='l' + 'c' * len(metrics_table.columns),
107
+ caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DICE and HD. As well as the number of missing labels in the predicted images.')
108
+
109
+ print('Done')
110
+ else:
111
+ print('No files found in {}!'.format(args.dir))
Brain_study/Build_test_set.py CHANGED
@@ -18,7 +18,8 @@ import DeepDeformationMapRegistration.utils.constants as C
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.layers import AugmentationLayer
20
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
21
- from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
 
22
  from tqdm import tqdm
23
 
24
  from Brain_study.data_generator import BatchGenerator
@@ -37,9 +38,15 @@ POINTS = None
37
  MISSING_CENTROID = np.asarray([[np.nan]*3])
38
 
39
 
40
- def get_mov_centroids(fix_seg, disp_map):
41
- fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, 28))
42
- disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
 
 
 
 
 
 
43
  return fix_centroids, fix_centroids + disp, disp
44
 
45
 
@@ -50,6 +57,7 @@ if __name__ == '__main__':
50
  parser.add_argument('--gpu', type=int, help='GPU', default=0)
51
  parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
52
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
 
53
  args = parser.parse_args()
54
 
55
  assert args.dataset != '', "Missing original dataset dataset"
@@ -70,12 +78,20 @@ if __name__ == '__main__':
70
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
71
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
72
 
73
- data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
74
 
75
  img_generator = data_generator.get_train_generator()
76
  nb_labels = len(img_generator.get_segmentation_labels())
77
  image_input_shape = img_generator.get_data_shape()[-1][:-1]
78
- image_output_shape = [64] * 3
 
 
 
 
 
 
 
 
79
  # Build model
80
 
81
  xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
@@ -102,19 +118,27 @@ if __name__ == '__main__':
102
  return_displacement_map=True)
103
  augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
104
 
 
 
 
 
105
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
106
  config.gpu_options.allow_growth = True
107
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
108
 
 
 
109
  sess = tf.Session(config=config)
110
  tf.keras.backend.set_session(sess)
111
  with sess.as_default():
112
  sess.run(tf.global_variables_initializer())
113
  progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
114
- for step, (in_batch, _) in progress_bar:
115
- fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
 
 
116
 
117
- fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map)
118
 
119
  out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
120
  out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
@@ -129,7 +153,7 @@ if __name__ == '__main__':
129
  f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
130
  f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
131
  f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
132
-
133
  with h5py.File(out_file_dm, 'w') as f:
134
  f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
135
  f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
 
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.layers import AugmentationLayer
20
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
21
+ from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids, DisplacementMapInterpolator
22
+
23
  from tqdm import tqdm
24
 
25
  from Brain_study.data_generator import BatchGenerator
 
38
  MISSING_CENTROID = np.asarray([[np.nan]*3])
39
 
40
 
41
+ def get_mov_centroids(fix_seg, disp_map, nb_labels=28, exclude_background_lbl=False, brain_study=True, dm_interp=None):
42
+ if exclude_background_lbl:
43
+ fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ..., 1:], ohe=True, expected_lbls=range(1, nb_labels), brain_study=brain_study)
44
+ else:
45
+ fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels), brain_study=brain_study)
46
+ if dm_interp is None:
47
+ disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
48
+ else:
49
+ disp = dm_interp(disp_map, fix_centroids, backwards=False)
50
  return fix_centroids, fix_centroids + disp, disp
51
 
52
 
 
57
  parser.add_argument('--gpu', type=int, help='GPU', default=0)
58
  parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
59
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
60
+ parser.add_argument('--output_shape', help='If an int, a cubic shape is presumed. Otherwise provide it as a space separated sequence', nargs='+', default=128)
61
  args = parser.parse_args()
62
 
63
  assert args.dataset != '', "Missing original dataset dataset"
 
78
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
79
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
80
 
81
+ data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'], return_isotropic_shape=True)
82
 
83
  img_generator = data_generator.get_train_generator()
84
  nb_labels = len(img_generator.get_segmentation_labels())
85
  image_input_shape = img_generator.get_data_shape()[-1][:-1]
86
+
87
+ if isinstance(args.output_shape, int):
88
+ image_output_shape = [args.output_shape] * 3
89
+ elif isinstance(args.output_shape, list):
90
+ assert len(args.output_shape) == 3, 'Invalid output shape, expected three values and got {}'.format(len(args.output_shape))
91
+ image_output_shape = [int(s) for s in args.output_shape]
92
+ else:
93
+ raise ValueError('Invalid output_shape. Must be an int or a space-separated sequence of ints')
94
+ print('Scaling to: ', image_output_shape)
95
  # Build model
96
 
97
  xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
 
118
  return_displacement_map=True)
119
  augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
120
 
121
+ fix_img_ph = tf.placeholder(dtype=tf.float32, shape=[1,] + image_input_shape + [1+nb_labels,], name='fix_image')
122
+
123
+ augmentation_pipeline = augm_model(fix_img_ph)
124
+
125
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
126
  config.gpu_options.allow_growth = True
127
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
128
 
129
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=8)
130
+
131
  sess = tf.Session(config=config)
132
  tf.keras.backend.set_session(sess)
133
  with sess.as_default():
134
  sess.run(tf.global_variables_initializer())
135
  progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
136
+ for step, (in_batch, _, isotropic_shape) in progress_bar:
137
+ # fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
138
+ fix_img, mov_img, fix_seg, mov_seg, disp_map = sess.run(augmentation_pipeline,
139
+ feed_dict={'fix_image:0': in_batch})
140
 
141
+ fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, dm_interp=dm_interp)
142
 
143
  out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
144
  out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
 
153
  f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
154
  f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
155
  f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
156
+ f.create_dataset('isotropic_shape', data=np.squeeze(isotropic_shape))
157
  with h5py.File(out_file_dm, 'w') as f:
158
  f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
159
  f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
Brain_study/Evaluate_network.py CHANGED
@@ -23,6 +23,7 @@ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplifie
23
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
24
  from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
25
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
 
26
  from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
27
  from scipy.interpolate import RegularGridInterpolator
28
  from tqdm import tqdm
@@ -147,9 +148,9 @@ if __name__ == '__main__':
147
  dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval()
148
  hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval()
149
 
150
- pred_seg = np.argmax(pred_seg, axis=-1)[..., np.newaxis].astype(np.float32)
151
- mov_seg = np.argmax(mov_seg, axis=-1)[..., np.newaxis].astype(np.float32)
152
- fix_seg = np.argmax(fix_seg, axis=-1)[..., np.newaxis].astype(np.float32)
153
 
154
  mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1)
155
  dest_coords = mov_coords + disp_map[0, ...]
@@ -178,8 +179,8 @@ if __name__ == '__main__':
178
  plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
179
  plt.close()
180
 
181
- plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures.png'.format(step)), show=False)
182
- plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
183
  save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
184
 
185
  progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
 
23
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
24
  from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
25
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
26
+ from DeepDeformationMapRegistration.utils.misc import segmentation_ohe_to_cardinal
27
  from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
28
  from scipy.interpolate import RegularGridInterpolator
29
  from tqdm import tqdm
 
148
  dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval()
149
  hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval()
150
 
151
+ pred_seg = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
152
+ mov_seg = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
153
+ fix_seg = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
154
 
155
  mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1)
156
  dest_coords = mov_coords + disp_map[0, ...]
 
179
  plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
180
  plt.close()
181
 
182
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg, mov_seg, pred_seg], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
183
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
184
  save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
185
 
186
  progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
Brain_study/Evaluate_network__test_fixed.py CHANGED
@@ -18,20 +18,23 @@ import pandas as pd
18
  import voxelmorph as vxm
19
 
20
  import DeepDeformationMapRegistration.utils.constants as C
 
21
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
22
  from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
23
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
24
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
25
  from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
26
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
 
27
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
28
  from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
29
- from scipy.interpolate import RegularGridInterpolator
30
  from tqdm import tqdm
31
-
32
  import h5py
33
  import re
34
  from Brain_study.data_generator import BatchGenerator
 
35
 
36
  import argparse
37
 
@@ -49,9 +52,10 @@ if __name__ == '__main__':
49
  parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
50
  parser.add_argument('--gpu', type=int, help='GPU', default=0)
51
  parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
52
- default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training')
53
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
54
  parser.add_argument('--outdirname', type=str, default='Evaluate')
 
55
  args = parser.parse_args()
56
  if args.model is not None:
57
  assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
@@ -83,8 +87,8 @@ if __name__ == '__main__':
83
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
84
  config.allow_soft_placement = True
85
 
86
- sess = tf.Session(config=config)
87
- tf.keras.backend.set_session(sess)
88
 
89
  # Loss and metric functions. Common to all models
90
  loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
@@ -101,10 +105,10 @@ if __name__ == '__main__':
101
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
102
 
103
  ### METRICS GRAPH ###
104
- fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
105
- pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
106
- fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
107
- pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
108
 
109
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
110
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
@@ -118,7 +122,7 @@ if __name__ == '__main__':
118
  # Needed for VxmDense type of network
119
  warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
120
 
121
- dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
122
 
123
  for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
124
  print('MODEL LOCATION: ', MODEL_FILE)
@@ -171,6 +175,8 @@ if __name__ == '__main__':
171
  fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
172
  mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
173
  fix_centroids = f['fix_centroids'][:]
 
 
174
 
175
  if network.name == 'vxm_dense_semi_supervised_seg':
176
  t0 = time.time()
@@ -182,29 +188,42 @@ if __name__ == '__main__':
182
  pred_seg = warp_segmentation.predict([mov_seg, disp_map])
183
  t1 = time.time()
184
 
185
- mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(0, 28))
 
186
  # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
187
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
188
 
 
 
 
 
 
 
 
 
189
  # I need the labels to be OHE to compute the segmentation metrics.
190
- dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
 
 
 
191
 
192
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
193
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
194
  fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
195
 
196
- ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
197
- ms_ssim = ms_ssim[0]
 
198
 
199
  # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
200
- upsample_scale = 128 / 64
201
- fix_centroids_isotropic = fix_centroids * upsample_scale
202
  # mov_centroids_isotropic = mov_centroids * upsample_scale
203
- pred_centroids_isotropic = pred_centroids * upsample_scale
204
 
205
- fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
206
- # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
207
- pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
208
  # Now we can measure the TRE in mm
209
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
210
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
@@ -235,14 +254,20 @@ if __name__ == '__main__':
235
  # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
236
  # plt.close()
237
 
238
- plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures.png'.format(step)), show=False)
239
- plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
240
- save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
241
 
242
  progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
243
 
244
  print('Summary\n=======\n')
245
- print(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0))
 
 
 
 
 
 
246
  print('\n=======\n')
247
  tf.keras.backend.clear_session()
248
  # sess.close()
 
18
  import voxelmorph as vxm
19
 
20
  import DeepDeformationMapRegistration.utils.constants as C
21
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
22
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
23
  from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
24
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
25
  from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
26
  from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
27
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
28
+ from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation
29
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
30
  from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
31
+ from scipy.ndimage import zoom
32
  from tqdm import tqdm
33
+ import medpy.metric as medpy_metrics
34
  import h5py
35
  import re
36
  from Brain_study.data_generator import BatchGenerator
37
+ from voxelmorph.tf.layers import SpatialTransformer
38
 
39
  import argparse
40
 
 
52
  parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
53
  parser.add_argument('--gpu', type=int, help='GPU', default=0)
54
  parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
55
+ default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
56
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
57
  parser.add_argument('--outdirname', type=str, default='Evaluate')
58
+
59
  args = parser.parse_args()
60
  if args.model is not None:
61
  assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
 
87
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
88
  config.allow_soft_placement = True
89
 
90
+ sess = tf.compat.v1.Session(config=config)
91
+ tf.compat.v1.keras.backend.set_session(sess)
92
 
93
  # Loss and metric functions. Common to all models
94
  loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
 
105
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
106
 
107
  ### METRICS GRAPH ###
108
+ fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
109
+ pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
110
+ fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
111
+ pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
112
 
113
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
114
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
 
122
  # Needed for VxmDense type of network
123
  warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
124
 
125
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
126
 
127
  for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
128
  print('MODEL LOCATION: ', MODEL_FILE)
 
175
  fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
176
  mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
177
  fix_centroids = f['fix_centroids'][:]
178
+ isotropic_shape = f['isotropic_shape'][:]
179
+ voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
180
 
181
  if network.name == 'vxm_dense_semi_supervised_seg':
182
  t0 = time.time()
 
188
  pred_seg = warp_segmentation.predict([mov_seg, disp_map])
189
  t1 = time.time()
190
 
191
+ pred_img = min_max_norm(pred_img)
192
+ mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels + 1))
193
  # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
194
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
195
 
196
+ # Up sample the segmentation masks to isotropic resolution
197
+ zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
198
+ pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
199
+ fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
200
+
201
+ pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
202
+ fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
203
+
204
  # I need the labels to be OHE to compute the segmentation metrics.
205
+ # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
206
+ dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
207
+ hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
208
+ dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
209
 
210
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
211
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
212
  fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
213
 
214
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img_isot, 'pred_img:0': pred_img_isot})
215
+ ssim = np.mean(ssim) # returns a list of values, which correspond to the ssim of each patch
216
+ ms_ssim = ms_ssim[0] # returns an array of shape (1,)
217
 
218
  # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
219
+ # upsample_scale = 128 / 64
220
+ fix_centroids_isotropic = fix_centroids * voxel_size
221
  # mov_centroids_isotropic = mov_centroids * upsample_scale
222
+ pred_centroids_isotropic = pred_centroids * voxel_size
223
 
224
+ # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
225
+ # # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
226
+ # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
227
  # Now we can measure the TRE in mm
228
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
229
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
 
254
  # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
255
  # plt.close()
256
 
257
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
258
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
259
+ save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
260
 
261
  progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
262
 
263
  print('Summary\n=======\n')
264
+ metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
265
+ print('\nAVG:\n')
266
+ print(metrics_df.mean(axis=0))
267
+ print('\nSTD:\n')
268
+ print(metrics_df.std(axis=0))
269
+ print('\nHD95perc:\n')
270
+ print(metrics_df['HD'].describe(percentiles=[.95]))
271
  print('\n=======\n')
272
  tf.keras.backend.clear_session()
273
  # sess.close()
Brain_study/MultiTrain_config.py CHANGED
@@ -62,6 +62,11 @@ if __name__ == '__main__':
62
  except KeyError as e:
63
  head = [16, 16]
64
 
 
 
 
 
 
65
  launch_train(dataset_folder=datasetConfig['train'],
66
  validation_folder=datasetConfig['validation'],
67
  output_folder=output_folder,
@@ -75,4 +80,5 @@ if __name__ == '__main__':
75
  early_stop_patience=eval(trainConfig['earlyStopPatience']),
76
  unet=unet,
77
  head=head,
 
78
  **loss_config)
 
62
  except KeyError as e:
63
  head = [16, 16]
64
 
65
+ try:
66
+ resume_checkpoint = trainConfig['resumeCheckpoint']
67
+ except KeyError as e:
68
+ resume_checkpoint = None
69
+
70
  launch_train(dataset_folder=datasetConfig['train'],
71
  validation_folder=datasetConfig['validation'],
72
  output_folder=output_folder,
 
80
  early_stop_patience=eval(trainConfig['earlyStopPatience']),
81
  unet=unet,
82
  head=head,
83
+ resume=resume_checkpoint,
84
  **loss_config)
Brain_study/Train_Baseline.py CHANGED
@@ -8,7 +8,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
8
  import numpy as np
9
  import tensorflow as tf
10
 
11
- from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
12
  from tensorflow.keras import Input
13
  from tensorflow.keras.models import Model
14
  from tensorflow.python.keras.utils import Progbar
@@ -32,11 +32,12 @@ from Brain_study.utils import SummaryDictionary, named_logs
32
 
33
  from tqdm import tqdm
34
  from datetime import datetime
 
35
 
36
 
37
  def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
38
  acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
39
- unet=[16, 32, 64, 128, 256], head=[16, 16]):
40
  assert dataset_folder is not None and output_folder is not None
41
 
42
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
@@ -46,7 +47,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
46
  if batch_size != 1 and acc_gradients != 1:
47
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
48
 
49
- output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
 
 
 
 
 
 
 
 
 
50
  os.makedirs(output_folder, exist_ok=True)
51
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
52
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
@@ -141,6 +151,26 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
141
  nb_unet_features=nb_features,
142
  int_steps=0)
143
  network.summary(line_length=150)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # Losses and loss weights
145
  SSIM_KER_SIZE = 5
146
  MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
@@ -178,14 +208,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
178
 
179
  # Train
180
  os.makedirs(output_folder, exist_ok=True)
181
- os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
182
  os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
183
  os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
184
 
185
  callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
186
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
187
- # callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
188
- # save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
189
  # CSVLogger(train_log_name, ';'),
190
  # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
191
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
@@ -194,6 +224,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
194
  write_graph=True, write_grads=True
195
  )
196
  callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
 
197
 
198
  losses = {'transformer': loss_fnc,
199
  'flow': vxm.losses.Grad('l2').loss}
@@ -215,8 +246,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
215
 
216
  callback_tensorboard.set_model(network)
217
  callback_best_model.set_model(network)
218
- # callback_save_model.set_model(network)
219
  callback_early_stop.set_model(network)
 
220
  # TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
221
 
222
  summary = SummaryDictionary(network, C.BATCH_SIZE)
@@ -226,13 +258,13 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
226
  callback_tensorboard.on_train_begin()
227
  callback_early_stop.on_train_begin()
228
  callback_best_model.on_train_begin()
229
- # callback_save_model.on_train_begin()
230
- for epoch in range(C.EPOCHS):
231
  callback_tensorboard.on_epoch_begin(epoch)
232
  callback_early_stop.on_epoch_begin(epoch)
233
  callback_best_model.on_epoch_begin(epoch)
234
- # callback_save_model.on_epoch_begin(epoch)
235
-
236
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
237
  print('TRAINING')
238
 
@@ -241,9 +273,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
241
  for step, (in_batch, _) in enumerate(train_generator, 1):
242
  # callback_tensorboard.on_train_batch_begin(step)
243
  callback_best_model.on_train_batch_begin(step)
244
- # callback_save_model.on_train_batch_begin(step)
245
  callback_early_stop.on_train_batch_begin(step)
246
-
247
  try:
248
  fix_img, mov_img, *_ = augm_model_train.predict(in_batch)
249
  np.nan_to_num(fix_img, copy=False)
@@ -273,8 +305,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
273
  summary.on_train_batch_end(ret)
274
  # callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
275
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
276
- # callback_save_model.on_train_batch_end(step, named_logs(network, ret))
277
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
 
278
  progress_bar.update(step, zip(names, ret))
279
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
280
  val_values = progress_bar._values.copy()
@@ -309,13 +342,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
309
  callback_tensorboard.on_epoch_end(epoch, epoch_summary)
310
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
311
  callback_best_model.on_epoch_end(epoch, epoch_summary)
312
- # callback_save_model.on_epoch_end(epoch, epoch_summary)
 
313
  print('End of epoch {}: '.format(epoch), ret, '\n')
314
 
315
  callback_tensorboard.on_train_end()
316
- # callback_save_model.on_train_end()
317
  callback_best_model.on_train_end()
318
  callback_early_stop.on_train_end()
 
319
 
320
 
321
  if __name__ == '__main__':
 
8
  import numpy as np
9
  import tensorflow as tf
10
 
11
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
12
  from tensorflow.keras import Input
13
  from tensorflow.keras.models import Model
14
  from tensorflow.python.keras.utils import Progbar
 
32
 
33
  from tqdm import tqdm
34
  from datetime import datetime
35
+ import re
36
 
37
 
38
  def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
39
  acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
40
+ unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
41
  assert dataset_folder is not None and output_folder is not None
42
 
43
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
 
47
  if batch_size != 1 and acc_gradients != 1:
48
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
49
 
50
+ if resume is not None:
51
+ try:
52
+ assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
53
+ output_folder = resume
54
+ resume = True
55
+ except AssertionError:
56
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
57
+ resume = False
58
+ else:
59
+ resume = False
60
  os.makedirs(output_folder, exist_ok=True)
61
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
62
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
 
151
  nb_unet_features=nb_features,
152
  int_steps=0)
153
  network.summary(line_length=150)
154
+
155
+ resume_epoch = 0
156
+ if resume:
157
+ cp_dir = os.path.join(output_folder, 'checkpoints')
158
+ cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
159
+ if len(cp_file_list):
160
+ cp_file_list.sort()
161
+ checkpoint_file = cp_file_list[-1]
162
+ if os.path.exists(checkpoint_file):
163
+ network.load_weights(checkpoint_file, by_name=True)
164
+ print('Loaded checkpoint file: ' + checkpoint_file)
165
+ try:
166
+ resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
167
+ except TypeError:
168
+ # Checkpoint file has no epoch number in the name
169
+ resume_epoch = 0
170
+ print('Resuming from epoch: {:d}'.format(resume_epoch))
171
+ else:
172
+ warnings.warn('Checkpoint file NOT found. Training from scratch')
173
+
174
  # Losses and loss weights
175
  SSIM_KER_SIZE = 5
176
  MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
 
208
 
209
  # Train
210
  os.makedirs(output_folder, exist_ok=True)
211
+ os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True) # exist_ok=True leaves directory unaltered.
212
  os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
213
  os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
214
 
215
  callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
216
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
217
+ callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'),
218
+ save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
219
  # CSVLogger(train_log_name, ';'),
220
  # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
221
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
 
224
  write_graph=True, write_grads=True
225
  )
226
  callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
227
+ callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
228
 
229
  losses = {'transformer': loss_fnc,
230
  'flow': vxm.losses.Grad('l2').loss}
 
246
 
247
  callback_tensorboard.set_model(network)
248
  callback_best_model.set_model(network)
249
+ callback_save_model.set_model(network)
250
  callback_early_stop.set_model(network)
251
+ callback_lr.set_model(network)
252
  # TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
253
 
254
  summary = SummaryDictionary(network, C.BATCH_SIZE)
 
258
  callback_tensorboard.on_train_begin()
259
  callback_early_stop.on_train_begin()
260
  callback_best_model.on_train_begin()
261
+ callback_save_model.on_train_begin()
262
+ for epoch in range(resume_epoch, C.EPOCHS):
263
  callback_tensorboard.on_epoch_begin(epoch)
264
  callback_early_stop.on_epoch_begin(epoch)
265
  callback_best_model.on_epoch_begin(epoch)
266
+ callback_save_model.on_epoch_begin(epoch)
267
+ callback_lr.on_epoch_begin(epoch)
268
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
269
  print('TRAINING')
270
 
 
273
  for step, (in_batch, _) in enumerate(train_generator, 1):
274
  # callback_tensorboard.on_train_batch_begin(step)
275
  callback_best_model.on_train_batch_begin(step)
276
+ callback_save_model.on_train_batch_begin(step)
277
  callback_early_stop.on_train_batch_begin(step)
278
+ callback_lr.on_train_batch_begin(step)
279
  try:
280
  fix_img, mov_img, *_ = augm_model_train.predict(in_batch)
281
  np.nan_to_num(fix_img, copy=False)
 
305
  summary.on_train_batch_end(ret)
306
  # callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
307
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
308
+ callback_save_model.on_train_batch_end(step, named_logs(network, ret))
309
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
310
+ callback_lr.on_train_batch_end(step, named_logs(network, ret))
311
  progress_bar.update(step, zip(names, ret))
312
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
313
  val_values = progress_bar._values.copy()
 
342
  callback_tensorboard.on_epoch_end(epoch, epoch_summary)
343
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
344
  callback_best_model.on_epoch_end(epoch, epoch_summary)
345
+ callback_save_model.on_epoch_end(epoch, epoch_summary)
346
+ callback_lr.on_epoch_end(epoch, epoch_summary)
347
  print('End of epoch {}: '.format(epoch), ret, '\n')
348
 
349
  callback_tensorboard.on_train_end()
350
+ callback_save_model.on_train_end()
351
  callback_best_model.on_train_end()
352
  callback_early_stop.on_train_end()
353
+ callback_lr.on_train_end()
354
 
355
 
356
  if __name__ == '__main__':
Brain_study/Train_SegmentationGuided.py CHANGED
@@ -1,11 +1,12 @@
1
  import os, sys
 
2
  currentdir = os.path.dirname(os.path.realpath(__file__))
3
  parentdir = os.path.dirname(currentdir)
4
  sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
 
6
  import numpy as np
7
  import tensorflow as tf
8
- from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
9
  from tensorflow.keras import Input
10
  from tensorflow.keras.models import Model
11
  from tensorflow.python.keras.utils import Progbar
@@ -30,11 +31,14 @@ from Brain_study.utils import SummaryDictionary, named_logs
30
 
31
  import time
32
  import warnings
 
 
33
 
34
 
35
  def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd',
36
  acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
37
- unet=[16, 32, 64, 128, 256], head=[16, 16]):
 
38
  assert dataset_folder is not None and output_folder is not None
39
 
40
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
@@ -44,7 +48,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
44
  if batch_size != 1 and acc_gradients != 1:
45
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
46
 
47
- output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
 
 
 
 
 
 
 
 
 
48
  os.makedirs(output_folder, exist_ok=True)
49
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
50
  C.TRAINING_DATASET = dataset_folder
@@ -85,6 +98,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
85
  directory_val=C.VALIDATION_DATASET)
86
 
87
  train_generator = data_generator.get_train_generator()
 
 
 
 
88
  validation_generator = data_generator.get_validation_generator()
89
 
90
  image_input_shape = train_generator.get_data_shape()[1][:-1]
@@ -96,7 +113,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
96
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
97
  config.gpu_options.allow_growth = True
98
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
99
- config.allow_soft_placement = True # https://github.com/tensorflow/tensorflow/issues/30782
100
  sess = tf.Session(config=config)
101
  tf.keras.backend.set_session(sess)
102
 
@@ -128,6 +145,27 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
128
  int_steps=0,
129
  int_downsize=1,
130
  seg_downsize=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Compile the model
133
  SSIM_KER_SIZE = 5
@@ -190,8 +228,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
190
 
191
  callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
192
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
193
- # callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
194
- # save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
195
  # CSVLogger(train_log_name, ';'),
196
  # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
197
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
@@ -200,6 +238,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
200
  write_graph=True, write_grads=True
201
  )
202
  callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
 
203
 
204
  # Compile the model
205
  optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
@@ -210,9 +249,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
210
 
211
  callback_tensorboard.set_model(network)
212
  callback_best_model.set_model(network)
213
- # callback_save_model.set_model(network)
214
  callback_early_stop.set_model(network)
215
-
216
  summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP)
217
  names = network.metrics_names # It give both the loss and metric names
218
  log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
@@ -221,13 +260,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
221
  callback_tensorboard.on_train_begin()
222
  callback_early_stop.on_train_begin()
223
  callback_best_model.on_train_begin()
224
- # callback_save_model.on_train_begin()
225
- for epoch in range(C.EPOCHS):
 
226
  callback_tensorboard.on_epoch_begin(epoch)
227
  callback_early_stop.on_epoch_begin(epoch)
228
  callback_best_model.on_epoch_begin(epoch)
229
- # callback_save_model.on_epoch_begin(epoch)
230
-
231
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
232
  print('TRAINING')
233
 
@@ -238,12 +278,13 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
238
  #print('Loaded in {} s'.format(time.time() - t0))
239
  # callback_tensorboard.on_train_batch_begin(step)
240
  callback_best_model.on_train_batch_begin(step)
241
- # callback_save_model.on_train_batch_begin(step)
242
  callback_early_stop.on_train_batch_begin(step)
243
-
244
  try:
245
  t0 = time.time()
246
  fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
 
247
  #print('Augmented in {} s'.format(time.time() - t0))
248
  np.nan_to_num(fix_img, copy=False)
249
  np.nan_to_num(mov_img, copy=False)
@@ -275,8 +316,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
275
  summary.on_train_batch_end(ret)
276
  # callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
277
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
278
- # callback_save_model.on_train_batch_end(step, named_logs(network, ret))
279
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
 
280
  progress_bar.update(step, zip(names, ret))
281
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
282
  t0 = time.time()
@@ -313,13 +355,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
313
  callback_tensorboard.on_epoch_end(epoch, epoch_summary)
314
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
315
  callback_best_model.on_epoch_end(epoch, epoch_summary)
316
- # callback_save_model.on_epoch_end(epoch, named_logs(network, ret, True))
 
317
 
318
  callback_tensorboard.on_train_end()
319
- # callback_save_model.on_train_end()
320
  callback_best_model.on_train_end()
321
  callback_early_stop.on_train_end()
322
-
323
 
324
  if __name__ == '__main__':
325
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
 
1
  import os, sys
2
+
3
  currentdir = os.path.dirname(os.path.realpath(__file__))
4
  parentdir = os.path.dirname(currentdir)
5
  sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
 
7
  import numpy as np
8
  import tensorflow as tf
9
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
10
  from tensorflow.keras import Input
11
  from tensorflow.keras.models import Model
12
  from tensorflow.python.keras.utils import Progbar
 
31
 
32
  import time
33
  import warnings
34
+ import re
35
+ import tqdm
36
 
37
 
38
  def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd',
39
  acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
40
+ unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
41
+
42
  assert dataset_folder is not None and output_folder is not None
43
 
44
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
 
48
  if batch_size != 1 and acc_gradients != 1:
49
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
50
 
51
+ if resume is not None:
52
+ try:
53
+ assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
54
+ output_folder = resume
55
+ resume = True
56
+ except AssertionError:
57
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
58
+ resume = False
59
+ else:
60
+ resume = False
61
  os.makedirs(output_folder, exist_ok=True)
62
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
63
  C.TRAINING_DATASET = dataset_folder
 
98
  directory_val=C.VALIDATION_DATASET)
99
 
100
  train_generator = data_generator.get_train_generator()
101
+
102
+ # for l in tqdm.tqdm(train_generator, smoothing=0):
103
+ # pass
104
+ # exit()
105
  validation_generator = data_generator.get_validation_generator()
106
 
107
  image_input_shape = train_generator.get_data_shape()[1][:-1]
 
113
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
114
  config.gpu_options.allow_growth = True
115
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
116
+ # config.allow_soft_placement = False # https://github.com/tensorflow/tensorflow/issues/30782
117
  sess = tf.Session(config=config)
118
  tf.keras.backend.set_session(sess)
119
 
 
145
  int_steps=0,
146
  int_downsize=1,
147
  seg_downsize=1)
148
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
149
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines)
150
+
151
+ resume_epoch = 0
152
+ if resume:
153
+ cp_dir = os.path.join(output_folder, 'checkpoints')
154
+ cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
155
+ if len(cp_file_list):
156
+ cp_file_list.sort()
157
+ checkpoint_file = cp_file_list[-1]
158
+ if os.path.exists(checkpoint_file):
159
+ network.load_weights(checkpoint_file, by_name=True)
160
+ print('Loaded checkpoint file: ' + checkpoint_file)
161
+ try:
162
+ resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
163
+ except TypeError:
164
+ # Checkpoint file has no epoch number in the name
165
+ resume_epoch = 0
166
+ print('Resuming from epoch: {:d}'.format(resume_epoch))
167
+ else:
168
+ warnings.warn('Checkpoint file NOT found. Training from scratch')
169
 
170
  # Compile the model
171
  SSIM_KER_SIZE = 5
 
228
 
229
  callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
230
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
231
+ callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'),
232
+ save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
233
  # CSVLogger(train_log_name, ';'),
234
  # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
235
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
 
238
  write_graph=True, write_grads=True
239
  )
240
  callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
241
+ callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
242
 
243
  # Compile the model
244
  optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
 
249
 
250
  callback_tensorboard.set_model(network)
251
  callback_best_model.set_model(network)
252
+ callback_save_model.set_model(network)
253
  callback_early_stop.set_model(network)
254
+ callback_lr.set_model(network)
255
  summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP)
256
  names = network.metrics_names # It give both the loss and metric names
257
  log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
 
260
  callback_tensorboard.on_train_begin()
261
  callback_early_stop.on_train_begin()
262
  callback_best_model.on_train_begin()
263
+ callback_save_model.on_train_begin()
264
+ callback_lr.on_train_begin()
265
+ for epoch in range(resume_epoch, C.EPOCHS):
266
  callback_tensorboard.on_epoch_begin(epoch)
267
  callback_early_stop.on_epoch_begin(epoch)
268
  callback_best_model.on_epoch_begin(epoch)
269
+ callback_save_model.on_epoch_begin(epoch)
270
+ callback_lr.on_epoch_begin(epoch)
271
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
272
  print('TRAINING')
273
 
 
278
  #print('Loaded in {} s'.format(time.time() - t0))
279
  # callback_tensorboard.on_train_batch_begin(step)
280
  callback_best_model.on_train_batch_begin(step)
281
+ callback_save_model.on_train_batch_begin(step)
282
  callback_early_stop.on_train_batch_begin(step)
283
+ callback_lr.on_train_batch_begin(step)
284
  try:
285
  t0 = time.time()
286
  fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
287
+
288
  #print('Augmented in {} s'.format(time.time() - t0))
289
  np.nan_to_num(fix_img, copy=False)
290
  np.nan_to_num(mov_img, copy=False)
 
316
  summary.on_train_batch_end(ret)
317
  # callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
318
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
319
+ callback_save_model.on_train_batch_end(step, named_logs(network, ret))
320
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
321
+ callback_lr.on_train_batch_end(step, named_logs(network, ret))
322
  progress_bar.update(step, zip(names, ret))
323
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
324
  t0 = time.time()
 
355
  callback_tensorboard.on_epoch_end(epoch, epoch_summary)
356
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
357
  callback_best_model.on_epoch_end(epoch, epoch_summary)
358
+ callback_save_model.on_epoch_end(epoch, epoch_summary)
359
+ callback_lr.on_epoch_end(epoch, epoch_summary)
360
 
361
  callback_tensorboard.on_train_end()
362
+ callback_save_model.on_train_end()
363
  callback_best_model.on_train_end()
364
  callback_early_stop.on_train_end()
365
+ callback_lr.on_train_end()
366
 
367
  if __name__ == '__main__':
368
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
Brain_study/Train_UncertaintyWeighted.py CHANGED
@@ -5,7 +5,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
 
6
  import numpy as np
7
  import tensorflow as tf
8
- from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
9
  from tensorflow.keras import Input
10
  from tensorflow.keras.models import Model
11
  from tensorflow.python.keras.utils import Progbar
@@ -27,11 +27,12 @@ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccum
27
  from Brain_study.data_generator import BatchGenerator
28
  from Brain_study.utils import SummaryDictionary, named_logs
29
  import warnings
 
30
 
31
 
32
  def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4, rw=5e-3,
33
  gpu_num=0, simil=['mse'], segm=['dice'], acc_gradients=16, batch_size=1, max_epochs=10000,
34
- early_stop_patience=1000, image_size=64, unet=[16, 32, 64, 128, 256], head=[16, 16]):
35
  assert dataset_folder is not None and output_folder is not None
36
 
37
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
@@ -41,8 +42,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
41
  if batch_size != 1 and acc_gradients != 1:
42
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
43
 
44
- output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
45
- # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
 
 
 
 
 
 
 
 
46
  os.makedirs(output_folder, exist_ok=True)
47
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
48
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
@@ -94,7 +103,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
94
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
95
  config.gpu_options.allow_growth = True
96
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
97
- config.allow_soft_placement = True
98
  sess = tf.Session(config=config)
99
  tf.keras.backend.set_session(sess)
100
 
@@ -161,6 +170,26 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
161
  int_steps=0,
162
  int_downsize=1,
163
  seg_downsize=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # Network inputs: mov_img, fix_img, mov_seg
165
  # Network outputs: pred_img, disp_map, pred_seg
166
  grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
@@ -194,8 +223,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
194
 
195
  callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
196
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
197
- # callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
198
- # save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
199
  # CSVLogger(train_log_name, ';'),
200
  # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
201
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
@@ -204,6 +233,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
204
  write_graph=True, write_grads=True
205
  )
206
  callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
 
207
 
208
  # Compile the model
209
  optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
@@ -211,8 +241,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
211
 
212
  callback_tensorboard.set_model(full_model)
213
  callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
214
- # callback_save_model.set_model(network)
215
  callback_early_stop.set_model(full_model)
 
 
216
  # TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
217
 
218
  summary = SummaryDictionary(full_model, C.BATCH_SIZE)
@@ -223,12 +255,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
223
  callback_tensorboard.on_train_begin()
224
  callback_early_stop.on_train_begin()
225
  callback_best_model.on_train_begin()
226
- # callback_save_model.on_train_begin()
227
- for epoch in range(C.EPOCHS):
 
 
228
  callback_tensorboard.on_epoch_begin(epoch)
229
  callback_early_stop.on_epoch_begin(epoch)
230
  callback_best_model.on_epoch_begin(epoch)
231
- # callback_save_model.on_epoch_begin(epoch)
 
232
 
233
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
234
  print('TRAINING')
@@ -238,8 +273,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
238
  for step, (in_batch, _) in enumerate(train_generator, 1):
239
  # callback_tensorboard.on_train_batch_begin(step)
240
  callback_best_model.on_train_batch_begin(step)
241
- # callback_save_model.on_train_batch_begin(step)
242
  callback_early_stop.on_train_batch_begin(step)
 
243
 
244
  try:
245
  fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
@@ -255,8 +291,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
255
  summary.on_train_batch_end(ret)
256
  # callback_tensorboard.on_train_batch_end(step, named_logs(full_model, ret))
257
  callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
258
- # callback_save_model.on_train_batch_end(step, named_logs(network, ret))
259
  callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
 
260
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
261
  # print(ret, '\n')
262
  progress_bar.update(step, zip(names, ret))
@@ -292,12 +329,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
292
  callback_tensorboard.on_epoch_end(epoch, epoch_summary)
293
  callback_best_model.on_epoch_end(epoch, epoch_summary)
294
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
295
- # callback_save_model.on_train_end(epoch, epoch_summary)
 
296
 
297
  callback_tensorboard.on_train_end()
298
- # callback_save_model.on_train_end()
299
  callback_best_model.on_train_end()
300
  callback_early_stop.on_train_end()
 
301
 
302
 
303
  if __name__ == '__main__':
 
5
 
6
  import numpy as np
7
  import tensorflow as tf
8
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
9
  from tensorflow.keras import Input
10
  from tensorflow.keras.models import Model
11
  from tensorflow.python.keras.utils import Progbar
 
27
  from Brain_study.data_generator import BatchGenerator
28
  from Brain_study.utils import SummaryDictionary, named_logs
29
  import warnings
30
+ import re
31
 
32
 
33
  def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4, rw=5e-3,
34
  gpu_num=0, simil=['mse'], segm=['dice'], acc_gradients=16, batch_size=1, max_epochs=10000,
35
+ early_stop_patience=1000, image_size=64, unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
36
  assert dataset_folder is not None and output_folder is not None
37
 
38
  os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
 
42
  if batch_size != 1 and acc_gradients != 1:
43
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
44
 
45
+ if resume is not None:
46
+ try:
47
+ assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
48
+ output_folder = resume
49
+ resume = True
50
+ except AssertionError:
51
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
52
+ resume = False
53
+ else:
54
+ resume = False
55
  os.makedirs(output_folder, exist_ok=True)
56
  log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
57
  C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
 
103
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
104
  config.gpu_options.allow_growth = True
105
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
106
+ # config.allow_soft_placement = True
107
  sess = tf.Session(config=config)
108
  tf.keras.backend.set_session(sess)
109
 
 
170
  int_steps=0,
171
  int_downsize=1,
172
  seg_downsize=1)
173
+
174
+ resume_epoch = 0
175
+ if resume:
176
+ cp_dir = os.path.join(output_folder, 'checkpoints')
177
+ cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
178
+ if len(cp_file_list):
179
+ cp_file_list.sort()
180
+ checkpoint_file = cp_file_list[-1]
181
+ if os.path.exists(checkpoint_file):
182
+ network.load_weights(checkpoint_file, by_name=True)
183
+ print('Loaded checkpoint file: ' + checkpoint_file)
184
+ try:
185
+ resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
186
+ except TypeError:
187
+ # Checkpoint file has no epoch number in the name
188
+ resume_epoch = 0
189
+ print('Resuming from epoch: {:d}'.format(resume_epoch))
190
+ else:
191
+ warnings.warn('Checkpoint file NOT found. Training from scratch')
192
+
193
  # Network inputs: mov_img, fix_img, mov_seg
194
  # Network outputs: pred_img, disp_map, pred_seg
195
  grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
 
223
 
224
  callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
225
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
226
+ callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'),
227
+ save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
228
  # CSVLogger(train_log_name, ';'),
229
  # UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
230
  callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
 
233
  write_graph=True, write_grads=True
234
  )
235
  callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
236
+ callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
237
 
238
  # Compile the model
239
  optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
 
241
 
242
  callback_tensorboard.set_model(full_model)
243
  callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
244
+ callback_save_model.set_model(network)
245
  callback_early_stop.set_model(full_model)
246
+ callback_lr.set_model(full_model)
247
+
248
  # TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
249
 
250
  summary = SummaryDictionary(full_model, C.BATCH_SIZE)
 
255
  callback_tensorboard.on_train_begin()
256
  callback_early_stop.on_train_begin()
257
  callback_best_model.on_train_begin()
258
+ callback_save_model.on_train_begin()
259
+ callback_lr.on_train_begin()
260
+
261
+ for epoch in range(resume_epoch, C.EPOCHS):
262
  callback_tensorboard.on_epoch_begin(epoch)
263
  callback_early_stop.on_epoch_begin(epoch)
264
  callback_best_model.on_epoch_begin(epoch)
265
+ callback_save_model.on_epoch_begin(epoch)
266
+ callback_lr.on_epoch_begin(epoch)
267
 
268
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
269
  print('TRAINING')
 
273
  for step, (in_batch, _) in enumerate(train_generator, 1):
274
  # callback_tensorboard.on_train_batch_begin(step)
275
  callback_best_model.on_train_batch_begin(step)
276
+ callback_save_model.on_train_batch_begin(step)
277
  callback_early_stop.on_train_batch_begin(step)
278
+ callback_lr.on_train_batch_begin(step)
279
 
280
  try:
281
  fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
 
291
  summary.on_train_batch_end(ret)
292
  # callback_tensorboard.on_train_batch_end(step, named_logs(full_model, ret))
293
  callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
294
+ callback_save_model.on_train_batch_end(step, named_logs(full_model, ret))
295
  callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
296
+ callback_lr.on_train_batch_end(step, named_logs(full_model, ret))
297
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
298
  # print(ret, '\n')
299
  progress_bar.update(step, zip(names, ret))
 
329
  callback_tensorboard.on_epoch_end(epoch, epoch_summary)
330
  callback_best_model.on_epoch_end(epoch, epoch_summary)
331
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
332
+ callback_save_model.on_epoch_end(epoch, epoch_summary)
333
+ callback_lr.on_epoch_end(epoch, epoch_summary)
334
 
335
  callback_tensorboard.on_train_end()
336
+ callback_save_model.on_train_end()
337
  callback_best_model.on_train_end()
338
  callback_early_stop.on_train_end()
339
+ callback_lr.on_train_end()
340
 
341
 
342
  if __name__ == '__main__':
Brain_study/data_generator.py CHANGED
@@ -1,5 +1,5 @@
1
  import warnings
2
-
3
  import numpy as np
4
  from tensorflow import keras
5
  import os
@@ -14,6 +14,7 @@ import tensorflow as tf
14
 
15
  import DeepDeformationMapRegistration.utils.constants as C
16
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
 
17
  from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
18
  from voxelmorph.tf.layers import SpatialTransformer
19
  from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT
@@ -22,6 +23,10 @@ from tensorflow.python.keras.preprocessing.image import Iterator
22
  from tensorflow.python.keras.utils import Sequence
23
  import sys
24
 
 
 
 
 
25
  #import concurrent.futures
26
  #import multiprocessing as mp
27
  import time
@@ -34,13 +39,15 @@ class BatchGenerator:
34
  split=0.7,
35
  combine_segmentations=True,
36
  labels=['all'],
37
- directory_val=None):
 
38
  self.file_directory = directory
39
  self.batch_size = batch_size
40
  self.combine_segmentations = combine_segmentations
41
  self.labels = labels
42
  self.shuffle = shuffle
43
  self.split = split
 
44
 
45
  if directory_val is None:
46
  self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
@@ -48,11 +55,11 @@ class BatchGenerator:
48
  self.num_samples = len(self.file_list)
49
  training_samples = self.file_list[:int(self.num_samples * self.split)]
50
 
51
- self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
52
  if self.split < 1.:
53
  validation_samples = list(set(self.file_list) - set(training_samples))
54
  self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
55
- validation=True)
56
  else:
57
  self.validation_iter = None
58
  else:
@@ -66,7 +73,7 @@ class BatchGenerator:
66
  self.file_list = training_samples + validation_samples
67
 
68
  self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
69
- self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
70
  validation=True)
71
 
72
  def get_train_generator(self):
@@ -92,7 +99,8 @@ ALL_LABELS_LOC = {label: loc for label, loc in zip(ALL_LABELS, range(0, len(ALL_
92
 
93
  class BatchIterator(Sequence):
94
  def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'],
95
- zero_grads=[64, 64, 64, 3], validation=False, **kwargs):
 
96
  # super(BatchIterator, self).__init__(n=len(file_list),
97
  # batch_size=batch_size,
98
  # shuffle=shuffle,
@@ -103,31 +111,51 @@ class BatchIterator(Sequence):
103
  self.file_list = file_list
104
  self.combine_segmentations = combine_segmentations
105
  self.labels = labels
106
- self.zero_grads = zero_grads
107
  self.idx_list = np.arange(0, len(self.file_list))
108
  self.validation = validation
 
 
109
  self._initialize()
110
  self.shuffle_samples()
111
 
112
  def _initialize(self):
113
- with h5py.File(self.file_list[0], 'r') as f:
114
- self.image_shape = list(f['image'][:].shape)
115
- self.segm_shape = list(f['segmentation'][:].shape)
116
- if not self.combine_segmentations:
117
- self.segm_shape[-1] = len(f['segmentation_labels'][:]) if self.labels[0].lower() == 'all' else len(self.labels)
118
-
119
- self.batch_shape = self.image_shape.copy()
120
- if self.labels[0].lower() != 'none':
121
- self.batch_shape[-1] = 2 if self.combine_segmentations else 1 + self.segm_shape[-1] # +1 because we have the fix and the moving images
122
-
123
  if self.labels[0] != 'all':
124
- if isinstance(self.labels[0], str):
125
- self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0)
128
  #self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.batch_size)
129
  #self.mp_pool = mp.Pool(self.batch_size)
130
 
 
 
 
 
 
 
 
 
131
  def shuffle_samples(self):
132
  np.random.shuffle(self.idx_list)
133
 
@@ -137,7 +165,7 @@ class BatchIterator(Sequence):
137
  def _filter_segmentations(self, segm, segm_labels):
138
  if self.combine_segmentations:
139
  # TODO
140
- warnings.warn('Cannot select labels when combinine_segmentations options is active')
141
  if self.labels[0] != 'all':
142
  if set(self.labels).issubset(set(segm_labels)):
143
  # If labels in self.labels are in segm
@@ -155,31 +183,38 @@ class BatchIterator(Sequence):
155
  def _load_sample(self, file_path):
156
  with h5py.File(file_path, 'r') as f:
157
  img = f['image'][:]
158
- segm_labels = f['segmentation_labels'][:]
159
- if self.combine_segmentations:
160
- segm = f['segmentation'][:]
161
- else:
162
- segm = f['segmentation_expanded'][:]
163
- if segm.shape[-1] != self.segm_shape[-1]:
164
- aux = np.zeros(self.segm_shape)
165
- aux[..., :segm.shape[-1]] = segm # Ensure the same shape in case there are missing labels in aux
166
- segm = aux
167
- # TODO: selection label segm = aux[..., self.labels] but:
168
- # what if aux does not have a label in self.labels??
169
 
170
- if self.labels[0].lower() != 'none' or self.validation: # I expect to ask for the segmentations during val
171
- segm = self._filter_segmentations(segm, segm_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  if self.validation:
174
- ret_val = np.concatenate([img, segm], axis=-1), (img, segm, np.zeros(self.zero_grads))
175
  else:
176
- ret_val = np.concatenate([img, segm], axis=-1), (img, np.zeros(self.zero_grads))
177
  else:
178
- ret_val = img, (img, np.zeros(self.zero_grads))
179
  return ret_val
180
 
181
  def __getitem__(self, idx):
182
  in_batch = list()
 
183
  # out_batch = list()
184
 
185
  batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size]
@@ -193,14 +228,22 @@ class BatchIterator(Sequence):
193
  # # out_batch.append(i)
194
  # else:
195
  # No need for multithreading, we are loading a single file
196
- for f in file_list:
197
- b, i = self._load_sample(f)
 
 
 
 
198
  in_batch.append(b)
199
  # out_batch.append(i)
200
 
201
- in_batch = np.asarray(in_batch)
 
 
 
 
202
  # out_batch = np.asarray(out_batch)
203
- return in_batch, in_batch
204
 
205
  def __iter__(self):
206
  """Create a generator that iterate over the Sequence."""
@@ -217,9 +260,7 @@ class BatchIterator(Sequence):
217
  if self.combine_segmentations:
218
  labels = [1]
219
  else:
220
- with h5py.File(self.file_list[0], 'r') as f:
221
- labels = np.unique(f['segmentation'][:])
222
- labels = np.sort(labels)[1:] # Ignore the background
223
  return labels
224
 
225
 
 
1
  import warnings
2
+ import time
3
  import numpy as np
4
  from tensorflow import keras
5
  import os
 
14
 
15
  import DeepDeformationMapRegistration.utils.constants as C
16
  from DeepDeformationMapRegistration.utils.operators import min_max_norm
17
+ from DeepDeformationMapRegistration.utils.misc import segmentation_cardinal_to_ohe
18
  from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
19
  from voxelmorph.tf.layers import SpatialTransformer
20
  from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT
 
23
  from tensorflow.python.keras.utils import Sequence
24
  import sys
25
 
26
+ from collections import defaultdict
27
+
28
+ from Brain_study.format_dataset import SEGMENTATION_LOC
29
+
30
  #import concurrent.futures
31
  #import multiprocessing as mp
32
  import time
 
39
  split=0.7,
40
  combine_segmentations=True,
41
  labels=['all'],
42
+ directory_val=None,
43
+ return_isotropic_shape=False):
44
  self.file_directory = directory
45
  self.batch_size = batch_size
46
  self.combine_segmentations = combine_segmentations
47
  self.labels = labels
48
  self.shuffle = shuffle
49
  self.split = split
50
+ self.return_isotropic_shape=return_isotropic_shape
51
 
52
  if directory_val is None:
53
  self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
 
55
  self.num_samples = len(self.file_list)
56
  training_samples = self.file_list[:int(self.num_samples * self.split)]
57
 
58
+ self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels, return_isotropic_shape=return_isotropic_shape)
59
  if self.split < 1.:
60
  validation_samples = list(set(self.file_list) - set(training_samples))
61
  self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
62
+ validation=True, return_isotropic_shape=return_isotropic_shape)
63
  else:
64
  self.validation_iter = None
65
  else:
 
73
  self.file_list = training_samples + validation_samples
74
 
75
  self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
76
+ self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, labels,
77
  validation=True)
78
 
79
  def get_train_generator(self):
 
99
 
100
  class BatchIterator(Sequence):
101
  def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'],
102
+ zero_grads=[64, 64, 64, 3], validation=False, sequential_labels=True,
103
+ return_isotropic_shape=False, **kwargs):
104
  # super(BatchIterator, self).__init__(n=len(file_list),
105
  # batch_size=batch_size,
106
  # shuffle=shuffle,
 
111
  self.file_list = file_list
112
  self.combine_segmentations = combine_segmentations
113
  self.labels = labels
114
+ self.zero_grads = np.zeros(zero_grads)
115
  self.idx_list = np.arange(0, len(self.file_list))
116
  self.validation = validation
117
+ self.sequential_labels = sequential_labels
118
+ self.return_isotropic_shape = return_isotropic_shape
119
  self._initialize()
120
  self.shuffle_samples()
121
 
122
  def _initialize(self):
123
+ if (isinstance(self.labels[0], str) and self.labels[0].lower() != 'none'):
 
 
 
 
 
 
 
 
 
124
  if self.labels[0] != 'all':
125
+ # Labels are tag names. Convert to numeric and check if the expected labels are in sequence or not
126
+ self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels]
127
+ if not self.sequential_labels:
128
+ self.labels = [SEGMENTATION_LOC[lbl] for lbl in self.labels]
129
+ self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x in self.labels else 0
130
+ else:
131
+ self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0
132
+ else:
133
+ # Use all labels
134
+ if self.sequential_labels:
135
+ self.labels = list(set(SEGMENTATION_LOC.values()))
136
+ self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x else 0
137
+ else:
138
+ self.labels = list(ALL_LABELS)
139
+ self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0
140
+ elif hasattr(self.labels[0], 'lower') and self.labels[0].lower() == 'none':
141
+ # self.labels = list()
142
+ self.labels_dict = dict()
143
+ else:
144
+ assert np.all([isinstance(lbl, (int, float)) for lbl in self.labels]), "Labels must be a str, int or float"
145
+ # Nothing to do, the self.labels contains a list of numbers
146
 
147
  self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0)
148
  #self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.batch_size)
149
  #self.mp_pool = mp.Pool(self.batch_size)
150
 
151
+ with h5py.File(self.file_list[0], 'r') as f:
152
+ self.image_shape = list(f['image'][:].shape)
153
+ self.segm_shape = self.image_shape.copy()
154
+ self.segm_shape[-1] = len(self.labels) if not self.combine_segmentations else 1
155
+
156
+ self.batch_shape = self.image_shape.copy()
157
+ self.batch_shape[-1] = self.image_shape[-1] + self.segm_shape[-1]
158
+
159
  def shuffle_samples(self):
160
  np.random.shuffle(self.idx_list)
161
 
 
165
  def _filter_segmentations(self, segm, segm_labels):
166
  if self.combine_segmentations:
167
  # TODO
168
+ warnings.warn('Cannot select labels when combine_segmentations options is active')
169
  if self.labels[0] != 'all':
170
  if set(self.labels).issubset(set(segm_labels)):
171
  # If labels in self.labels are in segm
 
183
  def _load_sample(self, file_path):
184
  with h5py.File(file_path, 'r') as f:
185
  img = f['image'][:]
186
+ segm = f['segmentation'][:]
187
+ isot_shape = f['isotropic_shape'][:]
 
 
 
 
 
 
 
 
 
188
 
189
+ if not self.combine_segmentations:
190
+ if self.sequential_labels:
191
+ # TODO: I am assuming I want all the labels
192
+ segm = np.squeeze(np.eye(len(self.labels))[segm])
193
+ else:
194
+ lbls_list = list(ALL_LABELS) if self.labels[0] == 'all' else self.labels
195
+ segm = segmentation_cardinal_to_ohe(segm, lbls_list) # Filtering is done here
196
+ # aux = np.zeros(self.segm_shape)
197
+ # aux[..., :segm.shape[-1]] = segm # Ensure the same shape in case there are missing labels in aux
198
+ # segm = aux
199
+ # TODO: selection label segm = aux[..., self.labels] but:
200
+ # what if aux does not have a label in self.labels??
201
+
202
+ img = np.asarray(img, dtype=np.float32)
203
+ segm = np.asarray(segm, dtype=np.float32)
204
+ if not isinstance(self.labels[0], str) or self.labels[0].lower() != 'none' or self.validation: # I expect to ask for the segmentations during val
205
+ # segm = self._filter_segmentations(segm, segm_labels)
206
 
207
  if self.validation:
208
+ ret_val = np.concatenate([img, segm], axis=-1), (img, segm, self.zero_grads), isot_shape
209
  else:
210
+ ret_val = np.concatenate([img, segm], axis=-1), (img, self.zero_grads), isot_shape
211
  else:
212
+ ret_val = img, (img, self.zero_grads), isot_shape
213
  return ret_val
214
 
215
  def __getitem__(self, idx):
216
  in_batch = list()
217
+ isotropic_shape = list()
218
  # out_batch = list()
219
 
220
  batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size]
 
228
  # # out_batch.append(i)
229
  # else:
230
  # No need for multithreading, we are loading a single file
231
+ # in_batch = np.zeros([self.batch_size] + self.batch_shape, dtype=np.float32)
232
+ for batch_idx, f in enumerate(file_list):
233
+ b, i, isot_shape = self._load_sample(f)
234
+ # in_batch[batch_idx, :, :, :, :] = b
235
+ if self.return_isotropic_shape:
236
+ isotropic_shape.append(isot_shape)
237
  in_batch.append(b)
238
  # out_batch.append(i)
239
 
240
+ in_batch = np.asarray(in_batch, dtype=np.float32)
241
+ ret_val = (in_batch, in_batch)
242
+ if self.return_isotropic_shape:
243
+ isotropic_shape = np.asarray(isotropic_shape, dtype=np.int)
244
+ ret_val += (isotropic_shape,)
245
  # out_batch = np.asarray(out_batch)
246
+ return ret_val
247
 
248
  def __iter__(self):
249
  """Create a generator that iterate over the Sequence."""
 
260
  if self.combine_segmentations:
261
  labels = [1]
262
  else:
263
+ labels = self.labels
 
 
264
  return labels
265
 
266
 
Brain_study/format_dataset.py CHANGED
@@ -1,12 +1,18 @@
1
  import h5py
2
  import nibabel as nib
3
  from nilearn.image import resample_img
4
- import os
5
  import re
6
  import numpy as np
7
  from scipy.ndimage import zoom
8
  from tqdm import tqdm
9
 
 
 
 
 
 
 
10
  SEGMENTATION_NR2LBL_LUT = {0: 'background',
11
  2: 'parietal-right-gm',
12
  3: 'lateral-ventricle-left',
@@ -36,15 +42,26 @@ SEGMENTATION_NR2LBL_LUT = {0: 'background',
36
  233: '4th-ventricle',
37
  254: 'fornix-right',
38
  255: 'csf'}
 
39
  SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
40
 
 
 
 
 
 
 
 
 
 
 
41
  IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1'
42
  SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/anatomical_masks'
43
 
44
  IMG_NAME_PATTERN = '(.*).nii.gz'
45
  SEG_NAME_PATTERN = '(.*)_lobes.nii.gz'
46
 
47
- OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test'
48
 
49
  if __name__ == '__main__':
50
  img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
@@ -54,6 +71,9 @@ if __name__ == '__main__':
54
  seg_list.sort()
55
 
56
  os.makedirs(OUT_DIRECTORY, exist_ok=True)
 
 
 
57
  for seg_file in tqdm(seg_list):
58
  img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
59
  img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
@@ -70,6 +90,8 @@ if __name__ == '__main__':
70
  seg = np.asarray(seg.dataobj)
71
  seg = zoom(seg, np.asarray([128]*3) / np.asarray(isot_shape), order=0)
72
 
 
 
73
  unique_lbls = np.unique(seg)[1:] # Omit background
74
  seg_expanded = np.tile(np.zeros_like(seg)[..., np.newaxis], (1, 1, 1, len(unique_lbls)))
75
  for ch, lbl in enumerate(unique_lbls):
@@ -79,12 +101,13 @@ if __name__ == '__main__':
79
 
80
  h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
81
  h5_file.create_dataset('segmentation', data=seg[..., np.newaxis].astype(np.uint8), dtype=np.uint8)
82
- h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
83
  h5_file.create_dataset('segmentation_labels', data=unique_lbls)
84
  h5_file.create_dataset('isotropic_shape', data=isot_shape)
85
 
86
  h5_file.close()
87
-
 
88
 
89
 
90
 
 
1
  import h5py
2
  import nibabel as nib
3
  from nilearn.image import resample_img
4
+ import os, sys
5
  import re
6
  import numpy as np
7
  from scipy.ndimage import zoom
8
  from tqdm import tqdm
9
 
10
+ currentdir = os.path.dirname(os.path.realpath(__file__))
11
+ parentdir = os.path.dirname(currentdir)
12
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
13
+
14
+ from Brain_study.split_dataset import split
15
+
16
  SEGMENTATION_NR2LBL_LUT = {0: 'background',
17
  2: 'parietal-right-gm',
18
  3: 'lateral-ventricle-left',
 
42
  233: '4th-ventricle',
43
  254: 'fornix-right',
44
  255: 'csf'}
45
+
46
  SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
47
 
48
+ ALL_LABELS = {2., 3., 4., 6., 8., 9., 11., 12., 14., 16., 20., 23., 29., 33., 39., 53., 67., 76., 102., 203., 210.,
49
+ 211., 218., 219., 232., 233., 254., 255.}
50
+ LABELS_COMBINED = {0, (2, 6), (3, 9), (4, 8), (11, 12), (14, 16), 20, (23, 33), (29, 254), (39, 53), (67, 76), (102, 203), (210, 211), (218, 219), 232, 233, 255}
51
+ SEGMENTATION_LOC = {}
52
+ for loc, label in enumerate(LABELS_COMBINED):
53
+ if isinstance(label, tuple):
54
+ SEGMENTATION_LOC.update(dict.fromkeys(label, loc))
55
+ else:
56
+ SEGMENTATION_LOC[label] = loc
57
+
58
  IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1'
59
  SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/anatomical_masks'
60
 
61
  IMG_NAME_PATTERN = '(.*).nii.gz'
62
  SEG_NAME_PATTERN = '(.*)_lobes.nii.gz'
63
 
64
+ OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/ERASEME_sequential'
65
 
66
  if __name__ == '__main__':
67
  img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
 
71
  seg_list.sort()
72
 
73
  os.makedirs(OUT_DIRECTORY, exist_ok=True)
74
+
75
+ vectorize_fnc = np.vectorize(lambda x: SEGMENTATION_LOC[x] if x in SEGMENTATION_LOC.keys() else 0)
76
+ change_labels = lambda x: np.reshape(vectorize_fnc(x.ravel()), x.shape)
77
  for seg_file in tqdm(seg_list):
78
  img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
79
  img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
 
90
  seg = np.asarray(seg.dataobj)
91
  seg = zoom(seg, np.asarray([128]*3) / np.asarray(isot_shape), order=0)
92
 
93
+ seg = change_labels(seg) # This way the segmentation numbering is continuous
94
+
95
  unique_lbls = np.unique(seg)[1:] # Omit background
96
  seg_expanded = np.tile(np.zeros_like(seg)[..., np.newaxis], (1, 1, 1, len(unique_lbls)))
97
  for ch, lbl in enumerate(unique_lbls):
 
101
 
102
  h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
103
  h5_file.create_dataset('segmentation', data=seg[..., np.newaxis].astype(np.uint8), dtype=np.uint8)
104
+ # h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
105
  h5_file.create_dataset('segmentation_labels', data=unique_lbls)
106
  h5_file.create_dataset('isotropic_shape', data=isot_shape)
107
 
108
  h5_file.close()
109
+ # We should only have train and test. The val split is done by the batch generator
110
+ split(train_perc=0.70, validation_perc=0.15, test_perc=0.15, data_dir=OUT_DIRECTORY, move_files=True)
111
 
112
 
113
 
Brain_study/split_dataset.py CHANGED
@@ -4,51 +4,55 @@ import random
4
  import warnings
5
 
6
  import math
7
- from shutil import copyfile
8
  from tqdm import tqdm
9
  import concurrent.futures
10
  import numpy as np
11
 
12
 
13
- def copy_file(s_d):
14
  s, d = s_d
15
  file_name = os.path.split(s)[-1]
16
  copyfile(s, os.path.join(d, file_name))
17
  return int(os.path.exists(d))
18
 
19
 
20
- if __name__ == '__main__':
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
23
- parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
24
- parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
25
- parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
26
- parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
27
- parser.add_argument('-r', '--random', type=bool, help='Randomly split the dataset or not. Default: True', default=True)
28
 
29
- args = parser.parse_args()
30
 
31
- assert args.train + args.validation + args.test == 1.0, 'Train+Validation+Test != 1 (100%)'
 
 
 
 
 
 
 
 
32
 
33
- file_set = [os.path.join(args.dir, f) for f in os.listdir(args.dir) if f.endswith(args.format)]
34
- random.shuffle(file_set) if args.random else file_set.sort()
35
 
36
  num_files = len(file_set)
37
- num_validation = math.floor(num_files * args.validation)
38
- num_test = math.floor(num_files * args.test)
39
  num_train = num_files - num_test - num_validation
40
 
41
- dataset_root, dataset_name = os.path.split(args.dir)
42
- dst_train = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'train_set')
43
- dst_validation = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'validation_set')
44
- dst_test = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'test_set')
45
 
46
  print('OUTPUT INFORMATION\n=============')
47
  print('Train:\t\t{}'.format(num_train))
48
  print('Validation:\t{}'.format(num_validation))
49
  print('Test:\t\t{}'.format(num_test))
50
  print('Num. samples\t{}'.format(num_files))
51
- print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_'+dataset_name))
52
 
53
  dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
54
 
@@ -57,8 +61,10 @@ if __name__ == '__main__':
57
  os.makedirs(dst_test, exist_ok=True)
58
 
59
  progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
 
 
60
  with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
61
- results = list(tqdm(ex.map(copy_file, zip(file_set, dest)), desc='Copying files', total=num_files))
62
 
63
  num_copies = np.sum(results)
64
  if num_copies == num_files:
@@ -66,3 +72,18 @@ if __name__ == '__main__':
66
  else:
67
  warnings.warn('Missing files: {}'.format(num_files - num_copies))
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import warnings
5
 
6
  import math
7
+ from shutil import copyfile, move
8
  from tqdm import tqdm
9
  import concurrent.futures
10
  import numpy as np
11
 
12
 
13
+ def copy_file_fnc(s_d):
14
  s, d = s_d
15
  file_name = os.path.split(s)[-1]
16
  copyfile(s, os.path.join(d, file_name))
17
  return int(os.path.exists(d))
18
 
19
 
20
+ def move_file_fnc(s_d):
21
+ s, d = s_d
22
+ file_name = os.path.split(s)[-1]
23
+ move(s, os.path.join(d, file_name))
24
+ return int(os.path.exists(d))
 
 
 
25
 
 
26
 
27
+ def split(train_perc: float=0.7,
28
+ validation_perc: float=0.15,
29
+ test_perc: float=0.15,
30
+ data_dir: str='',
31
+ file_format: str='h5',
32
+ random_split: bool=True,
33
+ move_files: bool=False):
34
+ assert train_perc + validation_perc + test_perc == 1.0, 'Train+Validation+Test != 1 (100%)'
35
+ assert train_perc > 0 and test_perc > 0, 'Train and test percentages must be greater than zero'
36
 
37
+ file_set = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(file_format)]
38
+ random.shuffle(file_set) if random_split else file_set.sort()
39
 
40
  num_files = len(file_set)
41
+ num_validation = math.floor(num_files * validation_perc)
42
+ num_test = math.floor(num_files * test_perc)
43
  num_train = num_files - num_test - num_validation
44
 
45
+ dataset_root, dataset_name = os.path.split(data_dir)
46
+ dst_train = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'train_set')
47
+ dst_validation = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'validation_set')
48
+ dst_test = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'test_set')
49
 
50
  print('OUTPUT INFORMATION\n=============')
51
  print('Train:\t\t{}'.format(num_train))
52
  print('Validation:\t{}'.format(num_validation))
53
  print('Test:\t\t{}'.format(num_test))
54
  print('Num. samples\t{}'.format(num_files))
55
+ print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_' + dataset_name))
56
 
57
  dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
58
 
 
61
  os.makedirs(dst_test, exist_ok=True)
62
 
63
  progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
64
+ operation = move_file_fnc if move_files else copy_file_fnc
65
+ desc = 'Moving files' if move_files else 'Copying files'
66
  with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
67
+ results = list(tqdm(ex.map(operation, zip(file_set, dest)), desc=desc, total=num_files))
68
 
69
  num_copies = np.sum(results)
70
  if num_copies == num_files:
 
72
  else:
73
  warnings.warn('Missing files: {}'.format(num_files - num_copies))
74
 
75
+ if __name__ == '__main__':
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
78
+ parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
79
+ parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
80
+ parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
81
+ parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
82
+ parser.add_argument('-r', '--random', help='Randomly split the dataset or not. Default: True', action='store_true', default=True)
83
+ parser.add_argument('-m', '--movefiles', help='Move files. Otherwise copy. Default: False', action='store_true', default=False)
84
+
85
+ args = parser.parse_args()
86
+
87
+ split(args.train, args.validation, args.test, args.dir, args.format, args.random, args.movefiles)
88
+
89
+
COMET/Build_test_set.py CHANGED
@@ -18,10 +18,11 @@ import DeepDeformationMapRegistration.utils.constants as C
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.layers import AugmentationLayer
20
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
21
- from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
22
  from tqdm import tqdm
23
 
24
  from Brain_study.data_generator import BatchGenerator
 
25
 
26
  from skimage.measure import regionprops
27
  from scipy.interpolate import griddata
@@ -35,12 +36,6 @@ POINTS = None
35
  MISSING_CENTROID = np.asarray([[np.nan]*3])
36
 
37
 
38
- def get_mov_centroids(fix_seg, disp_map, nb_labels=28, brain_study=True):
39
- fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=brain_study)
40
- disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
41
- return fix_centroids, fix_centroids + disp, disp
42
-
43
-
44
  if __name__ == '__main__':
45
  parser = argparse.ArgumentParser()
46
  parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
@@ -48,6 +43,7 @@ if __name__ == '__main__':
48
  parser.add_argument('--gpu', type=int, help='GPU', default=0)
49
  parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
50
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
 
51
  args = parser.parse_args()
52
 
53
  assert args.dataset != '', "Missing original dataset dataset"
@@ -68,12 +64,20 @@ if __name__ == '__main__':
68
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
69
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
70
 
71
- data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
72
 
73
  img_generator = data_generator.get_train_generator()
74
  nb_labels = len(img_generator.get_segmentation_labels())
75
  image_input_shape = img_generator.get_data_shape()[-1][:-1]
76
- image_output_shape = [64] * 3
 
 
 
 
 
 
 
 
77
  # Build model
78
 
79
  xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
@@ -104,15 +108,17 @@ if __name__ == '__main__':
104
  config.gpu_options.allow_growth = True
105
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
106
 
 
 
107
  sess = tf.Session(config=config)
108
  tf.keras.backend.set_session(sess)
109
  with sess.as_default():
110
  sess.run(tf.global_variables_initializer())
111
  progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
112
- for step, (in_batch, _) in progress_bar:
113
  fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
114
 
115
- fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, False)
116
 
117
  out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
118
  out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
@@ -127,7 +133,7 @@ if __name__ == '__main__':
127
  f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
128
  f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
129
  f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
130
-
131
  with h5py.File(out_file_dm, 'w') as f:
132
  f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
133
  f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
 
18
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.layers import AugmentationLayer
20
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
21
+ from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator
22
  from tqdm import tqdm
23
 
24
  from Brain_study.data_generator import BatchGenerator
25
+ from Brain_study.Build_test_set import get_mov_centroids
26
 
27
  from skimage.measure import regionprops
28
  from scipy.interpolate import griddata
 
36
  MISSING_CENTROID = np.asarray([[np.nan]*3])
37
 
38
 
 
 
 
 
 
 
39
  if __name__ == '__main__':
40
  parser = argparse.ArgumentParser()
41
  parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
 
43
  parser.add_argument('--gpu', type=int, help='GPU', default=0)
44
  parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
45
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
46
+ parser.add_argument('--output_shape', help='If an int, a cubic shape is presumed. Otherwise provide it as a space separated sequence', nargs='+', default=128)
47
  args = parser.parse_args()
48
 
49
  assert args.dataset != '', "Missing original dataset dataset"
 
64
  os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
65
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
66
 
67
+ data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, [0, 1, 2], return_isotropic_shape=True)
68
 
69
  img_generator = data_generator.get_train_generator()
70
  nb_labels = len(img_generator.get_segmentation_labels())
71
  image_input_shape = img_generator.get_data_shape()[-1][:-1]
72
+
73
+ if isinstance(args.output_shape, int):
74
+ image_output_shape = [args.output_shape] * 3
75
+ elif isinstance(args.output_shape, list):
76
+ assert len(args.output_shape) == 3, 'Invalid output shape, expected three values and got {}'.format(len(args.output_shape))
77
+ image_output_shape = [int(s) for s in args.output_shape]
78
+ else:
79
+ raise ValueError('Invalid output_shape. Must be an int or a space-separated sequence of ints')
80
+ print('Scaling to: ', image_output_shape)
81
  # Build model
82
 
83
  xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
 
108
  config.gpu_options.allow_growth = True
109
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
110
 
111
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=8)
112
+
113
  sess = tf.Session(config=config)
114
  tf.keras.backend.set_session(sess)
115
  with sess.as_default():
116
  sess.run(tf.global_variables_initializer())
117
  progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
118
+ for step, (in_batch, _, isotropic_shape) in progress_bar:
119
  fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
120
 
121
+ fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, False, dm_interp=dm_interp)
122
 
123
  out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
124
  out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
 
133
  f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
134
  f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
135
  f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
136
+ f.create_dataset('isotropic_shape', data=np.squeeze(isotropic_shape))
137
  with h5py.File(out_file_dm, 'w') as f:
138
  f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
139
  f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
COMET/COMET_train.py CHANGED
@@ -8,7 +8,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
8
 
9
  from datetime import datetime
10
 
11
- from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
12
  from tensorflow.python.keras.utils import Progbar
13
  from tensorflow.keras import Input
14
  from tensorflow.keras.models import Model
@@ -27,6 +27,7 @@ from Brain_study.data_generator import BatchGenerator
27
  from Brain_study.utils import SummaryDictionary, named_logs
28
 
29
  import COMET.augmentation_constants as COMET_C
 
30
 
31
  import numpy as np
32
  import tensorflow as tf
@@ -40,7 +41,7 @@ import warnings
40
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
41
  segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
42
  acc_gradients=1, batch_size=16, image_size=64,
43
- unet=[16, 32, 64, 128, 256], head=[16, 16]):
44
  # 0. Input checks
45
  assert dataset_folder is not None and output_folder is not None
46
  if model_file != '':
@@ -54,12 +55,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
54
  if batch_size != 1 and acc_gradients != 1:
55
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
56
 
57
- if freeze_layers is not None:
58
- assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
59
- 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
60
- freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
61
- if len(freeze_layers) > 1:
62
- freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
 
 
 
 
63
 
64
  os.makedirs(output_folder, exist_ok=True)
65
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
@@ -159,34 +164,37 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
159
  if model_file != '':
160
  network.load_weights(model_file, by_name=True)
161
  print('MODEL LOCATION: ', model_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # 4. Freeze/unfreeze model layers
163
- # freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
164
- # for l in freeze_layers:
165
- # network.layers[l].trainable = False
166
- # msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
167
- # print(msg)
168
- # log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
169
- if freeze_layers is not None:
170
- aux = list()
171
- for r in freeze_layers:
172
- for l in range(*r):
173
- network.layers[l].trainable = False
174
- aux.append(l)
175
- aux.sort()
176
- msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
177
  else:
178
  msg = "[INF] None frozen layers"
179
  print(msg)
180
  log_file.write(msg)
181
- # network.trainable = False # Freeze the base model
182
- # # Create a new model on top
183
- # input_new_model = keras.Input(network.input_shape)
184
- # x = base_model(input_new_model, training=False)
185
- # x =
186
- # network = keras.Model(input_new_model, x)
187
-
188
- network.summary()
189
- network.summary(print_fn=log_file.writelines)
190
  # Complete the model with the augmentation layer
191
  augm_train_input_shape = train_generator.get_data_shape()[-1]
192
  input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
@@ -269,6 +277,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
269
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
270
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
271
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
 
272
 
273
  losses = {'transformer': loss_fnc,
274
  'flow': vxm.losses.Grad('l2').loss}
@@ -281,7 +290,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
281
  loss_weights = {'transformer': 1.,
282
  'flow': rw}
283
 
284
- optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
285
  network.compile(optimizer=optimizer,
286
  loss=losses,
287
  loss_weights=loss_weights,
@@ -292,6 +301,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
292
  callback_early_stop.set_model(network)
293
  callback_best_model.set_model(network)
294
  callback_save_checkpoint.set_model(network)
 
295
 
296
  summary = SummaryDictionary(network, C.BATCH_SIZE)
297
  names = network.metrics_names
@@ -303,12 +313,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
303
  callback_early_stop.on_train_begin()
304
  callback_best_model.on_train_begin()
305
  callback_save_checkpoint.on_train_begin()
 
306
 
307
- for epoch in range(C.EPOCHS):
308
  callback_tensorboard.on_epoch_begin(epoch)
309
  callback_early_stop.on_epoch_begin(epoch)
310
  callback_best_model.on_epoch_begin(epoch)
311
  callback_save_checkpoint.on_epoch_begin(epoch)
 
312
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
313
  print("TRAIN")
314
 
@@ -318,6 +330,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
318
  callback_best_model.on_train_batch_begin(step)
319
  callback_save_checkpoint.on_train_batch_begin(step)
320
  callback_early_stop.on_train_batch_begin(step)
 
321
 
322
  try:
323
  fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
@@ -351,6 +364,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
351
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
352
  callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
353
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
 
354
  progress_bar.update(step, zip(names, ret))
355
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
356
  val_values = progress_bar._values.copy()
@@ -384,10 +398,12 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
384
  callback_best_model.on_epoch_end(epoch, epoch_summary)
385
  callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
386
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
 
387
  print('End of epoch {}: '.format(epoch), ret, '\n')
388
 
389
  callback_tensorboard.on_train_end()
390
  callback_best_model.on_train_end()
391
  callback_save_checkpoint.on_train_end()
392
  callback_early_stop.on_train_end()
 
393
  # 7. Wrap up
 
8
 
9
  from datetime import datetime
10
 
11
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
12
  from tensorflow.python.keras.utils import Progbar
13
  from tensorflow.keras import Input
14
  from tensorflow.keras.models import Model
 
27
  from Brain_study.utils import SummaryDictionary, named_logs
28
 
29
  import COMET.augmentation_constants as COMET_C
30
+ from COMET.utils import freeze_layers_by_group
31
 
32
  import numpy as np
33
  import tensorflow as tf
 
41
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
42
  segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
43
  acc_gradients=1, batch_size=16, image_size=64,
44
+ unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
45
  # 0. Input checks
46
  assert dataset_folder is not None and output_folder is not None
47
  if model_file != '':
 
55
  if batch_size != 1 and acc_gradients != 1:
56
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
57
 
58
+ if resume is not None:
59
+ try:
60
+ assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
61
+ output_folder = resume
62
+ resume = True
63
+ except AssertionError:
64
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
65
+ resume = False
66
+ else:
67
+ resume = False
68
 
69
  os.makedirs(output_folder, exist_ok=True)
70
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
 
164
  if model_file != '':
165
  network.load_weights(model_file, by_name=True)
166
  print('MODEL LOCATION: ', model_file)
167
+
168
+ resume_epoch = 0
169
+ if resume:
170
+ cp_dir = os.path.join(output_folder, 'checkpoints')
171
+ cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
172
+ if len(cp_file_list):
173
+ cp_file_list.sort()
174
+ checkpoint_file = cp_file_list[-1]
175
+ if os.path.exists(checkpoint_file):
176
+ network.load_weights(checkpoint_file, by_name=True)
177
+ print('Loaded checkpoint file: ' + checkpoint_file)
178
+ try:
179
+ resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
180
+ except TypeError:
181
+ # Checkpoint file has no epoch number in the name
182
+ resume_epoch = 0
183
+ print('Resuming from epoch: {:d}'.format(resume_epoch))
184
+ else:
185
+ warnings.warn('Checkpoint file NOT found. Training from scratch')
186
+
187
  # 4. Freeze/unfreeze model layers
188
+ _, frozen_layers = freeze_layers_by_group(network, freeze_layers)
189
+ if frozen_layers is not None:
190
+ msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers]))
 
 
 
 
 
 
 
 
 
 
 
191
  else:
192
  msg = "[INF] None frozen layers"
193
  print(msg)
194
  log_file.write(msg)
195
+
196
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
197
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines)
 
 
 
 
 
 
198
  # Complete the model with the augmentation layer
199
  augm_train_input_shape = train_generator.get_data_shape()[-1]
200
  input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
 
277
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
278
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
279
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
280
+ callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
281
 
282
  losses = {'transformer': loss_fnc,
283
  'flow': vxm.losses.Grad('l2').loss}
 
290
  loss_weights = {'transformer': 1.,
291
  'flow': rw}
292
 
293
+ optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE)
294
  network.compile(optimizer=optimizer,
295
  loss=losses,
296
  loss_weights=loss_weights,
 
301
  callback_early_stop.set_model(network)
302
  callback_best_model.set_model(network)
303
  callback_save_checkpoint.set_model(network)
304
+ callback_lr.set_model(network)
305
 
306
  summary = SummaryDictionary(network, C.BATCH_SIZE)
307
  names = network.metrics_names
 
313
  callback_early_stop.on_train_begin()
314
  callback_best_model.on_train_begin()
315
  callback_save_checkpoint.on_train_begin()
316
+ callback_lr.on_train_begin()
317
 
318
+ for epoch in range(resume_epoch, C.EPOCHS):
319
  callback_tensorboard.on_epoch_begin(epoch)
320
  callback_early_stop.on_epoch_begin(epoch)
321
  callback_best_model.on_epoch_begin(epoch)
322
  callback_save_checkpoint.on_epoch_begin(epoch)
323
+ callback_lr.on_epoch_begin(epoch)
324
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
325
  print("TRAIN")
326
 
 
330
  callback_best_model.on_train_batch_begin(step)
331
  callback_save_checkpoint.on_train_batch_begin(step)
332
  callback_early_stop.on_train_batch_begin(step)
333
+ callback_lr.on_train_batch_begin(step)
334
 
335
  try:
336
  fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
 
364
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
365
  callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
366
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
367
+ callback_lr.on_train_batch_end(step, named_logs(network, ret))
368
  progress_bar.update(step, zip(names, ret))
369
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
370
  val_values = progress_bar._values.copy()
 
398
  callback_best_model.on_epoch_end(epoch, epoch_summary)
399
  callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
400
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
401
+ callback_lr.on_epoch_end(epoch, epoch_summary)
402
  print('End of epoch {}: '.format(epoch), ret, '\n')
403
 
404
  callback_tensorboard.on_train_end()
405
  callback_best_model.on_train_end()
406
  callback_save_checkpoint.on_train_end()
407
  callback_early_stop.on_train_end()
408
+ callback_lr.on_train_end()
409
  # 7. Wrap up
COMET/COMET_train_UW.py CHANGED
@@ -6,7 +6,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
6
 
7
  from datetime import datetime
8
 
9
- from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
10
  from tensorflow.python.keras.utils import Progbar
11
  from tensorflow.keras import Input
12
  from tensorflow.keras.models import Model
@@ -26,6 +26,7 @@ from Brain_study.data_generator import BatchGenerator
26
  from Brain_study.utils import SummaryDictionary, named_logs
27
 
28
  import COMET.augmentation_constants as COMET_C
 
29
 
30
  import numpy as np
31
  import tensorflow as tf
@@ -39,7 +40,7 @@ import warnings
39
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
40
  simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3,
41
  freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64,
42
- unet=[16, 32, 64, 128, 256], head=[16, 16]):
43
  # 0. Input checks
44
  assert dataset_folder is not None and output_folder is not None
45
  if model_file != '':
@@ -53,15 +54,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
53
  if batch_size != 1 and acc_gradients != 1:
54
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
55
 
56
- if freeze_layers is not None:
57
- assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
58
- 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
59
- multiple_ranges = 'TOP' in freeze_layers
60
- freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
61
- freeze_layers = freeze_layers[0] if multiple_ranges else freeze_layers
62
-
63
- # if len(freeze_layers) > 1:
64
- # freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
 
65
 
66
  os.makedirs(output_folder, exist_ok=True)
67
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
@@ -103,8 +105,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
103
  print(aux)
104
 
105
  # 2. Data generator
 
106
  data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
107
- C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
108
  directory_val=C.VALIDATION_DATASET)
109
 
110
  train_generator = data_generator.get_train_generator()
@@ -137,22 +140,36 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
137
  print('MODEL LOCATION: ', model_file)
138
  network.load_weights(model_file, by_name=True)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # 4. Freeze/unfreeze model layers
141
- if freeze_layers is not None:
142
- aux = list()
143
- for r in freeze_layers:
144
- for l in range(*r):
145
- network.layers[l].trainable = False
146
- aux.append(l)
147
- aux.sort()
148
- msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
149
  else:
150
  msg = "[INF] None frozen layers"
151
  print(msg)
152
  log_file.write(msg)
153
 
154
- network.summary()
155
- network.summary(print_fn=log_file.write)
156
  # Complete the model with the augmentation layer
157
  input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train')
158
  augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
@@ -250,7 +267,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
250
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
251
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
252
 
253
- optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
 
 
254
  full_model.compile(optimizer=optimizer,
255
  loss=None, )
256
 
@@ -259,6 +278,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
259
  callback_early_stop.set_model(full_model)
260
  callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
261
  callback_save_checkpoint.set_model(network) # ONLY SAVE THE NETWORK!!!
 
262
 
263
  summary = SummaryDictionary(full_model, C.BATCH_SIZE)
264
  names = full_model.metrics_names
@@ -271,12 +291,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
271
  callback_early_stop.on_train_begin()
272
  callback_best_model.on_train_begin()
273
  callback_save_checkpoint.on_train_begin()
 
274
 
275
- for epoch in range(C.EPOCHS):
276
  callback_tensorboard.on_epoch_begin(epoch)
277
  callback_early_stop.on_epoch_begin(epoch)
278
  callback_best_model.on_epoch_begin(epoch)
279
  callback_save_checkpoint.on_epoch_begin(epoch)
 
 
280
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
281
  print("TRAIN")
282
 
@@ -287,6 +310,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
287
  callback_best_model.on_train_batch_begin(step)
288
  callback_save_checkpoint.on_train_batch_begin(step)
289
  callback_early_stop.on_train_batch_begin(step)
 
290
 
291
  try:
292
  fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
@@ -310,6 +334,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
310
  callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
311
  callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret))
312
  callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
 
313
  progress_bar.update(step, zip(names, ret))
314
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
315
  val_values = progress_bar._values.copy()
@@ -341,10 +366,12 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
341
  callback_best_model.on_epoch_end(epoch, epoch_summary)
342
  callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
343
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
 
344
  print('End of epoch {}: '.format(epoch), ret, '\n')
345
 
346
  callback_tensorboard.on_train_end()
347
  callback_best_model.on_train_end()
348
  callback_save_checkpoint.on_train_end()
349
  callback_early_stop.on_train_end()
 
350
  # 7. Wrap up
 
6
 
7
  from datetime import datetime
8
 
9
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
10
  from tensorflow.python.keras.utils import Progbar
11
  from tensorflow.keras import Input
12
  from tensorflow.keras.models import Model
 
26
  from Brain_study.utils import SummaryDictionary, named_logs
27
 
28
  import COMET.augmentation_constants as COMET_C
29
+ from COMET.utils import freeze_layers_by_group
30
 
31
  import numpy as np
32
  import tensorflow as tf
 
40
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
41
  simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3,
42
  freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64,
43
+ unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
44
  # 0. Input checks
45
  assert dataset_folder is not None and output_folder is not None
46
  if model_file != '':
 
54
  if batch_size != 1 and acc_gradients != 1:
55
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
56
 
57
+ if resume is not None:
58
+ try:
59
+ assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
60
+ output_folder = resume
61
+ resume = True
62
+ except AssertionError:
63
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
64
+ resume = False
65
+ else:
66
+ resume = False
67
 
68
  os.makedirs(output_folder, exist_ok=True)
69
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
 
105
  print(aux)
106
 
107
  # 2. Data generator
108
+ used_labels = [0, 1, 2]
109
  data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
110
+ C.TRAINING_PERC, labels=used_labels, combine_segmentations=False,
111
  directory_val=C.VALIDATION_DATASET)
112
 
113
  train_generator = data_generator.get_train_generator()
 
140
  print('MODEL LOCATION: ', model_file)
141
  network.load_weights(model_file, by_name=True)
142
 
143
+ resume_epoch = 0
144
+ if resume:
145
+ cp_dir = os.path.join(output_folder, 'checkpoints')
146
+ cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
147
+ if len(cp_file_list):
148
+ cp_file_list.sort()
149
+ checkpoint_file = cp_file_list[-1]
150
+ if os.path.exists(checkpoint_file):
151
+ network.load_weights(checkpoint_file, by_name=True)
152
+ print('Loaded checkpoint file: ' + checkpoint_file)
153
+ try:
154
+ resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
155
+ except TypeError:
156
+ # Checkpoint file has no epoch number in the name
157
+ resume_epoch = 0
158
+ print('Resuming from epoch: {:d}'.format(resume_epoch))
159
+ else:
160
+ warnings.warn('Checkpoint file NOT found. Training from scratch')
161
+
162
  # 4. Freeze/unfreeze model layers
163
+ _, frozen_layers = freeze_layers_by_group(network, freeze_layers)
164
+ if frozen_layers is not None:
165
+ msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers]))
 
 
 
 
 
166
  else:
167
  msg = "[INF] None frozen layers"
168
  print(msg)
169
  log_file.write(msg)
170
 
171
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
172
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.write)
173
  # Complete the model with the augmentation layer
174
  input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train')
175
  augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
 
267
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
268
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
269
 
270
+ callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
271
+
272
+ optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE)
273
  full_model.compile(optimizer=optimizer,
274
  loss=None, )
275
 
 
278
  callback_early_stop.set_model(full_model)
279
  callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
280
  callback_save_checkpoint.set_model(network) # ONLY SAVE THE NETWORK!!!
281
+ callback_lr.set_model(full_model)
282
 
283
  summary = SummaryDictionary(full_model, C.BATCH_SIZE)
284
  names = full_model.metrics_names
 
291
  callback_early_stop.on_train_begin()
292
  callback_best_model.on_train_begin()
293
  callback_save_checkpoint.on_train_begin()
294
+ callback_lr.on_train_begin()
295
 
296
+ for epoch in range(resume_epoch, C.EPOCHS):
297
  callback_tensorboard.on_epoch_begin(epoch)
298
  callback_early_stop.on_epoch_begin(epoch)
299
  callback_best_model.on_epoch_begin(epoch)
300
  callback_save_checkpoint.on_epoch_begin(epoch)
301
+ callback_lr.on_epoch_begin(epoch)
302
+
303
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
304
  print("TRAIN")
305
 
 
310
  callback_best_model.on_train_batch_begin(step)
311
  callback_save_checkpoint.on_train_batch_begin(step)
312
  callback_early_stop.on_train_batch_begin(step)
313
+ callback_lr.on_train_batch_begin(step)
314
 
315
  try:
316
  fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
 
334
  callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
335
  callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret))
336
  callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
337
+ callback_lr.on_train_batch_end(step, named_logs(full_model, ret))
338
  progress_bar.update(step, zip(names, ret))
339
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
340
  val_values = progress_bar._values.copy()
 
366
  callback_best_model.on_epoch_end(epoch, epoch_summary)
367
  callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
368
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
369
+ callback_lr.on_epoch_end(epoch, epoch_summary)
370
  print('End of epoch {}: '.format(epoch), ret, '\n')
371
 
372
  callback_tensorboard.on_train_end()
373
  callback_best_model.on_train_end()
374
  callback_save_checkpoint.on_train_end()
375
  callback_early_stop.on_train_end()
376
+ callback_lr.on_train_end()
377
  # 7. Wrap up
COMET/COMET_train_seggguided.py CHANGED
@@ -8,7 +8,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
8
 
9
  from datetime import datetime
10
 
11
- from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
12
  from tensorflow.python.keras.utils import Progbar
13
  from tensorflow.keras import Input
14
  from tensorflow.keras.models import Model
@@ -27,6 +27,7 @@ from Brain_study.data_generator import BatchGenerator
27
  from Brain_study.utils import SummaryDictionary, named_logs
28
 
29
  import COMET.augmentation_constants as COMET_C
 
30
 
31
  import numpy as np
32
  import tensorflow as tf
@@ -40,7 +41,7 @@ import warnings
40
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
41
  segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
42
  acc_gradients=1, batch_size=16, image_size=64,
43
- unet=[16, 32, 64, 128, 256], head=[16, 16]):
44
  # 0. Input checks
45
  assert dataset_folder is not None and output_folder is not None
46
  if model_file != '':
@@ -54,12 +55,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
54
  if batch_size != 1 and acc_gradients != 1:
55
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
56
 
57
- if freeze_layers is not None:
58
- assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
59
- 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
60
- freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
61
- if len(freeze_layers) > 1:
62
- freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
 
 
 
 
63
 
64
  os.makedirs(output_folder, exist_ok=True)
65
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
@@ -101,9 +106,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
101
  print(aux)
102
 
103
  # 2. Data generator
104
- used_labels = 'all'
105
  data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
106
- C.TRAINING_PERC, labels=[used_labels], combine_segmentations=False,
107
  directory_val=C.VALIDATION_DATASET)
108
 
109
  train_generator = data_generator.get_train_generator()
@@ -163,34 +168,37 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
163
  if model_file != '':
164
  network.load_weights(model_file, by_name=True)
165
  print('MODEL LOCATION: ', model_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # 4. Freeze/unfreeze model layers
167
- # freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
168
- # for l in freeze_layers:
169
- # network.layers[l].trainable = False
170
- # msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
171
- # print(msg)
172
- # log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
173
- if freeze_layers is not None:
174
- aux = list()
175
- for r in freeze_layers:
176
- for l in range(*r):
177
- network.layers[l].trainable = False
178
- aux.append(l)
179
- aux.sort()
180
- msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
181
  else:
182
  msg = "[INF] None frozen layers"
183
  print(msg)
184
  log_file.write(msg)
185
- # network.trainable = False # Freeze the base model
186
- # # Create a new model on top
187
- # input_new_model = keras.Input(network.input_shape)
188
- # x = base_model(input_new_model, training=False)
189
- # x =
190
- # network = keras.Model(input_new_model, x)
191
-
192
- network.summary()
193
- network.summary(print_fn=log_file.writelines)
194
  # Complete the model with the augmentation layer
195
  augm_train_input_shape = train_generator.get_data_shape()[0]
196
  input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
@@ -282,6 +290,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
282
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
283
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
284
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
 
285
 
286
  losses = {'transformer': loss_fnc,
287
  'seg_transformer': loss_segm,
@@ -300,8 +309,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
300
  'seg_transformer': 1.,
301
  'flow': rw}
302
 
303
-
304
- optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
305
  network.compile(optimizer=optimizer,
306
  loss=losses,
307
  loss_weights=loss_weights,
@@ -312,6 +320,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
312
  callback_early_stop.set_model(network)
313
  callback_best_model.set_model(network)
314
  callback_save_checkpoint.set_model(network)
 
315
 
316
  summary = SummaryDictionary(network, C.BATCH_SIZE)
317
  names = network.metrics_names
@@ -323,12 +332,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
323
  callback_early_stop.on_train_begin()
324
  callback_best_model.on_train_begin()
325
  callback_save_checkpoint.on_train_begin()
 
326
 
327
- for epoch in range(C.EPOCHS):
328
  callback_tensorboard.on_epoch_begin(epoch)
329
  callback_early_stop.on_epoch_begin(epoch)
330
  callback_best_model.on_epoch_begin(epoch)
331
  callback_save_checkpoint.on_epoch_begin(epoch)
 
 
332
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
333
  print("TRAIN")
334
 
@@ -338,6 +350,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
338
  callback_best_model.on_train_batch_begin(step)
339
  callback_save_checkpoint.on_train_batch_begin(step)
340
  callback_early_stop.on_train_batch_begin(step)
 
341
 
342
  try:
343
  fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
@@ -371,6 +384,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
371
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
372
  callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
373
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
 
 
374
  progress_bar.update(step, zip(names, ret))
375
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
376
  val_values = progress_bar._values.copy()
@@ -405,10 +420,13 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
405
  callback_best_model.on_epoch_end(epoch, epoch_summary)
406
  callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
407
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
 
 
408
  print('End of epoch {}: '.format(epoch), ret, '\n')
409
 
410
  callback_tensorboard.on_train_end()
411
  callback_best_model.on_train_end()
412
  callback_save_checkpoint.on_train_end()
413
  callback_early_stop.on_train_end()
 
414
  # 7. Wrap up
 
8
 
9
  from datetime import datetime
10
 
11
+ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
12
  from tensorflow.python.keras.utils import Progbar
13
  from tensorflow.keras import Input
14
  from tensorflow.keras.models import Model
 
27
  from Brain_study.utils import SummaryDictionary, named_logs
28
 
29
  import COMET.augmentation_constants as COMET_C
30
+ from COMET.utils import freeze_layers_by_group
31
 
32
  import numpy as np
33
  import tensorflow as tf
 
41
  def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
42
  segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
43
  acc_gradients=1, batch_size=16, image_size=64,
44
+ unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
45
  # 0. Input checks
46
  assert dataset_folder is not None and output_folder is not None
47
  if model_file != '':
 
55
  if batch_size != 1 and acc_gradients != 1:
56
  warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
57
 
58
+ if resume is not None:
59
+ try:
60
+ assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
61
+ output_folder = resume
62
+ resume = True
63
+ except AssertionError:
64
+ output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
65
+ resume = False
66
+ else:
67
+ resume = False
68
 
69
  os.makedirs(output_folder, exist_ok=True)
70
  # dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
 
106
  print(aux)
107
 
108
  # 2. Data generator
109
+ used_labels = [0, 1, 2]
110
  data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
111
+ C.TRAINING_PERC, labels=used_labels, combine_segmentations=False,
112
  directory_val=C.VALIDATION_DATASET)
113
 
114
  train_generator = data_generator.get_train_generator()
 
168
  if model_file != '':
169
  network.load_weights(model_file, by_name=True)
170
  print('MODEL LOCATION: ', model_file)
171
+
172
+ resume_epoch = 0
173
+ if resume:
174
+ cp_dir = os.path.join(output_folder, 'checkpoints')
175
+ cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
176
+ if len(cp_file_list):
177
+ cp_file_list.sort()
178
+ checkpoint_file = cp_file_list[-1]
179
+ if os.path.exists(checkpoint_file):
180
+ network.load_weights(checkpoint_file, by_name=True)
181
+ print('Loaded checkpoint file: ' + checkpoint_file)
182
+ try:
183
+ resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
184
+ except TypeError:
185
+ # Checkpoint file has no epoch number in the name
186
+ resume_epoch = 0
187
+ print('Resuming from epoch: {:d}'.format(resume_epoch))
188
+ else:
189
+ warnings.warn('Checkpoint file NOT found. Training from scratch')
190
+
191
  # 4. Freeze/unfreeze model layers
192
+ _, frozen_layers = freeze_layers_by_group(network, freeze_layers)
193
+ if frozen_layers is not None:
194
+ msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers]))
 
 
 
 
 
 
 
 
 
 
 
195
  else:
196
  msg = "[INF] None frozen layers"
197
  print(msg)
198
  log_file.write(msg)
199
+
200
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
201
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines)
 
 
 
 
 
 
202
  # Complete the model with the augmentation layer
203
  augm_train_input_shape = train_generator.get_data_shape()[0]
204
  input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
 
290
  save_best_only=True, monitor='val_loss', verbose=1, mode='min')
291
  callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
292
  save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
293
+ callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
294
 
295
  losses = {'transformer': loss_fnc,
296
  'seg_transformer': loss_segm,
 
309
  'seg_transformer': 1.,
310
  'flow': rw}
311
 
312
+ optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE)
 
313
  network.compile(optimizer=optimizer,
314
  loss=losses,
315
  loss_weights=loss_weights,
 
320
  callback_early_stop.set_model(network)
321
  callback_best_model.set_model(network)
322
  callback_save_checkpoint.set_model(network)
323
+ callback_lr.set_model(network)
324
 
325
  summary = SummaryDictionary(network, C.BATCH_SIZE)
326
  names = network.metrics_names
 
332
  callback_early_stop.on_train_begin()
333
  callback_best_model.on_train_begin()
334
  callback_save_checkpoint.on_train_begin()
335
+ callback_lr.on_train_begin()
336
 
337
+ for epoch in range(resume_epoch, C.EPOCHS):
338
  callback_tensorboard.on_epoch_begin(epoch)
339
  callback_early_stop.on_epoch_begin(epoch)
340
  callback_best_model.on_epoch_begin(epoch)
341
  callback_save_checkpoint.on_epoch_begin(epoch)
342
+ callback_lr.on_epoch_begin(epoch)
343
+
344
  print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
345
  print("TRAIN")
346
 
 
350
  callback_best_model.on_train_batch_begin(step)
351
  callback_save_checkpoint.on_train_batch_begin(step)
352
  callback_early_stop.on_train_batch_begin(step)
353
+ callback_lr.on_train_batch_begin(step)
354
 
355
  try:
356
  fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
 
384
  callback_best_model.on_train_batch_end(step, named_logs(network, ret))
385
  callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
386
  callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
387
+ callback_lr.on_predict_batch_end(step, named_logs(network, ret))
388
+
389
  progress_bar.update(step, zip(names, ret))
390
  log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
391
  val_values = progress_bar._values.copy()
 
420
  callback_best_model.on_epoch_end(epoch, epoch_summary)
421
  callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
422
  callback_early_stop.on_epoch_end(epoch, epoch_summary)
423
+ callback_lr.on_epoch_end(epoch,epoch_summary)
424
+
425
  print('End of epoch {}: '.format(epoch), ret, '\n')
426
 
427
  callback_tensorboard.on_train_end()
428
  callback_best_model.on_train_end()
429
  callback_save_checkpoint.on_train_end()
430
  callback_early_stop.on_train_end()
431
+ callback_lr.on_train_end()
432
  # 7. Wrap up
COMET/Evaluate_network.py CHANGED
@@ -17,9 +17,10 @@ import tensorflow as tf
17
  import numpy as np
18
  import pandas as pd
19
  import voxelmorph as vxm
 
20
 
21
  import DeepDeformationMapRegistration.utils.constants as C
22
- from DeepDeformationMapRegistration.utils.operators import min_max_norm
23
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
24
  from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
25
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
@@ -27,10 +28,14 @@ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimila
27
  from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
28
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
29
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
 
 
30
  from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
31
  from scipy.interpolate import RegularGridInterpolator
32
  from tqdm import tqdm
33
 
 
 
34
  import h5py
35
  import re
36
  from Brain_study.data_generator import BatchGenerator
@@ -44,6 +49,7 @@ import neurite as ne
44
 
45
 
46
  DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed'
 
47
  MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/checkpoints/best_model.h5'
48
  DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/'
49
 
@@ -56,6 +62,7 @@ if __name__ == '__main__':
56
  parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', default=DATASET)
57
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
58
  parser.add_argument('--outdirname', type=str, default='Evaluate')
 
59
  args = parser.parse_args()
60
  if args.model is not None:
61
  assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
@@ -82,6 +89,9 @@ if __name__ == '__main__':
82
  DATASET = args.dataset
83
  list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
84
  list_test_files.sort()
 
 
 
85
 
86
  with h5py.File(list_test_files[0], 'r') as f:
87
  image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
@@ -96,8 +106,8 @@ if __name__ == '__main__':
96
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
97
  config.allow_soft_placement = True
98
 
99
- sess = tf.Session(config=config)
100
- tf.keras.backend.set_session(sess)
101
 
102
  # Loss and metric functions. Common to all models
103
  loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
@@ -114,24 +124,24 @@ if __name__ == '__main__':
114
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
115
 
116
  ### METRICS GRAPH ###
117
- fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
118
- pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
119
- fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
120
- pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
121
 
122
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
123
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
124
  mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
125
  ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
126
- dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
127
- hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
128
- dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
129
  # hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
130
 
131
  # Needed for VxmDense type of network
132
  warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
133
 
134
- dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
135
 
136
  for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
137
  print('MODEL LOCATION: ', MODEL_FILE)
@@ -144,6 +154,14 @@ if __name__ == '__main__':
144
  os.makedirs(output_folder, exist_ok=True)
145
  print('DESTINATION FOLDER: ', output_folder)
146
 
 
 
 
 
 
 
 
 
147
  try:
148
  network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
149
  'VxmDense': vxm.networks.VxmDense,
@@ -172,10 +190,16 @@ if __name__ == '__main__':
172
  with open(metrics_file, 'w') as f:
173
  f.write(';'.join(csv_header)+'\n')
174
 
 
 
 
 
 
175
  ssim = ncc = mse = ms_ssim = dice = hd = 0
176
  with sess.as_default():
177
  sess.run(tf.global_variables_initializer())
178
  network.load_weights(MODEL_FILE, by_name=True)
 
179
  progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
180
  for step, in_batch in progress_bar:
181
  with h5py.File(in_batch, 'r') as f:
@@ -184,6 +208,8 @@ if __name__ == '__main__':
184
  fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
185
  mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
186
  fix_centroids = f['fix_centroids'][:]
 
 
187
 
188
  if network.name == 'vxm_dense_semi_supervised_seg':
189
  t0 = time.time()
@@ -196,29 +222,40 @@ if __name__ == '__main__':
196
  t1 = time.time()
197
 
198
  pred_img = min_max_norm(pred_img)
199
- mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=False)
200
  # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
201
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
202
 
 
 
 
 
 
 
 
 
203
  # I need the labels to be OHE to compute the segmentation metrics.
204
- dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
 
 
 
205
 
206
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
207
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
208
  fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
209
 
210
  ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
 
211
  ms_ssim = ms_ssim[0]
212
 
213
  # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
214
- upsample_scale = 128 / 64
215
- fix_centroids_isotropic = fix_centroids * upsample_scale
216
- # mov_centroids_isotropic = mov_centroids * upsample_scale
217
- pred_centroids_isotropic = pred_centroids * upsample_scale
218
-
219
- fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
220
- # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
221
- pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
222
  # Now we can measure the TRE in mm
223
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
224
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
@@ -249,15 +286,103 @@ if __name__ == '__main__':
249
  # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
250
  # plt.close()
251
 
252
- plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
253
- plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
254
- save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- progress_bar.set_description('SSIM {:.04f}\tG_DICE: {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice, dice_macro))
257
 
258
  print('Summary\n=======\n')
259
- print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).std(axis=0)))
 
 
 
 
 
 
260
  print('\n=======\n')
 
 
 
 
 
 
 
 
 
 
261
  tf.keras.backend.clear_session()
262
  # sess.close()
263
  del network
 
17
  import numpy as np
18
  import pandas as pd
19
  import voxelmorph as vxm
20
+ from voxelmorph.tf.layers import SpatialTransformer
21
 
22
  import DeepDeformationMapRegistration.utils.constants as C
23
+ from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
24
  from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
25
  from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
26
  from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
 
28
  from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
29
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
30
  from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
31
+ from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation, GaussianFilter
32
+ import medpy.metric as medpy_metrics
33
  from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
34
  from scipy.interpolate import RegularGridInterpolator
35
  from tqdm import tqdm
36
 
37
+ from scipy.ndimage import gaussian_filter, zoom
38
+
39
  import h5py
40
  import re
41
  from Brain_study.data_generator import BatchGenerator
 
49
 
50
 
51
  DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed'
52
+ DATASET_FR = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed/full_res'
53
  MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/checkpoints/best_model.h5'
54
  DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/'
55
 
 
62
  parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', default=DATASET)
63
  parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
64
  parser.add_argument('--outdirname', type=str, default='Evaluate')
65
+ parser.add_argument('--fullres', action='store_true', default=False)
66
  args = parser.parse_args()
67
  if args.model is not None:
68
  assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
 
89
  DATASET = args.dataset
90
  list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
91
  list_test_files.sort()
92
+ if args.fullres:
93
+ list_test_fr_files = [os.path.join(DATASET_FR, f) for f in os.listdir(DATASET_FR) if f.endswith('h5') and 'dm' not in f]
94
+ list_test_fr_files.sort()
95
 
96
  with h5py.File(list_test_files[0], 'r') as f:
97
  image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
 
106
  config.log_device_placement = False ## to log device placement (on which device the operation ran)
107
  config.allow_soft_placement = True
108
 
109
+ sess = tf.compat.v1.Session(config=config)
110
+ tf.compat.v1.keras.backend.set_session(sess)
111
 
112
  # Loss and metric functions. Common to all models
113
  loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
 
124
  GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
125
 
126
  ### METRICS GRAPH ###
127
+ fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
128
+ pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
129
+ fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
130
+ pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
131
 
132
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
133
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
134
  mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
135
  ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
136
+ # dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
137
+ # hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
138
+ # dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
139
  # hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
140
 
141
  # Needed for VxmDense type of network
142
  warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
143
 
144
+ dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
145
 
146
  for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
147
  print('MODEL LOCATION: ', MODEL_FILE)
 
154
  os.makedirs(output_folder, exist_ok=True)
155
  print('DESTINATION FOLDER: ', output_folder)
156
 
157
+ if args.fullres:
158
+ output_folder_fr = os.path.join(DATA_ROOT_DIR, args.outdirname, 'full_resolution') # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
159
+ # os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
160
+ if args.erase:
161
+ shutil.rmtree(output_folder_fr, ignore_errors=True)
162
+ os.makedirs(output_folder_fr, exist_ok=True)
163
+ print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
164
+
165
  try:
166
  network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
167
  'VxmDense': vxm.networks.VxmDense,
 
190
  with open(metrics_file, 'w') as f:
191
  f.write(';'.join(csv_header)+'\n')
192
 
193
+ if args.fullres:
194
+ metrics_file_fr = os.path.join(output_folder_fr, 'metrics.csv')
195
+ with open(metrics_file_fr, 'w') as f:
196
+ f.write(';'.join(csv_header) + '\n')
197
+
198
  ssim = ncc = mse = ms_ssim = dice = hd = 0
199
  with sess.as_default():
200
  sess.run(tf.global_variables_initializer())
201
  network.load_weights(MODEL_FILE, by_name=True)
202
+ network.summary(line_length=C.SUMMARY_LINE_LENGTH)
203
  progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
204
  for step, in_batch in progress_bar:
205
  with h5py.File(in_batch, 'r') as f:
 
208
  fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
209
  mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
210
  fix_centroids = f['fix_centroids'][:]
211
+ isotropic_shape = f['isotropic_shape'][:]
212
+ voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
213
 
214
  if network.name == 'vxm_dense_semi_supervised_seg':
215
  t0 = time.time()
 
222
  t1 = time.time()
223
 
224
  pred_img = min_max_norm(pred_img)
225
+ mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels+1), brain_study=False)
226
  # pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
227
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
228
 
229
+ # Up sample the segmentation masks to isotropic resolution
230
+ zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
231
+ pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
232
+ fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
233
+
234
+ pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
235
+ fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
236
+
237
  # I need the labels to be OHE to compute the segmentation metrics.
238
+ # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
239
+ dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
240
+ hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
241
+ dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
242
 
243
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
244
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
245
  fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
246
 
247
  ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
248
+ ssim = np.mean(ssim)
249
  ms_ssim = ms_ssim[0]
250
 
251
  # Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
252
+ fix_centroids_isotropic = fix_centroids * voxel_size
253
+ # mov_centroids_isotropic = mov_centroids * voxel_size
254
+ pred_centroids_isotropic = pred_centroids * voxel_size
255
+
256
+ # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
257
+ # # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
258
+ # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
 
259
  # Now we can measure the TRE in mm
260
  tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
261
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
 
286
  # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
287
  # plt.close()
288
 
289
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
290
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
291
+ save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
292
+
293
+ progress_bar.set_description('SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
294
+
295
+ if args.fullres:
296
+ with h5py.File(list_test_fr_files[step - 1], 'r') as f:
297
+ fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
298
+ mov_img = f['mov_image'][:][np.newaxis, ...]
299
+ fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
300
+ mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
301
+ fix_centroids = f['fix_centroids'][:]
302
+
303
+ # Up sample the displacement map to the full res
304
+ trf = scale_transformation(image_output_shape, fix_img.shape[1:-1])
305
+ disp_map_fr = resize_displacement_map(np.squeeze(disp_map), None, trf)[np.newaxis, ...]
306
+ disp_map_fr = gaussian_filter(disp_map_fr, 5)
307
+ # disp_mad_fr = sess.run(smooth_filter, feed_dict={'dm:0': disp_map_fr})
308
+
309
+ # Predicted image
310
+ pred_img_fr = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([mov_img, disp_map_fr]).eval()
311
+ pred_seg_fr = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([mov_seg, disp_map_fr]).eval()
312
+
313
+ # Predicted centroids
314
+ dm_interp_fr = DisplacementMapInterpolator(fix_img.shape[1:-1], 'griddata', step=2)
315
+ pred_centroids = dm_interp_fr(disp_map_fr, mov_centroids, backwards=True) + mov_centroids
316
+
317
+ # Metrics - segmentation
318
+ dice = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) / np.sum(fix_seg[..., l]) for l in range(nb_labels)])
319
+ hd = np.mean(safe_medpy_metric(pred_seg[0, ...], fix_seg[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
320
+ dice_macro = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) for l in range(nb_labels)])
321
+
322
+ pred_seg_card_fr = segmentation_ohe_to_cardinal(pred_seg_fr).astype(np.float32)
323
+ mov_seg_card_fr = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
324
+ fix_seg_card_fr = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
325
+
326
+ # Metrics - image
327
+ ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
328
+ {'fix_img:0': fix_img, 'pred_img:0': pred_img_fr})
329
+ ssim = np.mean(ssim)
330
+ ms_ssim = ms_ssim[0]
331
+
332
+ # Metrics - registration
333
+ tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
334
+
335
+ new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1 - t0, tre, len(missing_lbls),
336
+ missing_lbls]
337
+ with open(metrics_file_fr, 'a') as f:
338
+ f.write(';'.join(map(str, new_line)) + '\n')
339
+
340
+ save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
341
+ save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
342
+ save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
343
+ save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
344
+ save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
345
+ save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
346
+
347
+ # with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
348
+ # f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
349
+ # f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
350
+ # f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
351
+ # f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
352
+ # f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
353
+
354
+ # magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
355
+ # _ = plt.hist(magnitude.flatten())
356
+ # plt.title('Histogram of disp. magnitudes')
357
+ # plt.show(block=False)
358
+ # plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
359
+ # plt.close()
360
+
361
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, seg_batches=[fix_seg_card_fr, mov_seg_card_fr, pred_seg_card_fr], filename=os.path.join(output_folder_fr, '{:03d}_figures_seg.png'.format(step)), show=False, step=10)
362
+ plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, filename=os.path.join(output_folder_fr, '{:03d}_figures_img.png'.format(step)), show=False, step=10)
363
+ # save_disp_map_img(disp_map_fr, 'Displacement map', os.path.join(output_folder_fr, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=10)
364
 
365
+ progress_bar.set_description('[FR] SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
366
 
367
  print('Summary\n=======\n')
368
+ metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
369
+ print('\nAVG:\n')
370
+ print(metrics_df.mean(axis=0))
371
+ print('\nSTD:\n')
372
+ print(metrics_df.std(axis=0))
373
+ print('\nHD95perc:\n')
374
+ print(metrics_df['HD'].describe(percentiles=[.95]))
375
  print('\n=======\n')
376
+ if args.fullres:
377
+ print('Summary full resolution\n=======\n')
378
+ metrics_df = pd.read_csv(metrics_file_fr, sep=';', header=0)
379
+ print('\nAVG:\n')
380
+ print(metrics_df.mean(axis=0))
381
+ print('\nSTD:\n')
382
+ print(metrics_df.std(axis=0))
383
+ print('\nHD95perc:\n')
384
+ print(metrics_df['HD'].describe(percentiles=[.95]))
385
+ print('\n=======\n')
386
  tf.keras.backend.clear_session()
387
  # sess.close()
388
  del network
COMET/MultiTrain_config.py CHANGED
@@ -9,6 +9,8 @@ from shutil import copy2
9
  import os
10
  from datetime import datetime
11
  import DeepDeformationMapRegistration.utils.constants as C
 
 
12
  TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
13
 
14
  err = list()
@@ -58,10 +60,11 @@ if __name__ == '__main__':
58
  except KeyError as err:
59
  froozen_layers = None
60
  except NameError as err:
61
- froozen_layers = [trainConfig['freeze'].upper()]
 
62
  if froozen_layers is not None:
63
- assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in froozen_layers),\
64
- 'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
65
  froozen_layers = list(set(froozen_layers)) # Unique elements
66
 
67
  if augmentationConfig:
@@ -70,11 +73,20 @@ if __name__ == '__main__':
70
 
71
 
72
  # copy the configuration file to the destionation folder
73
- os.makedirs(output_folder, exist_ok=True)
74
  copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
75
 
76
- unet = [int(x) for x in trainConfig['unet'].split(',')] if trainConfig['unet'] else [16, 32, 64, 128, 256]
77
- head = [int(x) for x in trainConfig['head'].split(',')] if trainConfig['head'] else [16, 16]
 
 
 
 
 
 
 
 
 
78
 
79
  launch_train(dataset_folder=datasetConfig['train'],
80
  validation_folder=datasetConfig['validation'],
@@ -85,10 +97,12 @@ if __name__ == '__main__':
85
  simil=simil,
86
  segm=segm,
87
  max_epochs=eval(trainConfig['epochs']),
 
88
  early_stop_patience=eval(trainConfig['earlyStopPatience']),
89
  model_file=trainConfig['model'],
90
  freeze_layers=froozen_layers,
91
  acc_gradients=eval(trainConfig['accumulativeGradients']),
92
  batch_size=eval(trainConfig['batchSize']),
93
  unet=unet,
94
- head=head)
 
 
9
  import os
10
  from datetime import datetime
11
  import DeepDeformationMapRegistration.utils.constants as C
12
+ import re
13
+ from COMET.augmentation_constants import LAYER_SELECTION
14
  TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
15
 
16
  err = list()
 
60
  except KeyError as err:
61
  froozen_layers = None
62
  except NameError as err:
63
+ froozen_layers = list(filter(lambda x: x != '', re.split(';|\s|,|,\s|;\s', trainConfig['freeze'].upper())))
64
+
65
  if froozen_layers is not None:
66
+ assert all(s in LAYER_SELECTION.keys() for s in froozen_layers), \
67
+ 'Invalid option for "freeze". Expected one or several of: ' + ', '.join(LAYER_SELECTION.keys())
68
  froozen_layers = list(set(froozen_layers)) # Unique elements
69
 
70
  if augmentationConfig:
 
73
 
74
 
75
  # copy the configuration file to the destionation folder
76
+ os.makedirs(output_folder, exist_ok=True) # TODO: move this within the "resume" if case, and bring here the creation of the resume-output folder!
77
  copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
78
 
79
+ try:
80
+ unet = [int(x) for x in trainConfig['unet'].split(',')] if trainConfig['unet'] else [16, 32, 64, 128, 256]
81
+ head = [int(x) for x in trainConfig['head'].split(',')] if trainConfig['head'] else [16, 16]
82
+ except KeyError as err:
83
+ unet = [16, 32, 64, 128, 256]
84
+ head = [16, 16]
85
+
86
+ try:
87
+ resume_checkpoint = trainConfig['resumeCheckpoint']
88
+ except KeyError as e:
89
+ resume_checkpoint = None
90
 
91
  launch_train(dataset_folder=datasetConfig['train'],
92
  validation_folder=datasetConfig['validation'],
 
97
  simil=simil,
98
  segm=segm,
99
  max_epochs=eval(trainConfig['epochs']),
100
+ image_size=eval(trainConfig['imageSize']),
101
  early_stop_patience=eval(trainConfig['earlyStopPatience']),
102
  model_file=trainConfig['model'],
103
  freeze_layers=froozen_layers,
104
  acc_gradients=eval(trainConfig['accumulativeGradients']),
105
  batch_size=eval(trainConfig['batchSize']),
106
  unet=unet,
107
+ head=head,
108
+ resume=resume_checkpoint)
COMET/augmentation_constants.py CHANGED
@@ -1,4 +1,5 @@
1
  import numpy as np
 
2
 
3
  # Constants for augmentation layer
4
  # .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
@@ -31,4 +32,26 @@ LAYER_RANGES = {'INPUT': (IN_LAYERS),
31
  'ENCODER': (ENCONDER_LAYERS),
32
  'DECODER': (DECODER_LAYERS),
33
  'TOP': (TOP_LAYERS_ENC, TOP_LAYERS_DEC),
34
- 'BOTTOM': (BOTTOM_LAYERS)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import re
3
 
4
  # Constants for augmentation layer
5
  # .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
 
32
  'ENCODER': (ENCONDER_LAYERS),
33
  'DECODER': (DECODER_LAYERS),
34
  'TOP': (TOP_LAYERS_ENC, TOP_LAYERS_DEC),
35
+ 'BOTTOM': (BOTTOM_LAYERS)}
36
+
37
+ # LAYER names:
38
+ IN_LAYER_REGEXP = '.*input'
39
+ FC_LAYER_REGEXP = '.*final.*'
40
+ OUT_LAYER_REGEXP = '(?:flow|transformer)'
41
+ ENC_LAYER_REGEXP = '.*enc_(?:conv|pooling)_(\d).*'
42
+ DEC_LAYER_REGEXP = '.*dec_(?:conv|upsample)_(\d).*'
43
+ LEVEL_NUMBER = lambda x: re.match('.*(?:enc|dec)_(?:conv|upsample|pooling)_(\d).*', x)
44
+ IS_TOP_LEVEL = lambda x: int(LEVEL_NUMBER(x)[1]) < 3 if LEVEL_NUMBER(x) is not None else False or bool(re.match(FC_LAYER_REGEXP, x))
45
+ IS_BOTTOM_LEVEL = lambda x: int(LEVEL_NUMBER(x)[1]) >= 3 if LEVEL_NUMBER(x) is not None else False
46
+
47
+ LAYER_SELECTION = {'INPUT': lambda x: bool(re.match(IN_LAYER_REGEXP, x)),
48
+ 'FULLYCONNECTED': lambda x: bool(re.match(FC_LAYER_REGEXP, x)),
49
+ 'ENCODER': lambda x: bool(re.match(ENC_LAYER_REGEXP, x)),
50
+ 'DECODER': lambda x: bool(re.match(DEC_LAYER_REGEXP, x)),
51
+ 'TOP': lambda x: IS_TOP_LEVEL(x),
52
+ 'BOTTOM': lambda x: IS_BOTTOM_LEVEL(x)
53
+ }
54
+
55
+ # STUPID IDEA THAT COMPLICATES THINGS. The points was to allow combinations of the layer groups
56
+ # OR_GROUPS = ['ENCODER', 'DECODER', 'INPUT', 'OUTPUT', 'FULLYCONNECTED'] # These groups can be OR'ed with the AND_GROUPS and among them. E.g., Top layers of the encoder and decoder: ENCODER or DECODER and TOP
57
+ # AND_GROUPS = ['TOP', 'BOTTOM'] # These groups can be AND'ed with the OR_GROUPS and among them
COMET/format_dataset.py CHANGED
@@ -30,7 +30,7 @@ SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSL
30
  IMG_NAME_PATTERN = '(.*).nii.gz'
31
  SEG_NAME_PATTERN = '(.*).nii.gz'
32
 
33
- OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128'
34
 
35
 
36
  if __name__ == '__main__':
@@ -60,6 +60,7 @@ if __name__ == '__main__':
60
  seg = np.asarray(seg.dataobj)
61
 
62
  segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
 
63
  if args.crop:
64
  parenchyma = regionprops(seg[..., 0])[0]
65
  bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
@@ -72,10 +73,11 @@ if __name__ == '__main__':
72
  isot_shape = img.shape
73
 
74
  zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
75
-
76
  img = zoom(img, zoom_factors, order=3)
77
  if args.dilate_segmentations:
78
  seg = binary_dilation(seg, binary_ball, iterations=1)
 
79
  seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
80
  zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
81
  'scale_j': zoom_factors[1],
@@ -92,11 +94,15 @@ if __name__ == '__main__':
92
  h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
93
 
94
  h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
 
 
95
  h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
96
  h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
97
  h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:]) # Remove the 0 (background label)
98
  h5_file.create_dataset('isotropic_shape', data=isot_shape)
99
-
 
 
100
  print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
101
  h5_file.close()
102
 
 
30
  IMG_NAME_PATTERN = '(.*).nii.gz'
31
  SEG_NAME_PATTERN = '(.*).nii.gz'
32
 
33
+ OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/w_bboxes'
34
 
35
 
36
  if __name__ == '__main__':
 
60
  seg = np.asarray(seg.dataobj)
61
 
62
  segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
63
+ bbox = [0]*6
64
  if args.crop:
65
  parenchyma = regionprops(seg[..., 0])[0]
66
  bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
 
73
  isot_shape = img.shape
74
 
75
  zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
76
+ img_isotropic = np.copy(img)
77
  img = zoom(img, zoom_factors, order=3)
78
  if args.dilate_segmentations:
79
  seg = binary_dilation(seg, binary_ball, iterations=1)
80
+ seg_isotropic = np.copy(seg)
81
  seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
82
  zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
83
  'scale_j': zoom_factors[1],
 
94
  h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
95
 
96
  h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
97
+ h5_file.create_dataset('image_isotropic', data=img_isotropic[..., np.newaxis], dtype=np.float32)
98
+ h5_file.create_dataset('segmentation_isotropic', data=seg_isotropic.astype(np.uint8), dtype=np.uint8)
99
  h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
100
  h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
101
  h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:]) # Remove the 0 (background label)
102
  h5_file.create_dataset('isotropic_shape', data=isot_shape)
103
+ if args.crop:
104
+ h5_file.create_dataset('bounding_box_origin', data=bbox[:3])
105
+ h5_file.create_dataset('bounding_box_shape', data=bbox[3:] - bbox[:3])
106
  print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
107
  h5_file.close()
108
 
DeepDeformationMapRegistration/losses.py CHANGED
@@ -626,7 +626,8 @@ class StructuralSimilarityGaussian:
626
  self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
627
  lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
628
  [iterator, g_1d],
629
- [iterator.get_shape(), tf.TensorShape([None]*self.dim)] # Shape invariants
 
630
  )[-1]
631
 
632
  self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
@@ -759,32 +760,41 @@ class GeneralizedDICEScore:
759
  Learning Los Function for Highly Unbalanced Segmentations" https://arxiv.org/abs/1707.03237
760
  :param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
761
  """
 
 
762
  if input_shape[-1] > 1:
763
- self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
764
- self.hot_encode = False
 
 
 
765
  elif num_labels is not None:
766
- self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), num_labels]
767
- self.one_hot_enc_shape = [-1, *input_shape[:-1]]
768
- self.hot_encode = True
769
- warnings.warn('Differentiable one-hot encoding not yet implemented')
 
 
 
770
  else:
771
- raise ValueError('If input_shape is not one hot encoded, then num_labels must be provided')
 
772
 
773
  def one_hot_encoding(self, in_img, name=''):
774
  # TODO: Test if differentiable!
775
  labels, indices = tf.unique(tf.reshape(in_img, [-1]), tf.int32, name=name+'_unique')
776
- one_hot = tf.one_hot(indices, tf.size(labels), name=name + '_one_hot')
777
- one_hot = tf.reshape(one_hot, self.one_hot_enc_shape + [tf.size(labels)], name=name + '_reshape')
778
- one_hot = tf.slice(one_hot, [0]*len(self.one_hot_enc_shape) + [1], [-1]*(len(self.one_hot_enc_shape) + 1),
779
  name=name + '_remove_bg')
780
  return one_hot
781
 
782
  def weigthed_dice(self, y_true, y_pred):
783
  # y_true = [B, -1, L]
784
  # y_pred = [B, -1, L]
785
- if self.hot_encode:
786
- y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
787
- y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
788
  y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
789
  y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
790
 
@@ -793,14 +803,14 @@ class GeneralizedDICEScore:
793
  w = tf.math.divide_no_nan(1., tf.pow(size_y_true, 2), name='GDICE_weight')
794
  numerator = w * tf.reduce_sum(y_true * y_pred, axis=1)
795
  denominator = w * (size_y_true + size_y_pred)
796
- return tf.div_no_nan(2 * tf.reduce_sum(numerator, axis=-1), tf.reduce_sum(denominator, axis=-1))
797
 
798
  def macro_dice(self, y_true, y_pred):
799
  # y_true = [B, -1, L]
800
  # y_pred = [B, -1, L]
801
- if self.hot_encode:
802
- y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
803
- y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
804
  y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
805
  y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
806
 
@@ -808,7 +818,7 @@ class GeneralizedDICEScore:
808
  size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
809
  numerator = tf.reduce_sum(y_true * y_pred, axis=1)
810
  denominator = (size_y_true + size_y_pred)
811
- return tf.div_no_nan(2 * numerator, denominator)
812
 
813
  @function_decorator('GeneralizeDICE__loss')
814
  def loss(self, y_true, y_pred):
 
626
  self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
627
  lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
628
  [iterator, g_1d],
629
+ [iterator.get_shape(), tf.TensorShape([None]*self.dim)], # Shape invariants
630
+ back_prop=False,
631
  )[-1]
632
 
633
  self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
 
760
  Learning Los Function for Highly Unbalanced Segmentations" https://arxiv.org/abs/1707.03237
761
  :param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
762
  """
763
+ self.smooth = 1e-10 # If y_pred = y_true = null -> dice should be 1
764
+ self.num_labels = num_labels
765
  if input_shape[-1] > 1:
766
+ try:
767
+ self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
768
+ except TypeError as err:
769
+ self.flat_shape = [-1, None, input_shape[-1]]
770
+ self.cardinal_encoded = False
771
  elif num_labels is not None:
772
+ try:
773
+ self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
774
+ except TypeError as err:
775
+ self.flat_shape = [-1, None, input_shape[-1]]
776
+ self.cardinal_enc_shape = [-1, *input_shape[:-1]]
777
+ self.cardinal_encoded = True
778
+ warnings.warn('Differentiable cardinal encoding not yet implemented')
779
  else:
780
+ raise ValueError('If input_shape does not correspond to cardinally encoded,'
781
+ 'then num_labels must be provided')
782
 
783
  def one_hot_encoding(self, in_img, name=''):
784
  # TODO: Test if differentiable!
785
  labels, indices = tf.unique(tf.reshape(in_img, [-1]), tf.int32, name=name+'_unique')
786
+ one_hot = tf.one_hot(indices, self.num_labels, name=name + '_one_hot')
787
+ one_hot = tf.reshape(one_hot, self.cardinal_enc_shape + [self.num_labels], name=name + '_reshape')
788
+ one_hot = tf.slice(one_hot, [0] * len(self.cardinal_enc_shape) + [1], [-1] * (len(self.cardinal_enc_shape) + 1),
789
  name=name + '_remove_bg')
790
  return one_hot
791
 
792
  def weigthed_dice(self, y_true, y_pred):
793
  # y_true = [B, -1, L]
794
  # y_pred = [B, -1, L]
795
+ # if self.cardinal_encoded:
796
+ # y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
797
+ # y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
798
  y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
799
  y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
800
 
 
803
  w = tf.math.divide_no_nan(1., tf.pow(size_y_true, 2), name='GDICE_weight')
804
  numerator = w * tf.reduce_sum(y_true * y_pred, axis=1)
805
  denominator = w * (size_y_true + size_y_pred)
806
+ return tf.div_no_nan(2 * tf.reduce_sum(numerator, axis=-1) + self.smooth, tf.reduce_sum(denominator, axis=-1) + self.smooth)
807
 
808
  def macro_dice(self, y_true, y_pred):
809
  # y_true = [B, -1, L]
810
  # y_pred = [B, -1, L]
811
+ # if self.cardinal_encoded:
812
+ # y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
813
+ # y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
814
  y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
815
  y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
816
 
 
818
  size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
819
  numerator = tf.reduce_sum(y_true * y_pred, axis=1)
820
  denominator = (size_y_true + size_y_pred)
821
+ return tf.div_no_nan(2 * numerator + self.smooth, denominator + self.smooth)
822
 
823
  @function_decorator('GeneralizeDICE__loss')
824
  def loss(self, y_true, y_pred):
SoA_methods/eval_ants.py CHANGED
@@ -19,6 +19,8 @@ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
19
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
20
  import DeepDeformationMapRegistration.utils.constants as C
21
 
 
 
22
  import voxelmorph as vxm
23
 
24
  from argparse import ArgumentParser
@@ -35,22 +37,24 @@ WARPED_FIX = 'warpedfixout'
35
  FWD_TRFS = 'fwdtransforms'
36
  INV_TRFS = 'invtransforms'
37
 
38
- os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
39
- os.environ['CUDA_VISIBLE_DEVICES'] = '2'
40
 
41
  if __name__ == '__main__':
42
  parser = ArgumentParser()
43
  parser.add_argument('--dataset', type=str, help='Directory with the images')
44
  parser.add_argument('--outdir', type=str, help='Output directory')
 
45
  args = parser.parse_args()
46
 
 
 
 
47
  os.makedirs(args.outdir, exist_ok=True)
48
  os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
49
  os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
50
  dataset_files = os.listdir(args.dataset)
51
  dataset_files.sort()
52
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
53
-
54
  dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
55
 
56
  f = h5py.File(dataset_files[0], 'r')
@@ -67,18 +71,18 @@ if __name__ == '__main__':
67
  HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
68
  GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]
69
 
70
- fix_img_ph = tf.placeholder(tf.float32, (1, *image_shape, 1), name='fix_img')
71
- pred_img_ph = tf.placeholder(tf.float32, (1, *image_shape, 1), name='pred_img')
72
- fix_seg_ph = tf.placeholder(tf.float32, (1, *image_shape, nb_labels), name='fix_seg')
73
- pred_seg_ph = tf.placeholder(tf.float32, (1, *image_shape, nb_labels), name='pred_seg')
74
 
75
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
76
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
77
  mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
78
  ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
79
- dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
80
- hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
81
- dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
82
 
83
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
84
  config.gpu_options.allow_growth = True
@@ -90,7 +94,7 @@ if __name__ == '__main__':
90
  ####
91
  os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "{:d}".format(os.cpu_count()) #https://github.com/ANTsX/ANTsPy/issues/261
92
  print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
93
- dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
94
  # Header of the metrics csv file
95
  csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
96
 
@@ -147,10 +151,17 @@ if __name__ == '__main__':
147
 
148
  dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
149
  with sess.as_default():
150
- dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
151
- {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
152
- 'pred_seg:0': pred_seg[np.newaxis, ...] # Batch axis
153
- })
 
 
 
 
 
 
 
154
 
155
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
156
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
@@ -160,26 +171,28 @@ if __name__ == '__main__':
160
  {'fix_img:0': fix_img[np.newaxis, ...], # Batch axis
161
  'pred_img:0': pred_img[np.newaxis, ...] # Batch axis
162
  })
 
163
  ms_ssim = ms_ssim[0]
164
 
165
  # TRE
166
  disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
 
167
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
168
- upsample_scale = 128 / 64
169
- fix_centroids_isotropic = fix_centroids * upsample_scale
170
- pred_centroids_isotropic = pred_centroids * upsample_scale
171
 
172
- fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
173
- pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
174
- tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
175
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
176
 
177
- dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
178
- new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
179
- t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
180
- tre]
181
- with open(metrics_file[reg_method], 'a') as f:
182
- f.write(';'.join(map(str, new_line))+'\n')
183
 
184
  save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
185
  save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
@@ -188,12 +201,17 @@ if __name__ == '__main__':
188
  save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
189
  save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
190
 
191
- plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...], pred_img[np.newaxis, ...], os.path.join(args.outdir, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
192
- plot_predictions(fix_seg[np.newaxis, ...], mov_seg[np.newaxis, ...], disp_map[np.newaxis, ...], pred_seg[np.newaxis, ...], os.path.join(args.outdir, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
193
  save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
194
 
195
  for k in metrics_file.keys():
196
  print('Summary {}\n=======\n'.format(k))
197
- print('\nAVG:\n' + str(pd.read_csv(metrics_file[k], sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(
198
- pd.read_csv(metrics_file[k], sep=';', header=0).std(axis=0)))
 
 
 
 
 
199
  print('\n=======\n')
 
19
  from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
20
  import DeepDeformationMapRegistration.utils.constants as C
21
 
22
+ import medpy.metric as medpy_metrics
23
+
24
  import voxelmorph as vxm
25
 
26
  from argparse import ArgumentParser
 
37
  FWD_TRFS = 'fwdtransforms'
38
  INV_TRFS = 'invtransforms'
39
 
 
 
40
 
41
  if __name__ == '__main__':
42
  parser = ArgumentParser()
43
  parser.add_argument('--dataset', type=str, help='Directory with the images')
44
  parser.add_argument('--outdir', type=str, help='Output directory')
45
+ parser.add_argument('--gpu', type=int, help='GPU')
46
  args = parser.parse_args()
47
 
48
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
50
+
51
  os.makedirs(args.outdir, exist_ok=True)
52
  os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
53
  os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
54
  dataset_files = os.listdir(args.dataset)
55
  dataset_files.sort()
56
  dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
57
+ dataset_files.sort()
58
  dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
59
 
60
  f = h5py.File(dataset_files[0], 'r')
 
71
  HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
72
  GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]
73
 
74
+ fix_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
75
+ pred_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
76
+ fix_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
77
+ pred_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
78
 
79
  ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
80
  ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
81
  mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
82
  ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
83
+ # dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
84
+ # hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
85
+ # dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
86
 
87
  config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
88
  config.gpu_options.allow_growth = True
 
94
  ####
95
  os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "{:d}".format(os.cpu_count()) #https://github.com/ANTsX/ANTsPy/issues/261
96
  print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
97
+ # dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
98
  # Header of the metrics csv file
99
  csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
100
 
 
151
 
152
  dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
153
  with sess.as_default():
154
+ dice = np.mean([medpy_metrics.dc(pred_seg[np.newaxis, ..., l], fix_seg[np.newaxis,..., l]) / np.sum(
155
+ fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
156
+ hd = np.mean(
157
+ [medpy_metrics.hd(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
158
+ dice_macro = np.mean(
159
+ [medpy_metrics.dc(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
160
+
161
+ # dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
162
+ # {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
163
+ # 'pred_seg:0': pred_seg[np.newaxis, ...] # Batch axis
164
+ # })
165
 
166
  pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
167
  mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
 
171
  {'fix_img:0': fix_img[np.newaxis, ...], # Batch axis
172
  'pred_img:0': pred_img[np.newaxis, ...] # Batch axis
173
  })
174
+ ssim = np.mean(ssim)
175
  ms_ssim = ms_ssim[0]
176
 
177
  # TRE
178
  disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
179
+ dm_interp = DisplacementMapInterpolator(fix_img.shape[:-1], 'griddata', step=2)
180
  pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
181
+ # upsample_scale = 128 / 64
182
+ # fix_centroids_isotropic = fix_centroids * upsample_scale
183
+ # pred_centroids_isotropic = pred_centroids * upsample_scale
184
 
185
+ # fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
186
+ # pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
187
+ tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
188
  tre = np.mean([v for v in tre_array if not np.isnan(v)])
189
 
190
+ # dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
191
+ # new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
192
+ # t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
193
+ # tre]
194
+ # with open(metrics_file[reg_method], 'a') as f:
195
+ # f.write(';'.join(map(str, new_line))+'\n')
196
 
197
  save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
198
  save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
 
201
  save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
202
  save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
203
 
204
+ plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], seg_batches=[fix_seg_card[np.newaxis, ...], mov_seg_card[np.newaxis, ...], pred_seg_card[np.newaxis, ...]], filename=os.path.join(args.outdir, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
205
+ plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], filename=os.path.join(args.outdir, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
206
  save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
207
 
208
  for k in metrics_file.keys():
209
  print('Summary {}\n=======\n'.format(k))
210
+ metrics_df = pd.read_csv(metrics_file[k], sep=';', header=0)
211
+ print('\nAVG:\n')
212
+ print(metrics_df.mean(axis=0))
213
+ print('\nSTD:\n')
214
+ print(metrics_df.std(axis=0))
215
+ print('\nHD95perc:\n')
216
+ print(metrics_df['HD'].describe(percentiles=[.95]))
217
  print('\n=======\n')