Merge branch 'jpdefrutos:master' into master
Browse files- Brain_study/ABSTRACT/figures.py +158 -0
- Brain_study/ABSTRACT/format_tables_abstract.py +111 -0
- Brain_study/Build_test_set.py +34 -10
- Brain_study/Evaluate_network.py +6 -5
- Brain_study/Evaluate_network__test_fixed.py +49 -24
- Brain_study/MultiTrain_config.py +6 -0
- Brain_study/Train_Baseline.py +51 -16
- Brain_study/Train_SegmentationGuided.py +61 -18
- Brain_study/Train_UncertaintyWeighted.py +54 -15
- Brain_study/data_generator.py +84 -43
- Brain_study/format_dataset.py +27 -4
- Brain_study/split_dataset.py +43 -22
- COMET/Build_test_set.py +18 -12
- COMET/COMET_train.py +49 -33
- COMET/COMET_train_UW.py +51 -24
- COMET/COMET_train_seggguided.py +54 -36
- COMET/Evaluate_network.py +151 -26
- COMET/MultiTrain_config.py +21 -7
- COMET/augmentation_constants.py +24 -1
- COMET/format_dataset.py +9 -3
- DeepDeformationMapRegistration/layers/augmentation.py +1 -1
- DeepDeformationMapRegistration/layers/upsampling.py +2 -0
- DeepDeformationMapRegistration/losses.py +31 -21
- DeepDeformationMapRegistration/utils/constants.py +3 -2
- DeepDeformationMapRegistration/utils/misc.py +78 -15
- DeepDeformationMapRegistration/utils/operators.py +15 -0
- DeepDeformationMapRegistration/utils/visualization.py +43 -14
- SoA_methods/eval_ants.py +49 -31
- requirements.txt +12 -7
Brain_study/ABSTRACT/figures.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import re
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import nibabel as nib
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from matplotlib import cm
|
10 |
+
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
|
11 |
+
|
12 |
+
segm_cm = cm.get_cmap('Dark2', 256)
|
13 |
+
segm_cm = segm_cm(np.linspace(0, 1, 28))
|
14 |
+
segm_cm[0, :] = np.asarray([0, 0, 0, 0])
|
15 |
+
segm_cm = ListedColormap(segm_cm)
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
|
20 |
+
parser.add_argument('-d', '--dir', type=str, help='Directories where the models are stored', default=None)
|
21 |
+
parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
|
22 |
+
parser.add_argument('--overwrite', type=bool, default=True)
|
23 |
+
args = parser.parse_args()
|
24 |
+
assert args.dir is not None, "No directories provided. Stopping"
|
25 |
+
|
26 |
+
list_fix_img = list()
|
27 |
+
list_mov_img = list()
|
28 |
+
list_fix_seg = list()
|
29 |
+
list_mov_seg = list()
|
30 |
+
list_pred_img = list()
|
31 |
+
list_pred_seg = list()
|
32 |
+
print('Fetching data...')
|
33 |
+
for r, d, f in os.walk(args.dir):
|
34 |
+
if os.path.split(r)[1] == 'Evaluation_paper':
|
35 |
+
for name in f:
|
36 |
+
if re.search('^050', name) and name.endswith('nii.gz'):
|
37 |
+
if re.search('fix_img', name) and name.endswith('nii.gz'):
|
38 |
+
list_fix_img.append(os.path.join(r, name))
|
39 |
+
elif re.search('mov_img', name):
|
40 |
+
list_mov_img.append(os.path.join(r, name))
|
41 |
+
elif re.search('fix_seg', name):
|
42 |
+
list_fix_seg.append(os.path.join(r, name))
|
43 |
+
elif re.search('mov_seg', name):
|
44 |
+
list_mov_seg.append(os.path.join(r, name))
|
45 |
+
elif re.search('pred_img', name):
|
46 |
+
list_pred_img.append(os.path.join(r, name))
|
47 |
+
elif re.search('pred_seg', name):
|
48 |
+
list_pred_seg.append(os.path.join(r, name))
|
49 |
+
|
50 |
+
# Figure: all coronal views
|
51 |
+
# Fix img | Mov img
|
52 |
+
# BASELINE 1 | BASELINE 2 | SEGGUIDED
|
53 |
+
# UW 1 | UW 2 | UW 3
|
54 |
+
list_fix_img.sort()
|
55 |
+
list_fix_seg.sort()
|
56 |
+
list_mov_img.sort()
|
57 |
+
list_mov_seg.sort()
|
58 |
+
list_pred_img.sort()
|
59 |
+
list_pred_seg.sort()
|
60 |
+
print('Making Test_data.png...')
|
61 |
+
selected_slice = 30
|
62 |
+
fix_img = np.asarray(nib.load(list_fix_img[0]).dataobj)[..., selected_slice, 0]
|
63 |
+
mov_img = np.asarray(nib.load(list_mov_img[0]).dataobj)[..., selected_slice, 0]
|
64 |
+
fix_seg = np.asarray(nib.load(list_fix_seg[0]).dataobj)[..., selected_slice, 0]
|
65 |
+
mov_seg = np.asarray(nib.load(list_mov_seg[0]).dataobj)[..., selected_slice, 0]
|
66 |
+
|
67 |
+
fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(9, 3), dpi=200)
|
68 |
+
|
69 |
+
for i, (img, title) in enumerate(zip([(fix_img, fix_seg), (mov_img, mov_seg)],
|
70 |
+
[('Fixed image', 'Fixed Segms.'), ('Moving image', 'Moving Segms.')])):
|
71 |
+
|
72 |
+
ax[i].imshow(img[0], origin='lower', cmap='Greys_r')
|
73 |
+
ax[i+2].imshow(img[0], origin='lower', cmap='Greys_r')
|
74 |
+
ax[i+2].imshow(img[1], origin='lower', cmap=segm_cm, alpha=0.6)
|
75 |
+
|
76 |
+
ax[i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
|
77 |
+
ax[i+2].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
|
78 |
+
|
79 |
+
ax[i].set_xlabel(title[0], fontsize=16)
|
80 |
+
ax[i+2].set_xlabel(title[1], fontsize=16)
|
81 |
+
|
82 |
+
plt.tight_layout()
|
83 |
+
if not args.overwrite and os.path.exists(os.path.join(args.output, 'Test_data.png')):
|
84 |
+
warnings.warn('File Test_data.png already exists. Skipping')
|
85 |
+
else:
|
86 |
+
plt.savefig(os.path.join(args.output, 'Test_data.png'), format='png')
|
87 |
+
plt.close()
|
88 |
+
|
89 |
+
print('Making Pred_data.png...')
|
90 |
+
fig, ax = plt.subplots(nrows=2, ncols=6, figsize=(9, 3), dpi=200)
|
91 |
+
|
92 |
+
for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
|
93 |
+
img = np.asarray(nib.load(pred_img_path).dataobj)[..., selected_slice, 0]
|
94 |
+
seg = np.asarray(nib.load(pred_seg_path).dataobj)[..., selected_slice, 0]
|
95 |
+
|
96 |
+
ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
|
97 |
+
ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
|
98 |
+
ax[1, i].imshow(seg, origin='lower', cmap=segm_cm, alpha=0.6)
|
99 |
+
|
100 |
+
ax[0, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
|
101 |
+
ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
|
102 |
+
|
103 |
+
model = re.search('((UW|SEGGUIDED|BASELINE).*)_{2,}MET', pred_img_path).group(1).rstrip('_')
|
104 |
+
model = model.replace('_Lsim', ' ')
|
105 |
+
model = model.replace('_Lseg', ' ')
|
106 |
+
model = model.replace('_L', ' ')
|
107 |
+
model = model.replace('_', ' ')
|
108 |
+
model = model.upper()
|
109 |
+
model = ' '.join(model.split())
|
110 |
+
|
111 |
+
ax[1, i].set_xlabel(model, fontsize=9)
|
112 |
+
plt.tight_layout()
|
113 |
+
if not args.overwrite and os.path.exists(os.path.join(args.output, 'Pred_data.png')):
|
114 |
+
warnings.warn('File Pred_data.png already exists. Skipping')
|
115 |
+
else:
|
116 |
+
plt.savefig(os.path.join(args.output, 'Pred_data.png'), format='png')
|
117 |
+
plt.close()
|
118 |
+
|
119 |
+
print('Making Pred_data_large.png...')
|
120 |
+
fig, ax = plt.subplots(nrows=2, ncols=8, figsize=(9, 3), dpi=200)
|
121 |
+
list_pred_img = [list_mov_img[0]] + list_pred_img
|
122 |
+
list_pred_img = [list_fix_img[0]] + list_pred_img
|
123 |
+
list_pred_seg = [list_mov_seg[0]] + list_pred_seg
|
124 |
+
list_pred_seg = [list_fix_seg[0]] + list_pred_seg
|
125 |
+
|
126 |
+
for i, (pred_img_path, pred_seg_path) in enumerate(zip(list_pred_img, list_pred_seg)):
|
127 |
+
img = np.asarray(nib.load(pred_img_path).dataobj)[..., selected_slice, 0]
|
128 |
+
seg = np.asarray(nib.load(pred_seg_path).dataobj)[..., selected_slice, 0]
|
129 |
+
|
130 |
+
ax[0, i].imshow(img, origin='lower', cmap='Greys_r')
|
131 |
+
ax[1, i].imshow(img, origin='lower', cmap='Greys_r')
|
132 |
+
ax[1, i].imshow(seg, origin='lower', cmap=segm_cm, alpha=0.6)
|
133 |
+
|
134 |
+
ax[0, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
|
135 |
+
ax[1, i].tick_params(axis='both', which='both', bottom=False, left=False, labelleft=False, labelbottom=False)
|
136 |
+
|
137 |
+
if i > 1:
|
138 |
+
model = re.search('((UW|SEGGUIDED|BASELINE).*)_{2,}MET', pred_img_path).group(1).rstrip('_')
|
139 |
+
model = model.replace('_Lsim', ' ')
|
140 |
+
model = model.replace('_Lseg', ' ')
|
141 |
+
model = model.replace('_L', ' ')
|
142 |
+
model = model.replace('_', ' ')
|
143 |
+
model = model.upper()
|
144 |
+
model = ' '.join(model.split())
|
145 |
+
elif i == 0:
|
146 |
+
model = 'Moving image'
|
147 |
+
else:
|
148 |
+
model = 'Fixed image'
|
149 |
+
|
150 |
+
ax[1, i].set_xlabel(model, fontsize=7)
|
151 |
+
plt.tight_layout()
|
152 |
+
if not args.overwrite and os.path.exists(os.path.join(args.output, 'Pred_data_large.png')):
|
153 |
+
warnings.warn('File Pred_data.png already exists. Skipping')
|
154 |
+
else:
|
155 |
+
plt.savefig(os.path.join(args.output, 'Pred_data_large.png'), format='png')
|
156 |
+
plt.close()
|
157 |
+
|
158 |
+
print('...done!')
|
Brain_study/ABSTRACT/format_tables_abstract.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import re
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
DICT_MODEL_NAMES = {'BASELINE': 'BL',
|
10 |
+
'SEGGUIDED': 'SG',
|
11 |
+
'UW': 'UW'}
|
12 |
+
|
13 |
+
DICT_METRICS_NAMES = {'NCC': 'N',
|
14 |
+
'SSIM': 'S',
|
15 |
+
'DICE': 'D',
|
16 |
+
'DICE MACRO': 'D',
|
17 |
+
'HD': 'H', }
|
18 |
+
|
19 |
+
|
20 |
+
def row_name(in_path: str):
|
21 |
+
model = re.search('((UW|SEGGUIDED|BASELINE).*)_\d', in_path).group(1).rstrip('_')
|
22 |
+
model = model.replace('_Lsim', '')
|
23 |
+
model = model.replace('_Lseg', '')
|
24 |
+
model = model.replace('_L', '')
|
25 |
+
model = model.replace('_', ' ')
|
26 |
+
model = model.upper()
|
27 |
+
elements = model.split()
|
28 |
+
model = elements[0]
|
29 |
+
metrics = list()
|
30 |
+
model = DICT_MODEL_NAMES[model]
|
31 |
+
for m in elements[1:]:
|
32 |
+
if m != 'MACRO':
|
33 |
+
metrics.append(DICT_METRICS_NAMES[m])
|
34 |
+
|
35 |
+
return '{}-{}'.format(model, ''.join(metrics))
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
|
41 |
+
parser.add_argument('-d', '--dir', nargs='+', type=str, help='List of directories where metrics.csv file is',
|
42 |
+
default=None)
|
43 |
+
parser.add_argument('-o', '--output', type=str, help='Output directory', default=os.getcwd())
|
44 |
+
parser.add_argument('--overwrite', type=bool, default=True)
|
45 |
+
parser.add_argument('--filename', type=str, help='Output file name', default='metrics')
|
46 |
+
parser.add_argument('--removemetrics', nargs='+', type=str, default=None)
|
47 |
+
args = parser.parse_args()
|
48 |
+
assert args.dir is not None, "No directories provided. Stopping"
|
49 |
+
|
50 |
+
if len(args.dir) == 1:
|
51 |
+
list_files = list()
|
52 |
+
for r, d, f in os.walk(args.dir[0]):
|
53 |
+
for name in f:
|
54 |
+
if 'metrics.csv' == name: # and os.path.split(r)[1] == 'Evaluation_paper':
|
55 |
+
list_files.append(os.path.join(r, name))
|
56 |
+
else:
|
57 |
+
list_files = [os.path.join(d, 'metrics.csv') for d in args.dir]
|
58 |
+
|
59 |
+
for d in list_files:
|
60 |
+
assert os.path.exists(d), "Missing metrics.csv file in: " + os.path.split(d)[0]
|
61 |
+
|
62 |
+
print('Metric files found: {}'.format(list_files))
|
63 |
+
|
64 |
+
dataframes = list()
|
65 |
+
if len(list_files):
|
66 |
+
for d in list_files:
|
67 |
+
df = pd.read_csv(d, sep=';', header=0)
|
68 |
+
model = row_name(d)
|
69 |
+
|
70 |
+
df.insert(0, "Model", model)
|
71 |
+
df.drop(columns=list(df.filter(regex='Unnamed')), inplace=True)
|
72 |
+
df.drop(columns=['File', 'MSE', 'No_missing_lbls'], inplace=True)
|
73 |
+
dataframes.append(df)
|
74 |
+
|
75 |
+
full_table = pd.concat(dataframes)
|
76 |
+
if args.removemetrics is not None:
|
77 |
+
full_table = full_table.drop(columns=args.removemetrics)
|
78 |
+
mean_table = full_table.copy()
|
79 |
+
# mean_table.insert(column='Type', value='Avg.', loc=1)
|
80 |
+
# mean_table = mean_table.groupby(['Type', 'Model']).mean().round(3)
|
81 |
+
mean_table = mean_table.groupby(['Model'])
|
82 |
+
hd95 = mean_table.HD.quantile(0.95).map('{:.2f}'.format)
|
83 |
+
mean_table = mean_table.mean().round(3)
|
84 |
+
|
85 |
+
std_table = full_table.copy()
|
86 |
+
# std_table.insert(column='Type', value='STD', loc=1)
|
87 |
+
# std_table = std_table.groupby(['Type', 'Model']).std().round(3)
|
88 |
+
std_table = std_table.groupby(['Model']).std().round(3)
|
89 |
+
|
90 |
+
# metrics_table = pd.concat([mean_table, std_table]).swaplevel(axis='rows')
|
91 |
+
metrics_table = mean_table.applymap('{:.2f}'.format) + u"\u00B1" + std_table.applymap('{:.2f}'.format)
|
92 |
+
time_col = metrics_table.pop('Time')
|
93 |
+
metrics_table.insert(len(metrics_table.columns), 'Time', time_col)
|
94 |
+
metrics_table.insert(5, 'HD 95%ile', hd95)
|
95 |
+
|
96 |
+
metrics_file = os.path.join(args.output, args.filename + '.tex')
|
97 |
+
if os.path.exists(metrics_file) and args.overwrite:
|
98 |
+
shutil.rmtree(metrics_file, ignore_errors=True)
|
99 |
+
metrics_table.to_latex(metrics_file,
|
100 |
+
column_format='l' + 'c' * len(metrics_table.columns),
|
101 |
+
caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DICE and HD. As well as the number of missing labels in the predicted images.')
|
102 |
+
elif os.path.exists(metrics_file):
|
103 |
+
warnings.warn('File {} already exists. Skipping'.format(metrics_file))
|
104 |
+
else:
|
105 |
+
metrics_table.to_latex(metrics_file,
|
106 |
+
column_format='l' + 'c' * len(metrics_table.columns),
|
107 |
+
caption='Average and standard deviation of the metrics: MSE, NCC, SSIM, DICE and HD. As well as the number of missing labels in the predicted images.')
|
108 |
+
|
109 |
+
print('Done')
|
110 |
+
else:
|
111 |
+
print('No files found in {}!'.format(args.dir))
|
Brain_study/Build_test_set.py
CHANGED
@@ -18,7 +18,8 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
18 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
20 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
21 |
-
from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
|
|
|
22 |
from tqdm import tqdm
|
23 |
|
24 |
from Brain_study.data_generator import BatchGenerator
|
@@ -37,9 +38,15 @@ POINTS = None
|
|
37 |
MISSING_CENTROID = np.asarray([[np.nan]*3])
|
38 |
|
39 |
|
40 |
-
def get_mov_centroids(fix_seg, disp_map):
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return fix_centroids, fix_centroids + disp, disp
|
44 |
|
45 |
|
@@ -50,6 +57,7 @@ if __name__ == '__main__':
|
|
50 |
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
51 |
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
|
52 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
|
|
53 |
args = parser.parse_args()
|
54 |
|
55 |
assert args.dataset != '', "Missing original dataset dataset"
|
@@ -70,12 +78,20 @@ if __name__ == '__main__':
|
|
70 |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
71 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
72 |
|
73 |
-
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
|
74 |
|
75 |
img_generator = data_generator.get_train_generator()
|
76 |
nb_labels = len(img_generator.get_segmentation_labels())
|
77 |
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# Build model
|
80 |
|
81 |
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
|
@@ -102,19 +118,27 @@ if __name__ == '__main__':
|
|
102 |
return_displacement_map=True)
|
103 |
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
|
104 |
|
|
|
|
|
|
|
|
|
105 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
106 |
config.gpu_options.allow_growth = True
|
107 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
108 |
|
|
|
|
|
109 |
sess = tf.Session(config=config)
|
110 |
tf.keras.backend.set_session(sess)
|
111 |
with sess.as_default():
|
112 |
sess.run(tf.global_variables_initializer())
|
113 |
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
|
114 |
-
for step, (in_batch, _) in progress_bar:
|
115 |
-
fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
|
|
|
|
|
116 |
|
117 |
-
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map)
|
118 |
|
119 |
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
|
120 |
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
|
@@ -129,7 +153,7 @@ if __name__ == '__main__':
|
|
129 |
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
|
130 |
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
|
131 |
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
|
132 |
-
|
133 |
with h5py.File(out_file_dm, 'w') as f:
|
134 |
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
|
135 |
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
|
|
|
18 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
20 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
21 |
+
from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids, DisplacementMapInterpolator
|
22 |
+
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
from Brain_study.data_generator import BatchGenerator
|
|
|
38 |
MISSING_CENTROID = np.asarray([[np.nan]*3])
|
39 |
|
40 |
|
41 |
+
def get_mov_centroids(fix_seg, disp_map, nb_labels=28, exclude_background_lbl=False, brain_study=True, dm_interp=None):
|
42 |
+
if exclude_background_lbl:
|
43 |
+
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ..., 1:], ohe=True, expected_lbls=range(1, nb_labels), brain_study=brain_study)
|
44 |
+
else:
|
45 |
+
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels), brain_study=brain_study)
|
46 |
+
if dm_interp is None:
|
47 |
+
disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
|
48 |
+
else:
|
49 |
+
disp = dm_interp(disp_map, fix_centroids, backwards=False)
|
50 |
return fix_centroids, fix_centroids + disp, disp
|
51 |
|
52 |
|
|
|
57 |
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
58 |
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
|
59 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
60 |
+
parser.add_argument('--output_shape', help='If an int, a cubic shape is presumed. Otherwise provide it as a space separated sequence', nargs='+', default=128)
|
61 |
args = parser.parse_args()
|
62 |
|
63 |
assert args.dataset != '', "Missing original dataset dataset"
|
|
|
78 |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
79 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
80 |
|
81 |
+
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'], return_isotropic_shape=True)
|
82 |
|
83 |
img_generator = data_generator.get_train_generator()
|
84 |
nb_labels = len(img_generator.get_segmentation_labels())
|
85 |
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
86 |
+
|
87 |
+
if isinstance(args.output_shape, int):
|
88 |
+
image_output_shape = [args.output_shape] * 3
|
89 |
+
elif isinstance(args.output_shape, list):
|
90 |
+
assert len(args.output_shape) == 3, 'Invalid output shape, expected three values and got {}'.format(len(args.output_shape))
|
91 |
+
image_output_shape = [int(s) for s in args.output_shape]
|
92 |
+
else:
|
93 |
+
raise ValueError('Invalid output_shape. Must be an int or a space-separated sequence of ints')
|
94 |
+
print('Scaling to: ', image_output_shape)
|
95 |
# Build model
|
96 |
|
97 |
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
|
|
|
118 |
return_displacement_map=True)
|
119 |
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
|
120 |
|
121 |
+
fix_img_ph = tf.placeholder(dtype=tf.float32, shape=[1,] + image_input_shape + [1+nb_labels,], name='fix_image')
|
122 |
+
|
123 |
+
augmentation_pipeline = augm_model(fix_img_ph)
|
124 |
+
|
125 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
126 |
config.gpu_options.allow_growth = True
|
127 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
128 |
|
129 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=8)
|
130 |
+
|
131 |
sess = tf.Session(config=config)
|
132 |
tf.keras.backend.set_session(sess)
|
133 |
with sess.as_default():
|
134 |
sess.run(tf.global_variables_initializer())
|
135 |
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
|
136 |
+
for step, (in_batch, _, isotropic_shape) in progress_bar:
|
137 |
+
# fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
|
138 |
+
fix_img, mov_img, fix_seg, mov_seg, disp_map = sess.run(augmentation_pipeline,
|
139 |
+
feed_dict={'fix_image:0': in_batch})
|
140 |
|
141 |
+
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, dm_interp=dm_interp)
|
142 |
|
143 |
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
|
144 |
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
|
|
|
153 |
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
|
154 |
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
|
155 |
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
|
156 |
+
f.create_dataset('isotropic_shape', data=np.squeeze(isotropic_shape))
|
157 |
with h5py.File(out_file_dm, 'w') as f:
|
158 |
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
|
159 |
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
|
Brain_study/Evaluate_network.py
CHANGED
@@ -23,6 +23,7 @@ from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplifie
|
|
23 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
24 |
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
25 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
|
|
26 |
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
27 |
from scipy.interpolate import RegularGridInterpolator
|
28 |
from tqdm import tqdm
|
@@ -147,9 +148,9 @@ if __name__ == '__main__':
|
|
147 |
dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval()
|
148 |
hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval()
|
149 |
|
150 |
-
pred_seg =
|
151 |
-
mov_seg =
|
152 |
-
fix_seg =
|
153 |
|
154 |
mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1)
|
155 |
dest_coords = mov_coords + disp_map[0, ...]
|
@@ -178,8 +179,8 @@ if __name__ == '__main__':
|
|
178 |
plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
179 |
plt.close()
|
180 |
|
181 |
-
plot_predictions(fix_img, mov_img, disp_map,
|
182 |
-
plot_predictions(
|
183 |
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
184 |
|
185 |
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
|
|
|
23 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
24 |
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
25 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
26 |
+
from DeepDeformationMapRegistration.utils.misc import segmentation_ohe_to_cardinal
|
27 |
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
28 |
from scipy.interpolate import RegularGridInterpolator
|
29 |
from tqdm import tqdm
|
|
|
148 |
dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval()
|
149 |
hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval()
|
150 |
|
151 |
+
pred_seg = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
152 |
+
mov_seg = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
153 |
+
fix_seg = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
154 |
|
155 |
mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1)
|
156 |
dest_coords = mov_coords + disp_map[0, ...]
|
|
|
179 |
plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
180 |
plt.close()
|
181 |
|
182 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg, mov_seg, pred_seg], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
|
183 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
|
184 |
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
185 |
|
186 |
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
|
Brain_study/Evaluate_network__test_fixed.py
CHANGED
@@ -18,20 +18,23 @@ import pandas as pd
|
|
18 |
import voxelmorph as vxm
|
19 |
|
20 |
import DeepDeformationMapRegistration.utils.constants as C
|
|
|
21 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
22 |
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
23 |
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
24 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
25 |
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
26 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
|
|
27 |
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
|
28 |
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
29 |
-
from scipy.
|
30 |
from tqdm import tqdm
|
31 |
-
|
32 |
import h5py
|
33 |
import re
|
34 |
from Brain_study.data_generator import BatchGenerator
|
|
|
35 |
|
36 |
import argparse
|
37 |
|
@@ -49,9 +52,10 @@ if __name__ == '__main__':
|
|
49 |
parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
|
50 |
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
51 |
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
|
52 |
-
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/
|
53 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
54 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
|
|
55 |
args = parser.parse_args()
|
56 |
if args.model is not None:
|
57 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
@@ -83,8 +87,8 @@ if __name__ == '__main__':
|
|
83 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
84 |
config.allow_soft_placement = True
|
85 |
|
86 |
-
sess = tf.Session(config=config)
|
87 |
-
tf.keras.backend.set_session(sess)
|
88 |
|
89 |
# Loss and metric functions. Common to all models
|
90 |
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
@@ -101,10 +105,10 @@ if __name__ == '__main__':
|
|
101 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
102 |
|
103 |
### METRICS GRAPH ###
|
104 |
-
fix_img_ph = tf.placeholder(tf.float32, (1,
|
105 |
-
pred_img_ph = tf.placeholder(tf.float32, (1,
|
106 |
-
fix_seg_ph = tf.placeholder(tf.float32, (1,
|
107 |
-
pred_seg_ph = tf.placeholder(tf.float32, (1,
|
108 |
|
109 |
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
110 |
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
@@ -118,7 +122,7 @@ if __name__ == '__main__':
|
|
118 |
# Needed for VxmDense type of network
|
119 |
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
120 |
|
121 |
-
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
|
122 |
|
123 |
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
124 |
print('MODEL LOCATION: ', MODEL_FILE)
|
@@ -171,6 +175,8 @@ if __name__ == '__main__':
|
|
171 |
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
172 |
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
173 |
fix_centroids = f['fix_centroids'][:]
|
|
|
|
|
174 |
|
175 |
if network.name == 'vxm_dense_semi_supervised_seg':
|
176 |
t0 = time.time()
|
@@ -182,29 +188,42 @@ if __name__ == '__main__':
|
|
182 |
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
|
183 |
t1 = time.time()
|
184 |
|
185 |
-
|
|
|
186 |
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
187 |
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
# I need the labels to be OHE to compute the segmentation metrics.
|
190 |
-
dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
|
|
|
|
|
|
191 |
|
192 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
193 |
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
194 |
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
195 |
|
196 |
-
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0':
|
197 |
-
|
|
|
198 |
|
199 |
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
200 |
-
upsample_scale = 128 / 64
|
201 |
-
fix_centroids_isotropic = fix_centroids *
|
202 |
# mov_centroids_isotropic = mov_centroids * upsample_scale
|
203 |
-
pred_centroids_isotropic = pred_centroids *
|
204 |
|
205 |
-
fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
206 |
-
# mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
207 |
-
pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
208 |
# Now we can measure the TRE in mm
|
209 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
210 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
@@ -235,14 +254,20 @@ if __name__ == '__main__':
|
|
235 |
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
236 |
# plt.close()
|
237 |
|
238 |
-
plot_predictions(fix_img, mov_img, disp_map,
|
239 |
-
plot_predictions(
|
240 |
-
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
241 |
|
242 |
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
|
243 |
|
244 |
print('Summary\n=======\n')
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
print('\n=======\n')
|
247 |
tf.keras.backend.clear_session()
|
248 |
# sess.close()
|
|
|
18 |
import voxelmorph as vxm
|
19 |
|
20 |
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
|
22 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
23 |
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
24 |
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
25 |
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
26 |
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
27 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
28 |
+
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation
|
29 |
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
|
30 |
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
31 |
+
from scipy.ndimage import zoom
|
32 |
from tqdm import tqdm
|
33 |
+
import medpy.metric as medpy_metrics
|
34 |
import h5py
|
35 |
import re
|
36 |
from Brain_study.data_generator import BatchGenerator
|
37 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
38 |
|
39 |
import argparse
|
40 |
|
|
|
52 |
parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
|
53 |
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
54 |
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
|
55 |
+
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
|
56 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
57 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
58 |
+
|
59 |
args = parser.parse_args()
|
60 |
if args.model is not None:
|
61 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
|
|
87 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
88 |
config.allow_soft_placement = True
|
89 |
|
90 |
+
sess = tf.compat.v1.Session(config=config)
|
91 |
+
tf.compat.v1.keras.backend.set_session(sess)
|
92 |
|
93 |
# Loss and metric functions. Common to all models
|
94 |
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
|
|
105 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
106 |
|
107 |
### METRICS GRAPH ###
|
108 |
+
fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
|
109 |
+
pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
|
110 |
+
fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
|
111 |
+
pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
|
112 |
|
113 |
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
114 |
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
|
|
122 |
# Needed for VxmDense type of network
|
123 |
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
124 |
|
125 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
|
126 |
|
127 |
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
128 |
print('MODEL LOCATION: ', MODEL_FILE)
|
|
|
175 |
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
176 |
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
177 |
fix_centroids = f['fix_centroids'][:]
|
178 |
+
isotropic_shape = f['isotropic_shape'][:]
|
179 |
+
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
180 |
|
181 |
if network.name == 'vxm_dense_semi_supervised_seg':
|
182 |
t0 = time.time()
|
|
|
188 |
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
|
189 |
t1 = time.time()
|
190 |
|
191 |
+
pred_img = min_max_norm(pred_img)
|
192 |
+
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels + 1))
|
193 |
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
194 |
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
195 |
|
196 |
+
# Up sample the segmentation masks to isotropic resolution
|
197 |
+
zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
|
198 |
+
pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
|
199 |
+
fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
|
200 |
+
|
201 |
+
pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
202 |
+
fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
203 |
+
|
204 |
# I need the labels to be OHE to compute the segmentation metrics.
|
205 |
+
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
206 |
+
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
207 |
+
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
208 |
+
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
209 |
|
210 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
211 |
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
212 |
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
213 |
|
214 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img_isot, 'pred_img:0': pred_img_isot})
|
215 |
+
ssim = np.mean(ssim) # returns a list of values, which correspond to the ssim of each patch
|
216 |
+
ms_ssim = ms_ssim[0] # returns an array of shape (1,)
|
217 |
|
218 |
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
219 |
+
# upsample_scale = 128 / 64
|
220 |
+
fix_centroids_isotropic = fix_centroids * voxel_size
|
221 |
# mov_centroids_isotropic = mov_centroids * upsample_scale
|
222 |
+
pred_centroids_isotropic = pred_centroids * voxel_size
|
223 |
|
224 |
+
# fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
225 |
+
# # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
226 |
+
# pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
227 |
# Now we can measure the TRE in mm
|
228 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
229 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
|
|
254 |
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
255 |
# plt.close()
|
256 |
|
257 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
|
258 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
|
259 |
+
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
|
260 |
|
261 |
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
|
262 |
|
263 |
print('Summary\n=======\n')
|
264 |
+
metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
|
265 |
+
print('\nAVG:\n')
|
266 |
+
print(metrics_df.mean(axis=0))
|
267 |
+
print('\nSTD:\n')
|
268 |
+
print(metrics_df.std(axis=0))
|
269 |
+
print('\nHD95perc:\n')
|
270 |
+
print(metrics_df['HD'].describe(percentiles=[.95]))
|
271 |
print('\n=======\n')
|
272 |
tf.keras.backend.clear_session()
|
273 |
# sess.close()
|
Brain_study/MultiTrain_config.py
CHANGED
@@ -62,6 +62,11 @@ if __name__ == '__main__':
|
|
62 |
except KeyError as e:
|
63 |
head = [16, 16]
|
64 |
|
|
|
|
|
|
|
|
|
|
|
65 |
launch_train(dataset_folder=datasetConfig['train'],
|
66 |
validation_folder=datasetConfig['validation'],
|
67 |
output_folder=output_folder,
|
@@ -75,4 +80,5 @@ if __name__ == '__main__':
|
|
75 |
early_stop_patience=eval(trainConfig['earlyStopPatience']),
|
76 |
unet=unet,
|
77 |
head=head,
|
|
|
78 |
**loss_config)
|
|
|
62 |
except KeyError as e:
|
63 |
head = [16, 16]
|
64 |
|
65 |
+
try:
|
66 |
+
resume_checkpoint = trainConfig['resumeCheckpoint']
|
67 |
+
except KeyError as e:
|
68 |
+
resume_checkpoint = None
|
69 |
+
|
70 |
launch_train(dataset_folder=datasetConfig['train'],
|
71 |
validation_folder=datasetConfig['validation'],
|
72 |
output_folder=output_folder,
|
|
|
80 |
early_stop_patience=eval(trainConfig['earlyStopPatience']),
|
81 |
unet=unet,
|
82 |
head=head,
|
83 |
+
resume=resume_checkpoint,
|
84 |
**loss_config)
|
Brain_study/Train_Baseline.py
CHANGED
@@ -8,7 +8,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
|
8 |
import numpy as np
|
9 |
import tensorflow as tf
|
10 |
|
11 |
-
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
12 |
from tensorflow.keras import Input
|
13 |
from tensorflow.keras.models import Model
|
14 |
from tensorflow.python.keras.utils import Progbar
|
@@ -32,11 +32,12 @@ from Brain_study.utils import SummaryDictionary, named_logs
|
|
32 |
|
33 |
from tqdm import tqdm
|
34 |
from datetime import datetime
|
|
|
35 |
|
36 |
|
37 |
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
38 |
acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
|
39 |
-
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
40 |
assert dataset_folder is not None and output_folder is not None
|
41 |
|
42 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
@@ -46,7 +47,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
46 |
if batch_size != 1 and acc_gradients != 1:
|
47 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
os.makedirs(output_folder, exist_ok=True)
|
51 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
52 |
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
@@ -141,6 +151,26 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
141 |
nb_unet_features=nb_features,
|
142 |
int_steps=0)
|
143 |
network.summary(line_length=150)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# Losses and loss weights
|
145 |
SSIM_KER_SIZE = 5
|
146 |
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
@@ -178,14 +208,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
178 |
|
179 |
# Train
|
180 |
os.makedirs(output_folder, exist_ok=True)
|
181 |
-
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
182 |
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
183 |
os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
|
184 |
|
185 |
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
186 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
187 |
-
|
188 |
-
|
189 |
# CSVLogger(train_log_name, ';'),
|
190 |
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
191 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
@@ -194,6 +224,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
194 |
write_graph=True, write_grads=True
|
195 |
)
|
196 |
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
|
|
197 |
|
198 |
losses = {'transformer': loss_fnc,
|
199 |
'flow': vxm.losses.Grad('l2').loss}
|
@@ -215,8 +246,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
215 |
|
216 |
callback_tensorboard.set_model(network)
|
217 |
callback_best_model.set_model(network)
|
218 |
-
|
219 |
callback_early_stop.set_model(network)
|
|
|
220 |
# TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
|
221 |
|
222 |
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
@@ -226,13 +258,13 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
226 |
callback_tensorboard.on_train_begin()
|
227 |
callback_early_stop.on_train_begin()
|
228 |
callback_best_model.on_train_begin()
|
229 |
-
|
230 |
-
for epoch in range(C.EPOCHS):
|
231 |
callback_tensorboard.on_epoch_begin(epoch)
|
232 |
callback_early_stop.on_epoch_begin(epoch)
|
233 |
callback_best_model.on_epoch_begin(epoch)
|
234 |
-
|
235 |
-
|
236 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
237 |
print('TRAINING')
|
238 |
|
@@ -241,9 +273,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
241 |
for step, (in_batch, _) in enumerate(train_generator, 1):
|
242 |
# callback_tensorboard.on_train_batch_begin(step)
|
243 |
callback_best_model.on_train_batch_begin(step)
|
244 |
-
|
245 |
callback_early_stop.on_train_batch_begin(step)
|
246 |
-
|
247 |
try:
|
248 |
fix_img, mov_img, *_ = augm_model_train.predict(in_batch)
|
249 |
np.nan_to_num(fix_img, copy=False)
|
@@ -273,8 +305,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
273 |
summary.on_train_batch_end(ret)
|
274 |
# callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
|
275 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
276 |
-
|
277 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
|
|
278 |
progress_bar.update(step, zip(names, ret))
|
279 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
280 |
val_values = progress_bar._values.copy()
|
@@ -309,13 +342,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
309 |
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
310 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
311 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
312 |
-
|
|
|
313 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
314 |
|
315 |
callback_tensorboard.on_train_end()
|
316 |
-
|
317 |
callback_best_model.on_train_end()
|
318 |
callback_early_stop.on_train_end()
|
|
|
319 |
|
320 |
|
321 |
if __name__ == '__main__':
|
|
|
8 |
import numpy as np
|
9 |
import tensorflow as tf
|
10 |
|
11 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
|
12 |
from tensorflow.keras import Input
|
13 |
from tensorflow.keras.models import Model
|
14 |
from tensorflow.python.keras.utils import Progbar
|
|
|
32 |
|
33 |
from tqdm import tqdm
|
34 |
from datetime import datetime
|
35 |
+
import re
|
36 |
|
37 |
|
38 |
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
39 |
acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
|
40 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
|
41 |
assert dataset_folder is not None and output_folder is not None
|
42 |
|
43 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
|
|
47 |
if batch_size != 1 and acc_gradients != 1:
|
48 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
49 |
|
50 |
+
if resume is not None:
|
51 |
+
try:
|
52 |
+
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
|
53 |
+
output_folder = resume
|
54 |
+
resume = True
|
55 |
+
except AssertionError:
|
56 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
57 |
+
resume = False
|
58 |
+
else:
|
59 |
+
resume = False
|
60 |
os.makedirs(output_folder, exist_ok=True)
|
61 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
62 |
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
|
|
151 |
nb_unet_features=nb_features,
|
152 |
int_steps=0)
|
153 |
network.summary(line_length=150)
|
154 |
+
|
155 |
+
resume_epoch = 0
|
156 |
+
if resume:
|
157 |
+
cp_dir = os.path.join(output_folder, 'checkpoints')
|
158 |
+
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
|
159 |
+
if len(cp_file_list):
|
160 |
+
cp_file_list.sort()
|
161 |
+
checkpoint_file = cp_file_list[-1]
|
162 |
+
if os.path.exists(checkpoint_file):
|
163 |
+
network.load_weights(checkpoint_file, by_name=True)
|
164 |
+
print('Loaded checkpoint file: ' + checkpoint_file)
|
165 |
+
try:
|
166 |
+
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
|
167 |
+
except TypeError:
|
168 |
+
# Checkpoint file has no epoch number in the name
|
169 |
+
resume_epoch = 0
|
170 |
+
print('Resuming from epoch: {:d}'.format(resume_epoch))
|
171 |
+
else:
|
172 |
+
warnings.warn('Checkpoint file NOT found. Training from scratch')
|
173 |
+
|
174 |
# Losses and loss weights
|
175 |
SSIM_KER_SIZE = 5
|
176 |
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
|
|
208 |
|
209 |
# Train
|
210 |
os.makedirs(output_folder, exist_ok=True)
|
211 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True) # exist_ok=True leaves directory unaltered.
|
212 |
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
213 |
os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
|
214 |
|
215 |
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
216 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
217 |
+
callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'),
|
218 |
+
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
219 |
# CSVLogger(train_log_name, ';'),
|
220 |
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
221 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
|
|
224 |
write_graph=True, write_grads=True
|
225 |
)
|
226 |
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
227 |
+
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
|
228 |
|
229 |
losses = {'transformer': loss_fnc,
|
230 |
'flow': vxm.losses.Grad('l2').loss}
|
|
|
246 |
|
247 |
callback_tensorboard.set_model(network)
|
248 |
callback_best_model.set_model(network)
|
249 |
+
callback_save_model.set_model(network)
|
250 |
callback_early_stop.set_model(network)
|
251 |
+
callback_lr.set_model(network)
|
252 |
# TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
|
253 |
|
254 |
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
|
|
258 |
callback_tensorboard.on_train_begin()
|
259 |
callback_early_stop.on_train_begin()
|
260 |
callback_best_model.on_train_begin()
|
261 |
+
callback_save_model.on_train_begin()
|
262 |
+
for epoch in range(resume_epoch, C.EPOCHS):
|
263 |
callback_tensorboard.on_epoch_begin(epoch)
|
264 |
callback_early_stop.on_epoch_begin(epoch)
|
265 |
callback_best_model.on_epoch_begin(epoch)
|
266 |
+
callback_save_model.on_epoch_begin(epoch)
|
267 |
+
callback_lr.on_epoch_begin(epoch)
|
268 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
269 |
print('TRAINING')
|
270 |
|
|
|
273 |
for step, (in_batch, _) in enumerate(train_generator, 1):
|
274 |
# callback_tensorboard.on_train_batch_begin(step)
|
275 |
callback_best_model.on_train_batch_begin(step)
|
276 |
+
callback_save_model.on_train_batch_begin(step)
|
277 |
callback_early_stop.on_train_batch_begin(step)
|
278 |
+
callback_lr.on_train_batch_begin(step)
|
279 |
try:
|
280 |
fix_img, mov_img, *_ = augm_model_train.predict(in_batch)
|
281 |
np.nan_to_num(fix_img, copy=False)
|
|
|
305 |
summary.on_train_batch_end(ret)
|
306 |
# callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
|
307 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
308 |
+
callback_save_model.on_train_batch_end(step, named_logs(network, ret))
|
309 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
310 |
+
callback_lr.on_train_batch_end(step, named_logs(network, ret))
|
311 |
progress_bar.update(step, zip(names, ret))
|
312 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
313 |
val_values = progress_bar._values.copy()
|
|
|
342 |
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
343 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
344 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
345 |
+
callback_save_model.on_epoch_end(epoch, epoch_summary)
|
346 |
+
callback_lr.on_epoch_end(epoch, epoch_summary)
|
347 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
348 |
|
349 |
callback_tensorboard.on_train_end()
|
350 |
+
callback_save_model.on_train_end()
|
351 |
callback_best_model.on_train_end()
|
352 |
callback_early_stop.on_train_end()
|
353 |
+
callback_lr.on_train_end()
|
354 |
|
355 |
|
356 |
if __name__ == '__main__':
|
Brain_study/Train_SegmentationGuided.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import os, sys
|
|
|
2 |
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
parentdir = os.path.dirname(currentdir)
|
4 |
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
|
6 |
import numpy as np
|
7 |
import tensorflow as tf
|
8 |
-
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
9 |
from tensorflow.keras import Input
|
10 |
from tensorflow.keras.models import Model
|
11 |
from tensorflow.python.keras.utils import Progbar
|
@@ -30,11 +31,14 @@ from Brain_study.utils import SummaryDictionary, named_logs
|
|
30 |
|
31 |
import time
|
32 |
import warnings
|
|
|
|
|
33 |
|
34 |
|
35 |
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd',
|
36 |
acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
|
37 |
-
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
|
|
38 |
assert dataset_folder is not None and output_folder is not None
|
39 |
|
40 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
@@ -44,7 +48,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
44 |
if batch_size != 1 and acc_gradients != 1:
|
45 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
46 |
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
os.makedirs(output_folder, exist_ok=True)
|
49 |
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
50 |
C.TRAINING_DATASET = dataset_folder
|
@@ -85,6 +98,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
85 |
directory_val=C.VALIDATION_DATASET)
|
86 |
|
87 |
train_generator = data_generator.get_train_generator()
|
|
|
|
|
|
|
|
|
88 |
validation_generator = data_generator.get_validation_generator()
|
89 |
|
90 |
image_input_shape = train_generator.get_data_shape()[1][:-1]
|
@@ -96,7 +113,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
96 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
97 |
config.gpu_options.allow_growth = True
|
98 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
99 |
-
config.allow_soft_placement =
|
100 |
sess = tf.Session(config=config)
|
101 |
tf.keras.backend.set_session(sess)
|
102 |
|
@@ -128,6 +145,27 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
128 |
int_steps=0,
|
129 |
int_downsize=1,
|
130 |
seg_downsize=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
# Compile the model
|
133 |
SSIM_KER_SIZE = 5
|
@@ -190,8 +228,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
190 |
|
191 |
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
192 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
193 |
-
|
194 |
-
|
195 |
# CSVLogger(train_log_name, ';'),
|
196 |
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
197 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
@@ -200,6 +238,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
200 |
write_graph=True, write_grads=True
|
201 |
)
|
202 |
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
|
|
203 |
|
204 |
# Compile the model
|
205 |
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
|
@@ -210,9 +249,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
210 |
|
211 |
callback_tensorboard.set_model(network)
|
212 |
callback_best_model.set_model(network)
|
213 |
-
|
214 |
callback_early_stop.set_model(network)
|
215 |
-
|
216 |
summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP)
|
217 |
names = network.metrics_names # It give both the loss and metric names
|
218 |
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
@@ -221,13 +260,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
221 |
callback_tensorboard.on_train_begin()
|
222 |
callback_early_stop.on_train_begin()
|
223 |
callback_best_model.on_train_begin()
|
224 |
-
|
225 |
-
|
|
|
226 |
callback_tensorboard.on_epoch_begin(epoch)
|
227 |
callback_early_stop.on_epoch_begin(epoch)
|
228 |
callback_best_model.on_epoch_begin(epoch)
|
229 |
-
|
230 |
-
|
231 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
232 |
print('TRAINING')
|
233 |
|
@@ -238,12 +278,13 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
238 |
#print('Loaded in {} s'.format(time.time() - t0))
|
239 |
# callback_tensorboard.on_train_batch_begin(step)
|
240 |
callback_best_model.on_train_batch_begin(step)
|
241 |
-
|
242 |
callback_early_stop.on_train_batch_begin(step)
|
243 |
-
|
244 |
try:
|
245 |
t0 = time.time()
|
246 |
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
|
|
247 |
#print('Augmented in {} s'.format(time.time() - t0))
|
248 |
np.nan_to_num(fix_img, copy=False)
|
249 |
np.nan_to_num(mov_img, copy=False)
|
@@ -275,8 +316,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
275 |
summary.on_train_batch_end(ret)
|
276 |
# callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
|
277 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
278 |
-
|
279 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
|
|
280 |
progress_bar.update(step, zip(names, ret))
|
281 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
282 |
t0 = time.time()
|
@@ -313,13 +355,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr
|
|
313 |
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
314 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
315 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
316 |
-
|
|
|
317 |
|
318 |
callback_tensorboard.on_train_end()
|
319 |
-
|
320 |
callback_best_model.on_train_end()
|
321 |
callback_early_stop.on_train_end()
|
322 |
-
|
323 |
|
324 |
if __name__ == '__main__':
|
325 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
|
|
1 |
import os, sys
|
2 |
+
|
3 |
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
parentdir = os.path.dirname(currentdir)
|
5 |
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
|
7 |
import numpy as np
|
8 |
import tensorflow as tf
|
9 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
|
10 |
from tensorflow.keras import Input
|
11 |
from tensorflow.keras.models import Model
|
12 |
from tensorflow.python.keras.utils import Progbar
|
|
|
31 |
|
32 |
import time
|
33 |
import warnings
|
34 |
+
import re
|
35 |
+
import tqdm
|
36 |
|
37 |
|
38 |
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd',
|
39 |
acc_gradients=16, batch_size=1, max_epochs=10000, early_stop_patience=1000, image_size=64,
|
40 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
|
41 |
+
|
42 |
assert dataset_folder is not None and output_folder is not None
|
43 |
|
44 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
|
|
48 |
if batch_size != 1 and acc_gradients != 1:
|
49 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
50 |
|
51 |
+
if resume is not None:
|
52 |
+
try:
|
53 |
+
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
|
54 |
+
output_folder = resume
|
55 |
+
resume = True
|
56 |
+
except AssertionError:
|
57 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
58 |
+
resume = False
|
59 |
+
else:
|
60 |
+
resume = False
|
61 |
os.makedirs(output_folder, exist_ok=True)
|
62 |
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
63 |
C.TRAINING_DATASET = dataset_folder
|
|
|
98 |
directory_val=C.VALIDATION_DATASET)
|
99 |
|
100 |
train_generator = data_generator.get_train_generator()
|
101 |
+
|
102 |
+
# for l in tqdm.tqdm(train_generator, smoothing=0):
|
103 |
+
# pass
|
104 |
+
# exit()
|
105 |
validation_generator = data_generator.get_validation_generator()
|
106 |
|
107 |
image_input_shape = train_generator.get_data_shape()[1][:-1]
|
|
|
113 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
114 |
config.gpu_options.allow_growth = True
|
115 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
116 |
+
# config.allow_soft_placement = False # https://github.com/tensorflow/tensorflow/issues/30782
|
117 |
sess = tf.Session(config=config)
|
118 |
tf.keras.backend.set_session(sess)
|
119 |
|
|
|
145 |
int_steps=0,
|
146 |
int_downsize=1,
|
147 |
seg_downsize=1)
|
148 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
149 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines)
|
150 |
+
|
151 |
+
resume_epoch = 0
|
152 |
+
if resume:
|
153 |
+
cp_dir = os.path.join(output_folder, 'checkpoints')
|
154 |
+
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
|
155 |
+
if len(cp_file_list):
|
156 |
+
cp_file_list.sort()
|
157 |
+
checkpoint_file = cp_file_list[-1]
|
158 |
+
if os.path.exists(checkpoint_file):
|
159 |
+
network.load_weights(checkpoint_file, by_name=True)
|
160 |
+
print('Loaded checkpoint file: ' + checkpoint_file)
|
161 |
+
try:
|
162 |
+
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
|
163 |
+
except TypeError:
|
164 |
+
# Checkpoint file has no epoch number in the name
|
165 |
+
resume_epoch = 0
|
166 |
+
print('Resuming from epoch: {:d}'.format(resume_epoch))
|
167 |
+
else:
|
168 |
+
warnings.warn('Checkpoint file NOT found. Training from scratch')
|
169 |
|
170 |
# Compile the model
|
171 |
SSIM_KER_SIZE = 5
|
|
|
228 |
|
229 |
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
230 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
231 |
+
callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'),
|
232 |
+
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
233 |
# CSVLogger(train_log_name, ';'),
|
234 |
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
235 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
|
|
238 |
write_graph=True, write_grads=True
|
239 |
)
|
240 |
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
241 |
+
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
|
242 |
|
243 |
# Compile the model
|
244 |
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
|
|
|
249 |
|
250 |
callback_tensorboard.set_model(network)
|
251 |
callback_best_model.set_model(network)
|
252 |
+
callback_save_model.set_model(network)
|
253 |
callback_early_stop.set_model(network)
|
254 |
+
callback_lr.set_model(network)
|
255 |
summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP)
|
256 |
names = network.metrics_names # It give both the loss and metric names
|
257 |
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
|
|
260 |
callback_tensorboard.on_train_begin()
|
261 |
callback_early_stop.on_train_begin()
|
262 |
callback_best_model.on_train_begin()
|
263 |
+
callback_save_model.on_train_begin()
|
264 |
+
callback_lr.on_train_begin()
|
265 |
+
for epoch in range(resume_epoch, C.EPOCHS):
|
266 |
callback_tensorboard.on_epoch_begin(epoch)
|
267 |
callback_early_stop.on_epoch_begin(epoch)
|
268 |
callback_best_model.on_epoch_begin(epoch)
|
269 |
+
callback_save_model.on_epoch_begin(epoch)
|
270 |
+
callback_lr.on_epoch_begin(epoch)
|
271 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
272 |
print('TRAINING')
|
273 |
|
|
|
278 |
#print('Loaded in {} s'.format(time.time() - t0))
|
279 |
# callback_tensorboard.on_train_batch_begin(step)
|
280 |
callback_best_model.on_train_batch_begin(step)
|
281 |
+
callback_save_model.on_train_batch_begin(step)
|
282 |
callback_early_stop.on_train_batch_begin(step)
|
283 |
+
callback_lr.on_train_batch_begin(step)
|
284 |
try:
|
285 |
t0 = time.time()
|
286 |
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
287 |
+
|
288 |
#print('Augmented in {} s'.format(time.time() - t0))
|
289 |
np.nan_to_num(fix_img, copy=False)
|
290 |
np.nan_to_num(mov_img, copy=False)
|
|
|
316 |
summary.on_train_batch_end(ret)
|
317 |
# callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
|
318 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
319 |
+
callback_save_model.on_train_batch_end(step, named_logs(network, ret))
|
320 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
321 |
+
callback_lr.on_train_batch_end(step, named_logs(network, ret))
|
322 |
progress_bar.update(step, zip(names, ret))
|
323 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
324 |
t0 = time.time()
|
|
|
355 |
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
356 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
357 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
358 |
+
callback_save_model.on_epoch_end(epoch, epoch_summary)
|
359 |
+
callback_lr.on_epoch_end(epoch, epoch_summary)
|
360 |
|
361 |
callback_tensorboard.on_train_end()
|
362 |
+
callback_save_model.on_train_end()
|
363 |
callback_best_model.on_train_end()
|
364 |
callback_early_stop.on_train_end()
|
365 |
+
callback_lr.on_train_end()
|
366 |
|
367 |
if __name__ == '__main__':
|
368 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
Brain_study/Train_UncertaintyWeighted.py
CHANGED
@@ -5,7 +5,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
|
5 |
|
6 |
import numpy as np
|
7 |
import tensorflow as tf
|
8 |
-
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
9 |
from tensorflow.keras import Input
|
10 |
from tensorflow.keras.models import Model
|
11 |
from tensorflow.python.keras.utils import Progbar
|
@@ -27,11 +27,12 @@ from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccum
|
|
27 |
from Brain_study.data_generator import BatchGenerator
|
28 |
from Brain_study.utils import SummaryDictionary, named_logs
|
29 |
import warnings
|
|
|
30 |
|
31 |
|
32 |
def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4, rw=5e-3,
|
33 |
gpu_num=0, simil=['mse'], segm=['dice'], acc_gradients=16, batch_size=1, max_epochs=10000,
|
34 |
-
early_stop_patience=1000, image_size=64, unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
35 |
assert dataset_folder is not None and output_folder is not None
|
36 |
|
37 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
@@ -41,8 +42,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
41 |
if batch_size != 1 and acc_gradients != 1:
|
42 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
43 |
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
os.makedirs(output_folder, exist_ok=True)
|
47 |
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
48 |
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
@@ -94,7 +103,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
94 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
95 |
config.gpu_options.allow_growth = True
|
96 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
97 |
-
config.allow_soft_placement = True
|
98 |
sess = tf.Session(config=config)
|
99 |
tf.keras.backend.set_session(sess)
|
100 |
|
@@ -161,6 +170,26 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
161 |
int_steps=0,
|
162 |
int_downsize=1,
|
163 |
seg_downsize=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
# Network inputs: mov_img, fix_img, mov_seg
|
165 |
# Network outputs: pred_img, disp_map, pred_seg
|
166 |
grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
@@ -194,8 +223,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
194 |
|
195 |
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
196 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
197 |
-
|
198 |
-
|
199 |
# CSVLogger(train_log_name, ';'),
|
200 |
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
201 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
@@ -204,6 +233,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
204 |
write_graph=True, write_grads=True
|
205 |
)
|
206 |
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
|
|
207 |
|
208 |
# Compile the model
|
209 |
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
|
@@ -211,8 +241,10 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
211 |
|
212 |
callback_tensorboard.set_model(full_model)
|
213 |
callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
|
214 |
-
|
215 |
callback_early_stop.set_model(full_model)
|
|
|
|
|
216 |
# TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
|
217 |
|
218 |
summary = SummaryDictionary(full_model, C.BATCH_SIZE)
|
@@ -223,12 +255,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
223 |
callback_tensorboard.on_train_begin()
|
224 |
callback_early_stop.on_train_begin()
|
225 |
callback_best_model.on_train_begin()
|
226 |
-
|
227 |
-
|
|
|
|
|
228 |
callback_tensorboard.on_epoch_begin(epoch)
|
229 |
callback_early_stop.on_epoch_begin(epoch)
|
230 |
callback_best_model.on_epoch_begin(epoch)
|
231 |
-
|
|
|
232 |
|
233 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
234 |
print('TRAINING')
|
@@ -238,8 +273,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
238 |
for step, (in_batch, _) in enumerate(train_generator, 1):
|
239 |
# callback_tensorboard.on_train_batch_begin(step)
|
240 |
callback_best_model.on_train_batch_begin(step)
|
241 |
-
|
242 |
callback_early_stop.on_train_batch_begin(step)
|
|
|
243 |
|
244 |
try:
|
245 |
fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
|
@@ -255,8 +291,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
255 |
summary.on_train_batch_end(ret)
|
256 |
# callback_tensorboard.on_train_batch_end(step, named_logs(full_model, ret))
|
257 |
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
|
258 |
-
|
259 |
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
|
|
|
260 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
261 |
# print(ret, '\n')
|
262 |
progress_bar.update(step, zip(names, ret))
|
@@ -292,12 +329,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5
|
|
292 |
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
293 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
294 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
295 |
-
|
|
|
296 |
|
297 |
callback_tensorboard.on_train_end()
|
298 |
-
|
299 |
callback_best_model.on_train_end()
|
300 |
callback_early_stop.on_train_end()
|
|
|
301 |
|
302 |
|
303 |
if __name__ == '__main__':
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import tensorflow as tf
|
8 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
|
9 |
from tensorflow.keras import Input
|
10 |
from tensorflow.keras.models import Model
|
11 |
from tensorflow.python.keras.utils import Progbar
|
|
|
27 |
from Brain_study.data_generator import BatchGenerator
|
28 |
from Brain_study.utils import SummaryDictionary, named_logs
|
29 |
import warnings
|
30 |
+
import re
|
31 |
|
32 |
|
33 |
def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4, rw=5e-3,
|
34 |
gpu_num=0, simil=['mse'], segm=['dice'], acc_gradients=16, batch_size=1, max_epochs=10000,
|
35 |
+
early_stop_patience=1000, image_size=64, unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
|
36 |
assert dataset_folder is not None and output_folder is not None
|
37 |
|
38 |
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
|
|
42 |
if batch_size != 1 and acc_gradients != 1:
|
43 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
44 |
|
45 |
+
if resume is not None:
|
46 |
+
try:
|
47 |
+
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
|
48 |
+
output_folder = resume
|
49 |
+
resume = True
|
50 |
+
except AssertionError:
|
51 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
52 |
+
resume = False
|
53 |
+
else:
|
54 |
+
resume = False
|
55 |
os.makedirs(output_folder, exist_ok=True)
|
56 |
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
57 |
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
|
|
103 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
104 |
config.gpu_options.allow_growth = True
|
105 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
106 |
+
# config.allow_soft_placement = True
|
107 |
sess = tf.Session(config=config)
|
108 |
tf.keras.backend.set_session(sess)
|
109 |
|
|
|
170 |
int_steps=0,
|
171 |
int_downsize=1,
|
172 |
seg_downsize=1)
|
173 |
+
|
174 |
+
resume_epoch = 0
|
175 |
+
if resume:
|
176 |
+
cp_dir = os.path.join(output_folder, 'checkpoints')
|
177 |
+
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
|
178 |
+
if len(cp_file_list):
|
179 |
+
cp_file_list.sort()
|
180 |
+
checkpoint_file = cp_file_list[-1]
|
181 |
+
if os.path.exists(checkpoint_file):
|
182 |
+
network.load_weights(checkpoint_file, by_name=True)
|
183 |
+
print('Loaded checkpoint file: ' + checkpoint_file)
|
184 |
+
try:
|
185 |
+
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
|
186 |
+
except TypeError:
|
187 |
+
# Checkpoint file has no epoch number in the name
|
188 |
+
resume_epoch = 0
|
189 |
+
print('Resuming from epoch: {:d}'.format(resume_epoch))
|
190 |
+
else:
|
191 |
+
warnings.warn('Checkpoint file NOT found. Training from scratch')
|
192 |
+
|
193 |
# Network inputs: mov_img, fix_img, mov_seg
|
194 |
# Network outputs: pred_img, disp_map, pred_seg
|
195 |
grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
|
|
223 |
|
224 |
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
225 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
226 |
+
callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.{epoch:05d}-{val_loss:.2f}.h5'),
|
227 |
+
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
228 |
# CSVLogger(train_log_name, ';'),
|
229 |
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
230 |
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
|
|
233 |
write_graph=True, write_grads=True
|
234 |
)
|
235 |
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
236 |
+
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
|
237 |
|
238 |
# Compile the model
|
239 |
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
|
|
|
241 |
|
242 |
callback_tensorboard.set_model(full_model)
|
243 |
callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
|
244 |
+
callback_save_model.set_model(network)
|
245 |
callback_early_stop.set_model(full_model)
|
246 |
+
callback_lr.set_model(full_model)
|
247 |
+
|
248 |
# TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
|
249 |
|
250 |
summary = SummaryDictionary(full_model, C.BATCH_SIZE)
|
|
|
255 |
callback_tensorboard.on_train_begin()
|
256 |
callback_early_stop.on_train_begin()
|
257 |
callback_best_model.on_train_begin()
|
258 |
+
callback_save_model.on_train_begin()
|
259 |
+
callback_lr.on_train_begin()
|
260 |
+
|
261 |
+
for epoch in range(resume_epoch, C.EPOCHS):
|
262 |
callback_tensorboard.on_epoch_begin(epoch)
|
263 |
callback_early_stop.on_epoch_begin(epoch)
|
264 |
callback_best_model.on_epoch_begin(epoch)
|
265 |
+
callback_save_model.on_epoch_begin(epoch)
|
266 |
+
callback_lr.on_epoch_begin(epoch)
|
267 |
|
268 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
269 |
print('TRAINING')
|
|
|
273 |
for step, (in_batch, _) in enumerate(train_generator, 1):
|
274 |
# callback_tensorboard.on_train_batch_begin(step)
|
275 |
callback_best_model.on_train_batch_begin(step)
|
276 |
+
callback_save_model.on_train_batch_begin(step)
|
277 |
callback_early_stop.on_train_batch_begin(step)
|
278 |
+
callback_lr.on_train_batch_begin(step)
|
279 |
|
280 |
try:
|
281 |
fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
|
|
|
291 |
summary.on_train_batch_end(ret)
|
292 |
# callback_tensorboard.on_train_batch_end(step, named_logs(full_model, ret))
|
293 |
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
|
294 |
+
callback_save_model.on_train_batch_end(step, named_logs(full_model, ret))
|
295 |
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
|
296 |
+
callback_lr.on_train_batch_end(step, named_logs(full_model, ret))
|
297 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
298 |
# print(ret, '\n')
|
299 |
progress_bar.update(step, zip(names, ret))
|
|
|
329 |
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
330 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
331 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
332 |
+
callback_save_model.on_epoch_end(epoch, epoch_summary)
|
333 |
+
callback_lr.on_epoch_end(epoch, epoch_summary)
|
334 |
|
335 |
callback_tensorboard.on_train_end()
|
336 |
+
callback_save_model.on_train_end()
|
337 |
callback_best_model.on_train_end()
|
338 |
callback_early_stop.on_train_end()
|
339 |
+
callback_lr.on_train_end()
|
340 |
|
341 |
|
342 |
if __name__ == '__main__':
|
Brain_study/data_generator.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import warnings
|
2 |
-
|
3 |
import numpy as np
|
4 |
from tensorflow import keras
|
5 |
import os
|
@@ -14,6 +14,7 @@ import tensorflow as tf
|
|
14 |
|
15 |
import DeepDeformationMapRegistration.utils.constants as C
|
16 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
|
|
17 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
18 |
from voxelmorph.tf.layers import SpatialTransformer
|
19 |
from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT
|
@@ -22,6 +23,10 @@ from tensorflow.python.keras.preprocessing.image import Iterator
|
|
22 |
from tensorflow.python.keras.utils import Sequence
|
23 |
import sys
|
24 |
|
|
|
|
|
|
|
|
|
25 |
#import concurrent.futures
|
26 |
#import multiprocessing as mp
|
27 |
import time
|
@@ -34,13 +39,15 @@ class BatchGenerator:
|
|
34 |
split=0.7,
|
35 |
combine_segmentations=True,
|
36 |
labels=['all'],
|
37 |
-
directory_val=None
|
|
|
38 |
self.file_directory = directory
|
39 |
self.batch_size = batch_size
|
40 |
self.combine_segmentations = combine_segmentations
|
41 |
self.labels = labels
|
42 |
self.shuffle = shuffle
|
43 |
self.split = split
|
|
|
44 |
|
45 |
if directory_val is None:
|
46 |
self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
|
@@ -48,11 +55,11 @@ class BatchGenerator:
|
|
48 |
self.num_samples = len(self.file_list)
|
49 |
training_samples = self.file_list[:int(self.num_samples * self.split)]
|
50 |
|
51 |
-
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
|
52 |
if self.split < 1.:
|
53 |
validation_samples = list(set(self.file_list) - set(training_samples))
|
54 |
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
|
55 |
-
validation=True)
|
56 |
else:
|
57 |
self.validation_iter = None
|
58 |
else:
|
@@ -66,7 +73,7 @@ class BatchGenerator:
|
|
66 |
self.file_list = training_samples + validation_samples
|
67 |
|
68 |
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
|
69 |
-
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations,
|
70 |
validation=True)
|
71 |
|
72 |
def get_train_generator(self):
|
@@ -92,7 +99,8 @@ ALL_LABELS_LOC = {label: loc for label, loc in zip(ALL_LABELS, range(0, len(ALL_
|
|
92 |
|
93 |
class BatchIterator(Sequence):
|
94 |
def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'],
|
95 |
-
zero_grads=[64, 64, 64, 3], validation=False,
|
|
|
96 |
# super(BatchIterator, self).__init__(n=len(file_list),
|
97 |
# batch_size=batch_size,
|
98 |
# shuffle=shuffle,
|
@@ -103,31 +111,51 @@ class BatchIterator(Sequence):
|
|
103 |
self.file_list = file_list
|
104 |
self.combine_segmentations = combine_segmentations
|
105 |
self.labels = labels
|
106 |
-
self.zero_grads = zero_grads
|
107 |
self.idx_list = np.arange(0, len(self.file_list))
|
108 |
self.validation = validation
|
|
|
|
|
109 |
self._initialize()
|
110 |
self.shuffle_samples()
|
111 |
|
112 |
def _initialize(self):
|
113 |
-
|
114 |
-
self.image_shape = list(f['image'][:].shape)
|
115 |
-
self.segm_shape = list(f['segmentation'][:].shape)
|
116 |
-
if not self.combine_segmentations:
|
117 |
-
self.segm_shape[-1] = len(f['segmentation_labels'][:]) if self.labels[0].lower() == 'all' else len(self.labels)
|
118 |
-
|
119 |
-
self.batch_shape = self.image_shape.copy()
|
120 |
-
if self.labels[0].lower() != 'none':
|
121 |
-
self.batch_shape[-1] = 2 if self.combine_segmentations else 1 + self.segm_shape[-1] # +1 because we have the fix and the moving images
|
122 |
-
|
123 |
if self.labels[0] != 'all':
|
124 |
-
if
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0)
|
128 |
#self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.batch_size)
|
129 |
#self.mp_pool = mp.Pool(self.batch_size)
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def shuffle_samples(self):
|
132 |
np.random.shuffle(self.idx_list)
|
133 |
|
@@ -137,7 +165,7 @@ class BatchIterator(Sequence):
|
|
137 |
def _filter_segmentations(self, segm, segm_labels):
|
138 |
if self.combine_segmentations:
|
139 |
# TODO
|
140 |
-
warnings.warn('Cannot select labels when
|
141 |
if self.labels[0] != 'all':
|
142 |
if set(self.labels).issubset(set(segm_labels)):
|
143 |
# If labels in self.labels are in segm
|
@@ -155,31 +183,38 @@ class BatchIterator(Sequence):
|
|
155 |
def _load_sample(self, file_path):
|
156 |
with h5py.File(file_path, 'r') as f:
|
157 |
img = f['image'][:]
|
158 |
-
|
159 |
-
|
160 |
-
segm = f['segmentation'][:]
|
161 |
-
else:
|
162 |
-
segm = f['segmentation_expanded'][:]
|
163 |
-
if segm.shape[-1] != self.segm_shape[-1]:
|
164 |
-
aux = np.zeros(self.segm_shape)
|
165 |
-
aux[..., :segm.shape[-1]] = segm # Ensure the same shape in case there are missing labels in aux
|
166 |
-
segm = aux
|
167 |
-
# TODO: selection label segm = aux[..., self.labels] but:
|
168 |
-
# what if aux does not have a label in self.labels??
|
169 |
|
170 |
-
if
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
if self.validation:
|
174 |
-
ret_val = np.concatenate([img, segm], axis=-1), (img, segm,
|
175 |
else:
|
176 |
-
ret_val = np.concatenate([img, segm], axis=-1), (img,
|
177 |
else:
|
178 |
-
ret_val = img, (img,
|
179 |
return ret_val
|
180 |
|
181 |
def __getitem__(self, idx):
|
182 |
in_batch = list()
|
|
|
183 |
# out_batch = list()
|
184 |
|
185 |
batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size]
|
@@ -193,14 +228,22 @@ class BatchIterator(Sequence):
|
|
193 |
# # out_batch.append(i)
|
194 |
# else:
|
195 |
# No need for multithreading, we are loading a single file
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
198 |
in_batch.append(b)
|
199 |
# out_batch.append(i)
|
200 |
|
201 |
-
in_batch = np.asarray(in_batch)
|
|
|
|
|
|
|
|
|
202 |
# out_batch = np.asarray(out_batch)
|
203 |
-
return
|
204 |
|
205 |
def __iter__(self):
|
206 |
"""Create a generator that iterate over the Sequence."""
|
@@ -217,9 +260,7 @@ class BatchIterator(Sequence):
|
|
217 |
if self.combine_segmentations:
|
218 |
labels = [1]
|
219 |
else:
|
220 |
-
|
221 |
-
labels = np.unique(f['segmentation'][:])
|
222 |
-
labels = np.sort(labels)[1:] # Ignore the background
|
223 |
return labels
|
224 |
|
225 |
|
|
|
1 |
import warnings
|
2 |
+
import time
|
3 |
import numpy as np
|
4 |
from tensorflow import keras
|
5 |
import os
|
|
|
14 |
|
15 |
import DeepDeformationMapRegistration.utils.constants as C
|
16 |
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
17 |
+
from DeepDeformationMapRegistration.utils.misc import segmentation_cardinal_to_ohe
|
18 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
19 |
from voxelmorph.tf.layers import SpatialTransformer
|
20 |
from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT
|
|
|
23 |
from tensorflow.python.keras.utils import Sequence
|
24 |
import sys
|
25 |
|
26 |
+
from collections import defaultdict
|
27 |
+
|
28 |
+
from Brain_study.format_dataset import SEGMENTATION_LOC
|
29 |
+
|
30 |
#import concurrent.futures
|
31 |
#import multiprocessing as mp
|
32 |
import time
|
|
|
39 |
split=0.7,
|
40 |
combine_segmentations=True,
|
41 |
labels=['all'],
|
42 |
+
directory_val=None,
|
43 |
+
return_isotropic_shape=False):
|
44 |
self.file_directory = directory
|
45 |
self.batch_size = batch_size
|
46 |
self.combine_segmentations = combine_segmentations
|
47 |
self.labels = labels
|
48 |
self.shuffle = shuffle
|
49 |
self.split = split
|
50 |
+
self.return_isotropic_shape=return_isotropic_shape
|
51 |
|
52 |
if directory_val is None:
|
53 |
self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
|
|
|
55 |
self.num_samples = len(self.file_list)
|
56 |
training_samples = self.file_list[:int(self.num_samples * self.split)]
|
57 |
|
58 |
+
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels, return_isotropic_shape=return_isotropic_shape)
|
59 |
if self.split < 1.:
|
60 |
validation_samples = list(set(self.file_list) - set(training_samples))
|
61 |
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
|
62 |
+
validation=True, return_isotropic_shape=return_isotropic_shape)
|
63 |
else:
|
64 |
self.validation_iter = None
|
65 |
else:
|
|
|
73 |
self.file_list = training_samples + validation_samples
|
74 |
|
75 |
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
|
76 |
+
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, labels,
|
77 |
validation=True)
|
78 |
|
79 |
def get_train_generator(self):
|
|
|
99 |
|
100 |
class BatchIterator(Sequence):
|
101 |
def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'],
|
102 |
+
zero_grads=[64, 64, 64, 3], validation=False, sequential_labels=True,
|
103 |
+
return_isotropic_shape=False, **kwargs):
|
104 |
# super(BatchIterator, self).__init__(n=len(file_list),
|
105 |
# batch_size=batch_size,
|
106 |
# shuffle=shuffle,
|
|
|
111 |
self.file_list = file_list
|
112 |
self.combine_segmentations = combine_segmentations
|
113 |
self.labels = labels
|
114 |
+
self.zero_grads = np.zeros(zero_grads)
|
115 |
self.idx_list = np.arange(0, len(self.file_list))
|
116 |
self.validation = validation
|
117 |
+
self.sequential_labels = sequential_labels
|
118 |
+
self.return_isotropic_shape = return_isotropic_shape
|
119 |
self._initialize()
|
120 |
self.shuffle_samples()
|
121 |
|
122 |
def _initialize(self):
|
123 |
+
if (isinstance(self.labels[0], str) and self.labels[0].lower() != 'none'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
if self.labels[0] != 'all':
|
125 |
+
# Labels are tag names. Convert to numeric and check if the expected labels are in sequence or not
|
126 |
+
self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels]
|
127 |
+
if not self.sequential_labels:
|
128 |
+
self.labels = [SEGMENTATION_LOC[lbl] for lbl in self.labels]
|
129 |
+
self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x in self.labels else 0
|
130 |
+
else:
|
131 |
+
self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0
|
132 |
+
else:
|
133 |
+
# Use all labels
|
134 |
+
if self.sequential_labels:
|
135 |
+
self.labels = list(set(SEGMENTATION_LOC.values()))
|
136 |
+
self.labels_dict = lambda x: SEGMENTATION_LOC[x] if x else 0
|
137 |
+
else:
|
138 |
+
self.labels = list(ALL_LABELS)
|
139 |
+
self.labels_dict = lambda x: ALL_LABELS_LOC[x] if x in self.labels else 0
|
140 |
+
elif hasattr(self.labels[0], 'lower') and self.labels[0].lower() == 'none':
|
141 |
+
# self.labels = list()
|
142 |
+
self.labels_dict = dict()
|
143 |
+
else:
|
144 |
+
assert np.all([isinstance(lbl, (int, float)) for lbl in self.labels]), "Labels must be a str, int or float"
|
145 |
+
# Nothing to do, the self.labels contains a list of numbers
|
146 |
|
147 |
self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0)
|
148 |
#self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.batch_size)
|
149 |
#self.mp_pool = mp.Pool(self.batch_size)
|
150 |
|
151 |
+
with h5py.File(self.file_list[0], 'r') as f:
|
152 |
+
self.image_shape = list(f['image'][:].shape)
|
153 |
+
self.segm_shape = self.image_shape.copy()
|
154 |
+
self.segm_shape[-1] = len(self.labels) if not self.combine_segmentations else 1
|
155 |
+
|
156 |
+
self.batch_shape = self.image_shape.copy()
|
157 |
+
self.batch_shape[-1] = self.image_shape[-1] + self.segm_shape[-1]
|
158 |
+
|
159 |
def shuffle_samples(self):
|
160 |
np.random.shuffle(self.idx_list)
|
161 |
|
|
|
165 |
def _filter_segmentations(self, segm, segm_labels):
|
166 |
if self.combine_segmentations:
|
167 |
# TODO
|
168 |
+
warnings.warn('Cannot select labels when combine_segmentations options is active')
|
169 |
if self.labels[0] != 'all':
|
170 |
if set(self.labels).issubset(set(segm_labels)):
|
171 |
# If labels in self.labels are in segm
|
|
|
183 |
def _load_sample(self, file_path):
|
184 |
with h5py.File(file_path, 'r') as f:
|
185 |
img = f['image'][:]
|
186 |
+
segm = f['segmentation'][:]
|
187 |
+
isot_shape = f['isotropic_shape'][:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
if not self.combine_segmentations:
|
190 |
+
if self.sequential_labels:
|
191 |
+
# TODO: I am assuming I want all the labels
|
192 |
+
segm = np.squeeze(np.eye(len(self.labels))[segm])
|
193 |
+
else:
|
194 |
+
lbls_list = list(ALL_LABELS) if self.labels[0] == 'all' else self.labels
|
195 |
+
segm = segmentation_cardinal_to_ohe(segm, lbls_list) # Filtering is done here
|
196 |
+
# aux = np.zeros(self.segm_shape)
|
197 |
+
# aux[..., :segm.shape[-1]] = segm # Ensure the same shape in case there are missing labels in aux
|
198 |
+
# segm = aux
|
199 |
+
# TODO: selection label segm = aux[..., self.labels] but:
|
200 |
+
# what if aux does not have a label in self.labels??
|
201 |
+
|
202 |
+
img = np.asarray(img, dtype=np.float32)
|
203 |
+
segm = np.asarray(segm, dtype=np.float32)
|
204 |
+
if not isinstance(self.labels[0], str) or self.labels[0].lower() != 'none' or self.validation: # I expect to ask for the segmentations during val
|
205 |
+
# segm = self._filter_segmentations(segm, segm_labels)
|
206 |
|
207 |
if self.validation:
|
208 |
+
ret_val = np.concatenate([img, segm], axis=-1), (img, segm, self.zero_grads), isot_shape
|
209 |
else:
|
210 |
+
ret_val = np.concatenate([img, segm], axis=-1), (img, self.zero_grads), isot_shape
|
211 |
else:
|
212 |
+
ret_val = img, (img, self.zero_grads), isot_shape
|
213 |
return ret_val
|
214 |
|
215 |
def __getitem__(self, idx):
|
216 |
in_batch = list()
|
217 |
+
isotropic_shape = list()
|
218 |
# out_batch = list()
|
219 |
|
220 |
batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size]
|
|
|
228 |
# # out_batch.append(i)
|
229 |
# else:
|
230 |
# No need for multithreading, we are loading a single file
|
231 |
+
# in_batch = np.zeros([self.batch_size] + self.batch_shape, dtype=np.float32)
|
232 |
+
for batch_idx, f in enumerate(file_list):
|
233 |
+
b, i, isot_shape = self._load_sample(f)
|
234 |
+
# in_batch[batch_idx, :, :, :, :] = b
|
235 |
+
if self.return_isotropic_shape:
|
236 |
+
isotropic_shape.append(isot_shape)
|
237 |
in_batch.append(b)
|
238 |
# out_batch.append(i)
|
239 |
|
240 |
+
in_batch = np.asarray(in_batch, dtype=np.float32)
|
241 |
+
ret_val = (in_batch, in_batch)
|
242 |
+
if self.return_isotropic_shape:
|
243 |
+
isotropic_shape = np.asarray(isotropic_shape, dtype=np.int)
|
244 |
+
ret_val += (isotropic_shape,)
|
245 |
# out_batch = np.asarray(out_batch)
|
246 |
+
return ret_val
|
247 |
|
248 |
def __iter__(self):
|
249 |
"""Create a generator that iterate over the Sequence."""
|
|
|
260 |
if self.combine_segmentations:
|
261 |
labels = [1]
|
262 |
else:
|
263 |
+
labels = self.labels
|
|
|
|
|
264 |
return labels
|
265 |
|
266 |
|
Brain_study/format_dataset.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1 |
import h5py
|
2 |
import nibabel as nib
|
3 |
from nilearn.image import resample_img
|
4 |
-
import os
|
5 |
import re
|
6 |
import numpy as np
|
7 |
from scipy.ndimage import zoom
|
8 |
from tqdm import tqdm
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
SEGMENTATION_NR2LBL_LUT = {0: 'background',
|
11 |
2: 'parietal-right-gm',
|
12 |
3: 'lateral-ventricle-left',
|
@@ -36,15 +42,26 @@ SEGMENTATION_NR2LBL_LUT = {0: 'background',
|
|
36 |
233: '4th-ventricle',
|
37 |
254: 'fornix-right',
|
38 |
255: 'csf'}
|
|
|
39 |
SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1'
|
42 |
SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/anatomical_masks'
|
43 |
|
44 |
IMG_NAME_PATTERN = '(.*).nii.gz'
|
45 |
SEG_NAME_PATTERN = '(.*)_lobes.nii.gz'
|
46 |
|
47 |
-
OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/
|
48 |
|
49 |
if __name__ == '__main__':
|
50 |
img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
|
@@ -54,6 +71,9 @@ if __name__ == '__main__':
|
|
54 |
seg_list.sort()
|
55 |
|
56 |
os.makedirs(OUT_DIRECTORY, exist_ok=True)
|
|
|
|
|
|
|
57 |
for seg_file in tqdm(seg_list):
|
58 |
img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
|
59 |
img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
|
@@ -70,6 +90,8 @@ if __name__ == '__main__':
|
|
70 |
seg = np.asarray(seg.dataobj)
|
71 |
seg = zoom(seg, np.asarray([128]*3) / np.asarray(isot_shape), order=0)
|
72 |
|
|
|
|
|
73 |
unique_lbls = np.unique(seg)[1:] # Omit background
|
74 |
seg_expanded = np.tile(np.zeros_like(seg)[..., np.newaxis], (1, 1, 1, len(unique_lbls)))
|
75 |
for ch, lbl in enumerate(unique_lbls):
|
@@ -79,12 +101,13 @@ if __name__ == '__main__':
|
|
79 |
|
80 |
h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
|
81 |
h5_file.create_dataset('segmentation', data=seg[..., np.newaxis].astype(np.uint8), dtype=np.uint8)
|
82 |
-
h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
|
83 |
h5_file.create_dataset('segmentation_labels', data=unique_lbls)
|
84 |
h5_file.create_dataset('isotropic_shape', data=isot_shape)
|
85 |
|
86 |
h5_file.close()
|
87 |
-
|
|
|
88 |
|
89 |
|
90 |
|
|
|
1 |
import h5py
|
2 |
import nibabel as nib
|
3 |
from nilearn.image import resample_img
|
4 |
+
import os, sys
|
5 |
import re
|
6 |
import numpy as np
|
7 |
from scipy.ndimage import zoom
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
11 |
+
parentdir = os.path.dirname(currentdir)
|
12 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
13 |
+
|
14 |
+
from Brain_study.split_dataset import split
|
15 |
+
|
16 |
SEGMENTATION_NR2LBL_LUT = {0: 'background',
|
17 |
2: 'parietal-right-gm',
|
18 |
3: 'lateral-ventricle-left',
|
|
|
42 |
233: '4th-ventricle',
|
43 |
254: 'fornix-right',
|
44 |
255: 'csf'}
|
45 |
+
|
46 |
SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
|
47 |
|
48 |
+
ALL_LABELS = {2., 3., 4., 6., 8., 9., 11., 12., 14., 16., 20., 23., 29., 33., 39., 53., 67., 76., 102., 203., 210.,
|
49 |
+
211., 218., 219., 232., 233., 254., 255.}
|
50 |
+
LABELS_COMBINED = {0, (2, 6), (3, 9), (4, 8), (11, 12), (14, 16), 20, (23, 33), (29, 254), (39, 53), (67, 76), (102, 203), (210, 211), (218, 219), 232, 233, 255}
|
51 |
+
SEGMENTATION_LOC = {}
|
52 |
+
for loc, label in enumerate(LABELS_COMBINED):
|
53 |
+
if isinstance(label, tuple):
|
54 |
+
SEGMENTATION_LOC.update(dict.fromkeys(label, loc))
|
55 |
+
else:
|
56 |
+
SEGMENTATION_LOC[label] = loc
|
57 |
+
|
58 |
IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1'
|
59 |
SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/anatomical_masks'
|
60 |
|
61 |
IMG_NAME_PATTERN = '(.*).nii.gz'
|
62 |
SEG_NAME_PATTERN = '(.*)_lobes.nii.gz'
|
63 |
|
64 |
+
OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/ERASEME_sequential'
|
65 |
|
66 |
if __name__ == '__main__':
|
67 |
img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
|
|
|
71 |
seg_list.sort()
|
72 |
|
73 |
os.makedirs(OUT_DIRECTORY, exist_ok=True)
|
74 |
+
|
75 |
+
vectorize_fnc = np.vectorize(lambda x: SEGMENTATION_LOC[x] if x in SEGMENTATION_LOC.keys() else 0)
|
76 |
+
change_labels = lambda x: np.reshape(vectorize_fnc(x.ravel()), x.shape)
|
77 |
for seg_file in tqdm(seg_list):
|
78 |
img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
|
79 |
img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
|
|
|
90 |
seg = np.asarray(seg.dataobj)
|
91 |
seg = zoom(seg, np.asarray([128]*3) / np.asarray(isot_shape), order=0)
|
92 |
|
93 |
+
seg = change_labels(seg) # This way the segmentation numbering is continuous
|
94 |
+
|
95 |
unique_lbls = np.unique(seg)[1:] # Omit background
|
96 |
seg_expanded = np.tile(np.zeros_like(seg)[..., np.newaxis], (1, 1, 1, len(unique_lbls)))
|
97 |
for ch, lbl in enumerate(unique_lbls):
|
|
|
101 |
|
102 |
h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
|
103 |
h5_file.create_dataset('segmentation', data=seg[..., np.newaxis].astype(np.uint8), dtype=np.uint8)
|
104 |
+
# h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
|
105 |
h5_file.create_dataset('segmentation_labels', data=unique_lbls)
|
106 |
h5_file.create_dataset('isotropic_shape', data=isot_shape)
|
107 |
|
108 |
h5_file.close()
|
109 |
+
# We should only have train and test. The val split is done by the batch generator
|
110 |
+
split(train_perc=0.70, validation_perc=0.15, test_perc=0.15, data_dir=OUT_DIRECTORY, move_files=True)
|
111 |
|
112 |
|
113 |
|
Brain_study/split_dataset.py
CHANGED
@@ -4,51 +4,55 @@ import random
|
|
4 |
import warnings
|
5 |
|
6 |
import math
|
7 |
-
from shutil import copyfile
|
8 |
from tqdm import tqdm
|
9 |
import concurrent.futures
|
10 |
import numpy as np
|
11 |
|
12 |
|
13 |
-
def
|
14 |
s, d = s_d
|
15 |
file_name = os.path.split(s)[-1]
|
16 |
copyfile(s, os.path.join(d, file_name))
|
17 |
return int(os.path.exists(d))
|
18 |
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
|
26 |
-
parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
|
27 |
-
parser.add_argument('-r', '--random', type=bool, help='Randomly split the dataset or not. Default: True', default=True)
|
28 |
|
29 |
-
args = parser.parse_args()
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
file_set = [os.path.join(
|
34 |
-
random.shuffle(file_set) if
|
35 |
|
36 |
num_files = len(file_set)
|
37 |
-
num_validation = math.floor(num_files *
|
38 |
-
num_test = math.floor(num_files *
|
39 |
num_train = num_files - num_test - num_validation
|
40 |
|
41 |
-
dataset_root, dataset_name = os.path.split(
|
42 |
-
dst_train = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'train_set')
|
43 |
-
dst_validation = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'validation_set')
|
44 |
-
dst_test = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'test_set')
|
45 |
|
46 |
print('OUTPUT INFORMATION\n=============')
|
47 |
print('Train:\t\t{}'.format(num_train))
|
48 |
print('Validation:\t{}'.format(num_validation))
|
49 |
print('Test:\t\t{}'.format(num_test))
|
50 |
print('Num. samples\t{}'.format(num_files))
|
51 |
-
print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_'+dataset_name))
|
52 |
|
53 |
dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
|
54 |
|
@@ -57,8 +61,10 @@ if __name__ == '__main__':
|
|
57 |
os.makedirs(dst_test, exist_ok=True)
|
58 |
|
59 |
progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
|
|
|
|
|
60 |
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
|
61 |
-
results = list(tqdm(ex.map(
|
62 |
|
63 |
num_copies = np.sum(results)
|
64 |
if num_copies == num_files:
|
@@ -66,3 +72,18 @@ if __name__ == '__main__':
|
|
66 |
else:
|
67 |
warnings.warn('Missing files: {}'.format(num_files - num_copies))
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import warnings
|
5 |
|
6 |
import math
|
7 |
+
from shutil import copyfile, move
|
8 |
from tqdm import tqdm
|
9 |
import concurrent.futures
|
10 |
import numpy as np
|
11 |
|
12 |
|
13 |
+
def copy_file_fnc(s_d):
|
14 |
s, d = s_d
|
15 |
file_name = os.path.split(s)[-1]
|
16 |
copyfile(s, os.path.join(d, file_name))
|
17 |
return int(os.path.exists(d))
|
18 |
|
19 |
|
20 |
+
def move_file_fnc(s_d):
|
21 |
+
s, d = s_d
|
22 |
+
file_name = os.path.split(s)[-1]
|
23 |
+
move(s, os.path.join(d, file_name))
|
24 |
+
return int(os.path.exists(d))
|
|
|
|
|
|
|
25 |
|
|
|
26 |
|
27 |
+
def split(train_perc: float=0.7,
|
28 |
+
validation_perc: float=0.15,
|
29 |
+
test_perc: float=0.15,
|
30 |
+
data_dir: str='',
|
31 |
+
file_format: str='h5',
|
32 |
+
random_split: bool=True,
|
33 |
+
move_files: bool=False):
|
34 |
+
assert train_perc + validation_perc + test_perc == 1.0, 'Train+Validation+Test != 1 (100%)'
|
35 |
+
assert train_perc > 0 and test_perc > 0, 'Train and test percentages must be greater than zero'
|
36 |
|
37 |
+
file_set = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(file_format)]
|
38 |
+
random.shuffle(file_set) if random_split else file_set.sort()
|
39 |
|
40 |
num_files = len(file_set)
|
41 |
+
num_validation = math.floor(num_files * validation_perc)
|
42 |
+
num_test = math.floor(num_files * test_perc)
|
43 |
num_train = num_files - num_test - num_validation
|
44 |
|
45 |
+
dataset_root, dataset_name = os.path.split(data_dir)
|
46 |
+
dst_train = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'train_set')
|
47 |
+
dst_validation = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'validation_set')
|
48 |
+
dst_test = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'test_set')
|
49 |
|
50 |
print('OUTPUT INFORMATION\n=============')
|
51 |
print('Train:\t\t{}'.format(num_train))
|
52 |
print('Validation:\t{}'.format(num_validation))
|
53 |
print('Test:\t\t{}'.format(num_test))
|
54 |
print('Num. samples\t{}'.format(num_files))
|
55 |
+
print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_' + dataset_name))
|
56 |
|
57 |
dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
|
58 |
|
|
|
61 |
os.makedirs(dst_test, exist_ok=True)
|
62 |
|
63 |
progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
|
64 |
+
operation = move_file_fnc if move_files else copy_file_fnc
|
65 |
+
desc = 'Moving files' if move_files else 'Copying files'
|
66 |
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
|
67 |
+
results = list(tqdm(ex.map(operation, zip(file_set, dest)), desc=desc, total=num_files))
|
68 |
|
69 |
num_copies = np.sum(results)
|
70 |
if num_copies == num_files:
|
|
|
72 |
else:
|
73 |
warnings.warn('Missing files: {}'.format(num_files - num_copies))
|
74 |
|
75 |
+
if __name__ == '__main__':
|
76 |
+
parser = argparse.ArgumentParser()
|
77 |
+
parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
|
78 |
+
parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
|
79 |
+
parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
|
80 |
+
parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
|
81 |
+
parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
|
82 |
+
parser.add_argument('-r', '--random', help='Randomly split the dataset or not. Default: True', action='store_true', default=True)
|
83 |
+
parser.add_argument('-m', '--movefiles', help='Move files. Otherwise copy. Default: False', action='store_true', default=False)
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
split(args.train, args.validation, args.test, args.dir, args.format, args.random, args.movefiles)
|
88 |
+
|
89 |
+
|
COMET/Build_test_set.py
CHANGED
@@ -18,10 +18,11 @@ import DeepDeformationMapRegistration.utils.constants as C
|
|
18 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
20 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
21 |
-
from DeepDeformationMapRegistration.utils.misc import
|
22 |
from tqdm import tqdm
|
23 |
|
24 |
from Brain_study.data_generator import BatchGenerator
|
|
|
25 |
|
26 |
from skimage.measure import regionprops
|
27 |
from scipy.interpolate import griddata
|
@@ -35,12 +36,6 @@ POINTS = None
|
|
35 |
MISSING_CENTROID = np.asarray([[np.nan]*3])
|
36 |
|
37 |
|
38 |
-
def get_mov_centroids(fix_seg, disp_map, nb_labels=28, brain_study=True):
|
39 |
-
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=brain_study)
|
40 |
-
disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
|
41 |
-
return fix_centroids, fix_centroids + disp, disp
|
42 |
-
|
43 |
-
|
44 |
if __name__ == '__main__':
|
45 |
parser = argparse.ArgumentParser()
|
46 |
parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
|
@@ -48,6 +43,7 @@ if __name__ == '__main__':
|
|
48 |
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
49 |
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
|
50 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
|
|
51 |
args = parser.parse_args()
|
52 |
|
53 |
assert args.dataset != '', "Missing original dataset dataset"
|
@@ -68,12 +64,20 @@ if __name__ == '__main__':
|
|
68 |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
69 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
70 |
|
71 |
-
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, [
|
72 |
|
73 |
img_generator = data_generator.get_train_generator()
|
74 |
nb_labels = len(img_generator.get_segmentation_labels())
|
75 |
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
# Build model
|
78 |
|
79 |
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
|
@@ -104,15 +108,17 @@ if __name__ == '__main__':
|
|
104 |
config.gpu_options.allow_growth = True
|
105 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
106 |
|
|
|
|
|
107 |
sess = tf.Session(config=config)
|
108 |
tf.keras.backend.set_session(sess)
|
109 |
with sess.as_default():
|
110 |
sess.run(tf.global_variables_initializer())
|
111 |
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
|
112 |
-
for step, (in_batch, _) in progress_bar:
|
113 |
fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
|
114 |
|
115 |
-
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, False)
|
116 |
|
117 |
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
|
118 |
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
|
@@ -127,7 +133,7 @@ if __name__ == '__main__':
|
|
127 |
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
|
128 |
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
|
129 |
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
|
130 |
-
|
131 |
with h5py.File(out_file_dm, 'w') as f:
|
132 |
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
|
133 |
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
|
|
|
18 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
20 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
21 |
+
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator
|
22 |
from tqdm import tqdm
|
23 |
|
24 |
from Brain_study.data_generator import BatchGenerator
|
25 |
+
from Brain_study.Build_test_set import get_mov_centroids
|
26 |
|
27 |
from skimage.measure import regionprops
|
28 |
from scipy.interpolate import griddata
|
|
|
36 |
MISSING_CENTROID = np.asarray([[np.nan]*3])
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
if __name__ == '__main__':
|
40 |
parser = argparse.ArgumentParser()
|
41 |
parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
|
|
|
43 |
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
44 |
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
|
45 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
46 |
+
parser.add_argument('--output_shape', help='If an int, a cubic shape is presumed. Otherwise provide it as a space separated sequence', nargs='+', default=128)
|
47 |
args = parser.parse_args()
|
48 |
|
49 |
assert args.dataset != '', "Missing original dataset dataset"
|
|
|
64 |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
65 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
66 |
|
67 |
+
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, [0, 1, 2], return_isotropic_shape=True)
|
68 |
|
69 |
img_generator = data_generator.get_train_generator()
|
70 |
nb_labels = len(img_generator.get_segmentation_labels())
|
71 |
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
72 |
+
|
73 |
+
if isinstance(args.output_shape, int):
|
74 |
+
image_output_shape = [args.output_shape] * 3
|
75 |
+
elif isinstance(args.output_shape, list):
|
76 |
+
assert len(args.output_shape) == 3, 'Invalid output shape, expected three values and got {}'.format(len(args.output_shape))
|
77 |
+
image_output_shape = [int(s) for s in args.output_shape]
|
78 |
+
else:
|
79 |
+
raise ValueError('Invalid output_shape. Must be an int or a space-separated sequence of ints')
|
80 |
+
print('Scaling to: ', image_output_shape)
|
81 |
# Build model
|
82 |
|
83 |
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
|
|
|
108 |
config.gpu_options.allow_growth = True
|
109 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
110 |
|
111 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=8)
|
112 |
+
|
113 |
sess = tf.Session(config=config)
|
114 |
tf.keras.backend.set_session(sess)
|
115 |
with sess.as_default():
|
116 |
sess.run(tf.global_variables_initializer())
|
117 |
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
|
118 |
+
for step, (in_batch, _, isotropic_shape) in progress_bar:
|
119 |
fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
|
120 |
|
121 |
+
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, False, dm_interp=dm_interp)
|
122 |
|
123 |
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
|
124 |
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
|
|
|
133 |
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
|
134 |
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
|
135 |
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
|
136 |
+
f.create_dataset('isotropic_shape', data=np.squeeze(isotropic_shape))
|
137 |
with h5py.File(out_file_dm, 'w') as f:
|
138 |
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
|
139 |
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
|
COMET/COMET_train.py
CHANGED
@@ -8,7 +8,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
|
8 |
|
9 |
from datetime import datetime
|
10 |
|
11 |
-
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
12 |
from tensorflow.python.keras.utils import Progbar
|
13 |
from tensorflow.keras import Input
|
14 |
from tensorflow.keras.models import Model
|
@@ -27,6 +27,7 @@ from Brain_study.data_generator import BatchGenerator
|
|
27 |
from Brain_study.utils import SummaryDictionary, named_logs
|
28 |
|
29 |
import COMET.augmentation_constants as COMET_C
|
|
|
30 |
|
31 |
import numpy as np
|
32 |
import tensorflow as tf
|
@@ -40,7 +41,7 @@ import warnings
|
|
40 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
41 |
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
|
42 |
acc_gradients=1, batch_size=16, image_size=64,
|
43 |
-
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
44 |
# 0. Input checks
|
45 |
assert dataset_folder is not None and output_folder is not None
|
46 |
if model_file != '':
|
@@ -54,12 +55,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
54 |
if batch_size != 1 and acc_gradients != 1:
|
55 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
56 |
|
57 |
-
if
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
|
64 |
os.makedirs(output_folder, exist_ok=True)
|
65 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
@@ -159,34 +164,37 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
159 |
if model_file != '':
|
160 |
network.load_weights(model_file, by_name=True)
|
161 |
print('MODEL LOCATION: ', model_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
# 4. Freeze/unfreeze model layers
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
# msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
|
167 |
-
# print(msg)
|
168 |
-
# log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
|
169 |
-
if freeze_layers is not None:
|
170 |
-
aux = list()
|
171 |
-
for r in freeze_layers:
|
172 |
-
for l in range(*r):
|
173 |
-
network.layers[l].trainable = False
|
174 |
-
aux.append(l)
|
175 |
-
aux.sort()
|
176 |
-
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
|
177 |
else:
|
178 |
msg = "[INF] None frozen layers"
|
179 |
print(msg)
|
180 |
log_file.write(msg)
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
# x = base_model(input_new_model, training=False)
|
185 |
-
# x =
|
186 |
-
# network = keras.Model(input_new_model, x)
|
187 |
-
|
188 |
-
network.summary()
|
189 |
-
network.summary(print_fn=log_file.writelines)
|
190 |
# Complete the model with the augmentation layer
|
191 |
augm_train_input_shape = train_generator.get_data_shape()[-1]
|
192 |
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
@@ -269,6 +277,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
269 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
270 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
271 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
|
|
272 |
|
273 |
losses = {'transformer': loss_fnc,
|
274 |
'flow': vxm.losses.Grad('l2').loss}
|
@@ -281,7 +290,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
281 |
loss_weights = {'transformer': 1.,
|
282 |
'flow': rw}
|
283 |
|
284 |
-
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
285 |
network.compile(optimizer=optimizer,
|
286 |
loss=losses,
|
287 |
loss_weights=loss_weights,
|
@@ -292,6 +301,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
292 |
callback_early_stop.set_model(network)
|
293 |
callback_best_model.set_model(network)
|
294 |
callback_save_checkpoint.set_model(network)
|
|
|
295 |
|
296 |
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
297 |
names = network.metrics_names
|
@@ -303,12 +313,14 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
303 |
callback_early_stop.on_train_begin()
|
304 |
callback_best_model.on_train_begin()
|
305 |
callback_save_checkpoint.on_train_begin()
|
|
|
306 |
|
307 |
-
for epoch in range(C.EPOCHS):
|
308 |
callback_tensorboard.on_epoch_begin(epoch)
|
309 |
callback_early_stop.on_epoch_begin(epoch)
|
310 |
callback_best_model.on_epoch_begin(epoch)
|
311 |
callback_save_checkpoint.on_epoch_begin(epoch)
|
|
|
312 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
313 |
print("TRAIN")
|
314 |
|
@@ -318,6 +330,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
318 |
callback_best_model.on_train_batch_begin(step)
|
319 |
callback_save_checkpoint.on_train_batch_begin(step)
|
320 |
callback_early_stop.on_train_batch_begin(step)
|
|
|
321 |
|
322 |
try:
|
323 |
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
|
@@ -351,6 +364,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
351 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
352 |
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
|
353 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
|
|
354 |
progress_bar.update(step, zip(names, ret))
|
355 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
356 |
val_values = progress_bar._values.copy()
|
@@ -384,10 +398,12 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
384 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
385 |
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
386 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
|
|
387 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
388 |
|
389 |
callback_tensorboard.on_train_end()
|
390 |
callback_best_model.on_train_end()
|
391 |
callback_save_checkpoint.on_train_end()
|
392 |
callback_early_stop.on_train_end()
|
|
|
393 |
# 7. Wrap up
|
|
|
8 |
|
9 |
from datetime import datetime
|
10 |
|
11 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
|
12 |
from tensorflow.python.keras.utils import Progbar
|
13 |
from tensorflow.keras import Input
|
14 |
from tensorflow.keras.models import Model
|
|
|
27 |
from Brain_study.utils import SummaryDictionary, named_logs
|
28 |
|
29 |
import COMET.augmentation_constants as COMET_C
|
30 |
+
from COMET.utils import freeze_layers_by_group
|
31 |
|
32 |
import numpy as np
|
33 |
import tensorflow as tf
|
|
|
41 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
42 |
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
|
43 |
acc_gradients=1, batch_size=16, image_size=64,
|
44 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
|
45 |
# 0. Input checks
|
46 |
assert dataset_folder is not None and output_folder is not None
|
47 |
if model_file != '':
|
|
|
55 |
if batch_size != 1 and acc_gradients != 1:
|
56 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
57 |
|
58 |
+
if resume is not None:
|
59 |
+
try:
|
60 |
+
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
|
61 |
+
output_folder = resume
|
62 |
+
resume = True
|
63 |
+
except AssertionError:
|
64 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
65 |
+
resume = False
|
66 |
+
else:
|
67 |
+
resume = False
|
68 |
|
69 |
os.makedirs(output_folder, exist_ok=True)
|
70 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
|
|
164 |
if model_file != '':
|
165 |
network.load_weights(model_file, by_name=True)
|
166 |
print('MODEL LOCATION: ', model_file)
|
167 |
+
|
168 |
+
resume_epoch = 0
|
169 |
+
if resume:
|
170 |
+
cp_dir = os.path.join(output_folder, 'checkpoints')
|
171 |
+
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
|
172 |
+
if len(cp_file_list):
|
173 |
+
cp_file_list.sort()
|
174 |
+
checkpoint_file = cp_file_list[-1]
|
175 |
+
if os.path.exists(checkpoint_file):
|
176 |
+
network.load_weights(checkpoint_file, by_name=True)
|
177 |
+
print('Loaded checkpoint file: ' + checkpoint_file)
|
178 |
+
try:
|
179 |
+
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
|
180 |
+
except TypeError:
|
181 |
+
# Checkpoint file has no epoch number in the name
|
182 |
+
resume_epoch = 0
|
183 |
+
print('Resuming from epoch: {:d}'.format(resume_epoch))
|
184 |
+
else:
|
185 |
+
warnings.warn('Checkpoint file NOT found. Training from scratch')
|
186 |
+
|
187 |
# 4. Freeze/unfreeze model layers
|
188 |
+
_, frozen_layers = freeze_layers_by_group(network, freeze_layers)
|
189 |
+
if frozen_layers is not None:
|
190 |
+
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
else:
|
192 |
msg = "[INF] None frozen layers"
|
193 |
print(msg)
|
194 |
log_file.write(msg)
|
195 |
+
|
196 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
197 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines)
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
# Complete the model with the augmentation layer
|
199 |
augm_train_input_shape = train_generator.get_data_shape()[-1]
|
200 |
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
|
|
277 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
278 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
279 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
280 |
+
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
|
281 |
|
282 |
losses = {'transformer': loss_fnc,
|
283 |
'flow': vxm.losses.Grad('l2').loss}
|
|
|
290 |
loss_weights = {'transformer': 1.,
|
291 |
'flow': rw}
|
292 |
|
293 |
+
optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE)
|
294 |
network.compile(optimizer=optimizer,
|
295 |
loss=losses,
|
296 |
loss_weights=loss_weights,
|
|
|
301 |
callback_early_stop.set_model(network)
|
302 |
callback_best_model.set_model(network)
|
303 |
callback_save_checkpoint.set_model(network)
|
304 |
+
callback_lr.set_model(network)
|
305 |
|
306 |
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
307 |
names = network.metrics_names
|
|
|
313 |
callback_early_stop.on_train_begin()
|
314 |
callback_best_model.on_train_begin()
|
315 |
callback_save_checkpoint.on_train_begin()
|
316 |
+
callback_lr.on_train_begin()
|
317 |
|
318 |
+
for epoch in range(resume_epoch, C.EPOCHS):
|
319 |
callback_tensorboard.on_epoch_begin(epoch)
|
320 |
callback_early_stop.on_epoch_begin(epoch)
|
321 |
callback_best_model.on_epoch_begin(epoch)
|
322 |
callback_save_checkpoint.on_epoch_begin(epoch)
|
323 |
+
callback_lr.on_epoch_begin(epoch)
|
324 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
325 |
print("TRAIN")
|
326 |
|
|
|
330 |
callback_best_model.on_train_batch_begin(step)
|
331 |
callback_save_checkpoint.on_train_batch_begin(step)
|
332 |
callback_early_stop.on_train_batch_begin(step)
|
333 |
+
callback_lr.on_train_batch_begin(step)
|
334 |
|
335 |
try:
|
336 |
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
|
|
|
364 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
365 |
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
|
366 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
367 |
+
callback_lr.on_train_batch_end(step, named_logs(network, ret))
|
368 |
progress_bar.update(step, zip(names, ret))
|
369 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
370 |
val_values = progress_bar._values.copy()
|
|
|
398 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
399 |
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
400 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
401 |
+
callback_lr.on_epoch_end(epoch, epoch_summary)
|
402 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
403 |
|
404 |
callback_tensorboard.on_train_end()
|
405 |
callback_best_model.on_train_end()
|
406 |
callback_save_checkpoint.on_train_end()
|
407 |
callback_early_stop.on_train_end()
|
408 |
+
callback_lr.on_train_end()
|
409 |
# 7. Wrap up
|
COMET/COMET_train_UW.py
CHANGED
@@ -6,7 +6,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
|
6 |
|
7 |
from datetime import datetime
|
8 |
|
9 |
-
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
10 |
from tensorflow.python.keras.utils import Progbar
|
11 |
from tensorflow.keras import Input
|
12 |
from tensorflow.keras.models import Model
|
@@ -26,6 +26,7 @@ from Brain_study.data_generator import BatchGenerator
|
|
26 |
from Brain_study.utils import SummaryDictionary, named_logs
|
27 |
|
28 |
import COMET.augmentation_constants as COMET_C
|
|
|
29 |
|
30 |
import numpy as np
|
31 |
import tensorflow as tf
|
@@ -39,7 +40,7 @@ import warnings
|
|
39 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
|
40 |
simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3,
|
41 |
freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64,
|
42 |
-
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
43 |
# 0. Input checks
|
44 |
assert dataset_folder is not None and output_folder is not None
|
45 |
if model_file != '':
|
@@ -53,15 +54,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
53 |
if batch_size != 1 and acc_gradients != 1:
|
54 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
55 |
|
56 |
-
if
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
os.makedirs(output_folder, exist_ok=True)
|
67 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
@@ -103,8 +105,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
103 |
print(aux)
|
104 |
|
105 |
# 2. Data generator
|
|
|
106 |
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
107 |
-
C.TRAINING_PERC, labels=
|
108 |
directory_val=C.VALIDATION_DATASET)
|
109 |
|
110 |
train_generator = data_generator.get_train_generator()
|
@@ -137,22 +140,36 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
137 |
print('MODEL LOCATION: ', model_file)
|
138 |
network.load_weights(model_file, by_name=True)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
# 4. Freeze/unfreeze model layers
|
141 |
-
|
142 |
-
|
143 |
-
for
|
144 |
-
for l in range(*r):
|
145 |
-
network.layers[l].trainable = False
|
146 |
-
aux.append(l)
|
147 |
-
aux.sort()
|
148 |
-
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
|
149 |
else:
|
150 |
msg = "[INF] None frozen layers"
|
151 |
print(msg)
|
152 |
log_file.write(msg)
|
153 |
|
154 |
-
network.summary()
|
155 |
-
network.summary(print_fn=log_file.write)
|
156 |
# Complete the model with the augmentation layer
|
157 |
input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train')
|
158 |
augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
@@ -250,7 +267,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
250 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
251 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
252 |
|
253 |
-
|
|
|
|
|
254 |
full_model.compile(optimizer=optimizer,
|
255 |
loss=None, )
|
256 |
|
@@ -259,6 +278,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
259 |
callback_early_stop.set_model(full_model)
|
260 |
callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
|
261 |
callback_save_checkpoint.set_model(network) # ONLY SAVE THE NETWORK!!!
|
|
|
262 |
|
263 |
summary = SummaryDictionary(full_model, C.BATCH_SIZE)
|
264 |
names = full_model.metrics_names
|
@@ -271,12 +291,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
271 |
callback_early_stop.on_train_begin()
|
272 |
callback_best_model.on_train_begin()
|
273 |
callback_save_checkpoint.on_train_begin()
|
|
|
274 |
|
275 |
-
for epoch in range(C.EPOCHS):
|
276 |
callback_tensorboard.on_epoch_begin(epoch)
|
277 |
callback_early_stop.on_epoch_begin(epoch)
|
278 |
callback_best_model.on_epoch_begin(epoch)
|
279 |
callback_save_checkpoint.on_epoch_begin(epoch)
|
|
|
|
|
280 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
281 |
print("TRAIN")
|
282 |
|
@@ -287,6 +310,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
287 |
callback_best_model.on_train_batch_begin(step)
|
288 |
callback_save_checkpoint.on_train_batch_begin(step)
|
289 |
callback_early_stop.on_train_batch_begin(step)
|
|
|
290 |
|
291 |
try:
|
292 |
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
@@ -310,6 +334,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
310 |
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
|
311 |
callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret))
|
312 |
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
|
|
|
313 |
progress_bar.update(step, zip(names, ret))
|
314 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
315 |
val_values = progress_bar._values.copy()
|
@@ -341,10 +366,12 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
341 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
342 |
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
343 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
|
|
344 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
345 |
|
346 |
callback_tensorboard.on_train_end()
|
347 |
callback_best_model.on_train_end()
|
348 |
callback_save_checkpoint.on_train_end()
|
349 |
callback_early_stop.on_train_end()
|
|
|
350 |
# 7. Wrap up
|
|
|
6 |
|
7 |
from datetime import datetime
|
8 |
|
9 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
|
10 |
from tensorflow.python.keras.utils import Progbar
|
11 |
from tensorflow.keras import Input
|
12 |
from tensorflow.keras.models import Model
|
|
|
26 |
from Brain_study.utils import SummaryDictionary, named_logs
|
27 |
|
28 |
import COMET.augmentation_constants as COMET_C
|
29 |
+
from COMET.utils import freeze_layers_by_group
|
30 |
|
31 |
import numpy as np
|
32 |
import tensorflow as tf
|
|
|
40 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
|
41 |
simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, early_stop_patience=1000, prior_reg_w=5e-3,
|
42 |
freeze_layers=None, acc_gradients=1, batch_size=16, image_size=64,
|
43 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
|
44 |
# 0. Input checks
|
45 |
assert dataset_folder is not None and output_folder is not None
|
46 |
if model_file != '':
|
|
|
54 |
if batch_size != 1 and acc_gradients != 1:
|
55 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
56 |
|
57 |
+
if resume is not None:
|
58 |
+
try:
|
59 |
+
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
|
60 |
+
output_folder = resume
|
61 |
+
resume = True
|
62 |
+
except AssertionError:
|
63 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
64 |
+
resume = False
|
65 |
+
else:
|
66 |
+
resume = False
|
67 |
|
68 |
os.makedirs(output_folder, exist_ok=True)
|
69 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
|
|
105 |
print(aux)
|
106 |
|
107 |
# 2. Data generator
|
108 |
+
used_labels = [0, 1, 2]
|
109 |
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
110 |
+
C.TRAINING_PERC, labels=used_labels, combine_segmentations=False,
|
111 |
directory_val=C.VALIDATION_DATASET)
|
112 |
|
113 |
train_generator = data_generator.get_train_generator()
|
|
|
140 |
print('MODEL LOCATION: ', model_file)
|
141 |
network.load_weights(model_file, by_name=True)
|
142 |
|
143 |
+
resume_epoch = 0
|
144 |
+
if resume:
|
145 |
+
cp_dir = os.path.join(output_folder, 'checkpoints')
|
146 |
+
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
|
147 |
+
if len(cp_file_list):
|
148 |
+
cp_file_list.sort()
|
149 |
+
checkpoint_file = cp_file_list[-1]
|
150 |
+
if os.path.exists(checkpoint_file):
|
151 |
+
network.load_weights(checkpoint_file, by_name=True)
|
152 |
+
print('Loaded checkpoint file: ' + checkpoint_file)
|
153 |
+
try:
|
154 |
+
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
|
155 |
+
except TypeError:
|
156 |
+
# Checkpoint file has no epoch number in the name
|
157 |
+
resume_epoch = 0
|
158 |
+
print('Resuming from epoch: {:d}'.format(resume_epoch))
|
159 |
+
else:
|
160 |
+
warnings.warn('Checkpoint file NOT found. Training from scratch')
|
161 |
+
|
162 |
# 4. Freeze/unfreeze model layers
|
163 |
+
_, frozen_layers = freeze_layers_by_group(network, freeze_layers)
|
164 |
+
if frozen_layers is not None:
|
165 |
+
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers]))
|
|
|
|
|
|
|
|
|
|
|
166 |
else:
|
167 |
msg = "[INF] None frozen layers"
|
168 |
print(msg)
|
169 |
log_file.write(msg)
|
170 |
|
171 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
172 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.write)
|
173 |
# Complete the model with the augmentation layer
|
174 |
input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train')
|
175 |
augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
|
|
267 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
268 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
269 |
|
270 |
+
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
|
271 |
+
|
272 |
+
optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE)
|
273 |
full_model.compile(optimizer=optimizer,
|
274 |
loss=None, )
|
275 |
|
|
|
278 |
callback_early_stop.set_model(full_model)
|
279 |
callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
|
280 |
callback_save_checkpoint.set_model(network) # ONLY SAVE THE NETWORK!!!
|
281 |
+
callback_lr.set_model(full_model)
|
282 |
|
283 |
summary = SummaryDictionary(full_model, C.BATCH_SIZE)
|
284 |
names = full_model.metrics_names
|
|
|
291 |
callback_early_stop.on_train_begin()
|
292 |
callback_best_model.on_train_begin()
|
293 |
callback_save_checkpoint.on_train_begin()
|
294 |
+
callback_lr.on_train_begin()
|
295 |
|
296 |
+
for epoch in range(resume_epoch, C.EPOCHS):
|
297 |
callback_tensorboard.on_epoch_begin(epoch)
|
298 |
callback_early_stop.on_epoch_begin(epoch)
|
299 |
callback_best_model.on_epoch_begin(epoch)
|
300 |
callback_save_checkpoint.on_epoch_begin(epoch)
|
301 |
+
callback_lr.on_epoch_begin(epoch)
|
302 |
+
|
303 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
304 |
print("TRAIN")
|
305 |
|
|
|
310 |
callback_best_model.on_train_batch_begin(step)
|
311 |
callback_save_checkpoint.on_train_batch_begin(step)
|
312 |
callback_early_stop.on_train_batch_begin(step)
|
313 |
+
callback_lr.on_train_batch_begin(step)
|
314 |
|
315 |
try:
|
316 |
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
|
|
334 |
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
|
335 |
callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret))
|
336 |
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
|
337 |
+
callback_lr.on_train_batch_end(step, named_logs(full_model, ret))
|
338 |
progress_bar.update(step, zip(names, ret))
|
339 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
340 |
val_values = progress_bar._values.copy()
|
|
|
366 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
367 |
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
368 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
369 |
+
callback_lr.on_epoch_end(epoch, epoch_summary)
|
370 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
371 |
|
372 |
callback_tensorboard.on_train_end()
|
373 |
callback_best_model.on_train_end()
|
374 |
callback_save_checkpoint.on_train_end()
|
375 |
callback_early_stop.on_train_end()
|
376 |
+
callback_lr.on_train_end()
|
377 |
# 7. Wrap up
|
COMET/COMET_train_seggguided.py
CHANGED
@@ -8,7 +8,7 @@ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
|
8 |
|
9 |
from datetime import datetime
|
10 |
|
11 |
-
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
12 |
from tensorflow.python.keras.utils import Progbar
|
13 |
from tensorflow.keras import Input
|
14 |
from tensorflow.keras.models import Model
|
@@ -27,6 +27,7 @@ from Brain_study.data_generator import BatchGenerator
|
|
27 |
from Brain_study.utils import SummaryDictionary, named_logs
|
28 |
|
29 |
import COMET.augmentation_constants as COMET_C
|
|
|
30 |
|
31 |
import numpy as np
|
32 |
import tensorflow as tf
|
@@ -40,7 +41,7 @@ import warnings
|
|
40 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
41 |
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
|
42 |
acc_gradients=1, batch_size=16, image_size=64,
|
43 |
-
unet=[16, 32, 64, 128, 256], head=[16, 16]):
|
44 |
# 0. Input checks
|
45 |
assert dataset_folder is not None and output_folder is not None
|
46 |
if model_file != '':
|
@@ -54,12 +55,16 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
54 |
if batch_size != 1 and acc_gradients != 1:
|
55 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
56 |
|
57 |
-
if
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
63 |
|
64 |
os.makedirs(output_folder, exist_ok=True)
|
65 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
@@ -101,9 +106,9 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
101 |
print(aux)
|
102 |
|
103 |
# 2. Data generator
|
104 |
-
used_labels =
|
105 |
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
106 |
-
C.TRAINING_PERC, labels=
|
107 |
directory_val=C.VALIDATION_DATASET)
|
108 |
|
109 |
train_generator = data_generator.get_train_generator()
|
@@ -163,34 +168,37 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
163 |
if model_file != '':
|
164 |
network.load_weights(model_file, by_name=True)
|
165 |
print('MODEL LOCATION: ', model_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
# 4. Freeze/unfreeze model layers
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
# msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
|
171 |
-
# print(msg)
|
172 |
-
# log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
|
173 |
-
if freeze_layers is not None:
|
174 |
-
aux = list()
|
175 |
-
for r in freeze_layers:
|
176 |
-
for l in range(*r):
|
177 |
-
network.layers[l].trainable = False
|
178 |
-
aux.append(l)
|
179 |
-
aux.sort()
|
180 |
-
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
|
181 |
else:
|
182 |
msg = "[INF] None frozen layers"
|
183 |
print(msg)
|
184 |
log_file.write(msg)
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
# x = base_model(input_new_model, training=False)
|
189 |
-
# x =
|
190 |
-
# network = keras.Model(input_new_model, x)
|
191 |
-
|
192 |
-
network.summary()
|
193 |
-
network.summary(print_fn=log_file.writelines)
|
194 |
# Complete the model with the augmentation layer
|
195 |
augm_train_input_shape = train_generator.get_data_shape()[0]
|
196 |
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
@@ -282,6 +290,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
282 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
283 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
284 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
|
|
285 |
|
286 |
losses = {'transformer': loss_fnc,
|
287 |
'seg_transformer': loss_segm,
|
@@ -300,8 +309,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
300 |
'seg_transformer': 1.,
|
301 |
'flow': rw}
|
302 |
|
303 |
-
|
304 |
-
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
305 |
network.compile(optimizer=optimizer,
|
306 |
loss=losses,
|
307 |
loss_weights=loss_weights,
|
@@ -312,6 +320,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
312 |
callback_early_stop.set_model(network)
|
313 |
callback_best_model.set_model(network)
|
314 |
callback_save_checkpoint.set_model(network)
|
|
|
315 |
|
316 |
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
317 |
names = network.metrics_names
|
@@ -323,12 +332,15 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
323 |
callback_early_stop.on_train_begin()
|
324 |
callback_best_model.on_train_begin()
|
325 |
callback_save_checkpoint.on_train_begin()
|
|
|
326 |
|
327 |
-
for epoch in range(C.EPOCHS):
|
328 |
callback_tensorboard.on_epoch_begin(epoch)
|
329 |
callback_early_stop.on_epoch_begin(epoch)
|
330 |
callback_best_model.on_epoch_begin(epoch)
|
331 |
callback_save_checkpoint.on_epoch_begin(epoch)
|
|
|
|
|
332 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
333 |
print("TRAIN")
|
334 |
|
@@ -338,6 +350,7 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
338 |
callback_best_model.on_train_batch_begin(step)
|
339 |
callback_save_checkpoint.on_train_batch_begin(step)
|
340 |
callback_early_stop.on_train_batch_begin(step)
|
|
|
341 |
|
342 |
try:
|
343 |
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
|
@@ -371,6 +384,8 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
371 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
372 |
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
|
373 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
|
|
|
|
374 |
progress_bar.update(step, zip(names, ret))
|
375 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
376 |
val_values = progress_bar._values.copy()
|
@@ -405,10 +420,13 @@ def launch_train(dataset_folder, validation_folder, output_folder, model_file, g
|
|
405 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
406 |
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
407 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
|
|
|
|
408 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
409 |
|
410 |
callback_tensorboard.on_train_end()
|
411 |
callback_best_model.on_train_end()
|
412 |
callback_save_checkpoint.on_train_end()
|
413 |
callback_early_stop.on_train_end()
|
|
|
414 |
# 7. Wrap up
|
|
|
8 |
|
9 |
from datetime import datetime
|
10 |
|
11 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping, ReduceLROnPlateau
|
12 |
from tensorflow.python.keras.utils import Progbar
|
13 |
from tensorflow.keras import Input
|
14 |
from tensorflow.keras.models import Model
|
|
|
27 |
from Brain_study.utils import SummaryDictionary, named_logs
|
28 |
|
29 |
import COMET.augmentation_constants as COMET_C
|
30 |
+
from COMET.utils import freeze_layers_by_group
|
31 |
|
32 |
import numpy as np
|
33 |
import tensorflow as tf
|
|
|
41 |
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
42 |
segm='dice', max_epochs=C.EPOCHS, early_stop_patience=1000, freeze_layers=None,
|
43 |
acc_gradients=1, batch_size=16, image_size=64,
|
44 |
+
unet=[16, 32, 64, 128, 256], head=[16, 16], resume=None):
|
45 |
# 0. Input checks
|
46 |
assert dataset_folder is not None and output_folder is not None
|
47 |
if model_file != '':
|
|
|
55 |
if batch_size != 1 and acc_gradients != 1:
|
56 |
warnings.warn('WARNING: Batch size and Accumulative gradient step are set!')
|
57 |
|
58 |
+
if resume is not None:
|
59 |
+
try:
|
60 |
+
assert os.path.exists(resume) and len(os.listdir(os.path.join(resume, 'checkpoints'))), 'Invalid directory: ' + resume
|
61 |
+
output_folder = resume
|
62 |
+
resume = True
|
63 |
+
except AssertionError:
|
64 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
65 |
+
resume = False
|
66 |
+
else:
|
67 |
+
resume = False
|
68 |
|
69 |
os.makedirs(output_folder, exist_ok=True)
|
70 |
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
|
|
106 |
print(aux)
|
107 |
|
108 |
# 2. Data generator
|
109 |
+
used_labels = [0, 1, 2]
|
110 |
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
111 |
+
C.TRAINING_PERC, labels=used_labels, combine_segmentations=False,
|
112 |
directory_val=C.VALIDATION_DATASET)
|
113 |
|
114 |
train_generator = data_generator.get_train_generator()
|
|
|
168 |
if model_file != '':
|
169 |
network.load_weights(model_file, by_name=True)
|
170 |
print('MODEL LOCATION: ', model_file)
|
171 |
+
|
172 |
+
resume_epoch = 0
|
173 |
+
if resume:
|
174 |
+
cp_dir = os.path.join(output_folder, 'checkpoints')
|
175 |
+
cp_file_list = [os.path.join(cp_dir, f) for f in os.listdir(cp_dir) if (f.startswith('checkpoint') and f.endswith('.h5'))]
|
176 |
+
if len(cp_file_list):
|
177 |
+
cp_file_list.sort()
|
178 |
+
checkpoint_file = cp_file_list[-1]
|
179 |
+
if os.path.exists(checkpoint_file):
|
180 |
+
network.load_weights(checkpoint_file, by_name=True)
|
181 |
+
print('Loaded checkpoint file: ' + checkpoint_file)
|
182 |
+
try:
|
183 |
+
resume_epoch = int(re.match('checkpoint\.(\d+)-*.h5', os.path.split(checkpoint_file)[-1])[1])
|
184 |
+
except TypeError:
|
185 |
+
# Checkpoint file has no epoch number in the name
|
186 |
+
resume_epoch = 0
|
187 |
+
print('Resuming from epoch: {:d}'.format(resume_epoch))
|
188 |
+
else:
|
189 |
+
warnings.warn('Checkpoint file NOT found. Training from scratch')
|
190 |
+
|
191 |
# 4. Freeze/unfreeze model layers
|
192 |
+
_, frozen_layers = freeze_layers_by_group(network, freeze_layers)
|
193 |
+
if frozen_layers is not None:
|
194 |
+
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in frozen_layers]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
else:
|
196 |
msg = "[INF] None frozen layers"
|
197 |
print(msg)
|
198 |
log_file.write(msg)
|
199 |
+
|
200 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
201 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH, print_fn=log_file.writelines)
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
# Complete the model with the augmentation layer
|
203 |
augm_train_input_shape = train_generator.get_data_shape()[0]
|
204 |
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
|
|
290 |
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
291 |
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
292 |
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
293 |
+
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10)
|
294 |
|
295 |
losses = {'transformer': loss_fnc,
|
296 |
'seg_transformer': loss_segm,
|
|
|
309 |
'seg_transformer': 1.,
|
310 |
'flow': rw}
|
311 |
|
312 |
+
optimizer = AdamAccumulated(accumulation_steps=C.ACCUM_GRADIENT_STEP, learning_rate=C.LEARNING_RATE)
|
|
|
313 |
network.compile(optimizer=optimizer,
|
314 |
loss=losses,
|
315 |
loss_weights=loss_weights,
|
|
|
320 |
callback_early_stop.set_model(network)
|
321 |
callback_best_model.set_model(network)
|
322 |
callback_save_checkpoint.set_model(network)
|
323 |
+
callback_lr.set_model(network)
|
324 |
|
325 |
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
326 |
names = network.metrics_names
|
|
|
332 |
callback_early_stop.on_train_begin()
|
333 |
callback_best_model.on_train_begin()
|
334 |
callback_save_checkpoint.on_train_begin()
|
335 |
+
callback_lr.on_train_begin()
|
336 |
|
337 |
+
for epoch in range(resume_epoch, C.EPOCHS):
|
338 |
callback_tensorboard.on_epoch_begin(epoch)
|
339 |
callback_early_stop.on_epoch_begin(epoch)
|
340 |
callback_best_model.on_epoch_begin(epoch)
|
341 |
callback_save_checkpoint.on_epoch_begin(epoch)
|
342 |
+
callback_lr.on_epoch_begin(epoch)
|
343 |
+
|
344 |
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
345 |
print("TRAIN")
|
346 |
|
|
|
350 |
callback_best_model.on_train_batch_begin(step)
|
351 |
callback_save_checkpoint.on_train_batch_begin(step)
|
352 |
callback_early_stop.on_train_batch_begin(step)
|
353 |
+
callback_lr.on_train_batch_begin(step)
|
354 |
|
355 |
try:
|
356 |
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
|
|
|
384 |
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
385 |
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
|
386 |
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
387 |
+
callback_lr.on_predict_batch_end(step, named_logs(network, ret))
|
388 |
+
|
389 |
progress_bar.update(step, zip(names, ret))
|
390 |
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
391 |
val_values = progress_bar._values.copy()
|
|
|
420 |
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
421 |
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
422 |
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
423 |
+
callback_lr.on_epoch_end(epoch,epoch_summary)
|
424 |
+
|
425 |
print('End of epoch {}: '.format(epoch), ret, '\n')
|
426 |
|
427 |
callback_tensorboard.on_train_end()
|
428 |
callback_best_model.on_train_end()
|
429 |
callback_save_checkpoint.on_train_end()
|
430 |
callback_early_stop.on_train_end()
|
431 |
+
callback_lr.on_train_end()
|
432 |
# 7. Wrap up
|
COMET/Evaluate_network.py
CHANGED
@@ -17,9 +17,10 @@ import tensorflow as tf
|
|
17 |
import numpy as np
|
18 |
import pandas as pd
|
19 |
import voxelmorph as vxm
|
|
|
20 |
|
21 |
import DeepDeformationMapRegistration.utils.constants as C
|
22 |
-
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
23 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
24 |
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
25 |
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
@@ -27,10 +28,14 @@ from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimila
|
|
27 |
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
28 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
29 |
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
|
|
|
|
|
30 |
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
31 |
from scipy.interpolate import RegularGridInterpolator
|
32 |
from tqdm import tqdm
|
33 |
|
|
|
|
|
34 |
import h5py
|
35 |
import re
|
36 |
from Brain_study.data_generator import BatchGenerator
|
@@ -44,6 +49,7 @@ import neurite as ne
|
|
44 |
|
45 |
|
46 |
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed'
|
|
|
47 |
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/checkpoints/best_model.h5'
|
48 |
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/'
|
49 |
|
@@ -56,6 +62,7 @@ if __name__ == '__main__':
|
|
56 |
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', default=DATASET)
|
57 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
58 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
|
|
59 |
args = parser.parse_args()
|
60 |
if args.model is not None:
|
61 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
@@ -82,6 +89,9 @@ if __name__ == '__main__':
|
|
82 |
DATASET = args.dataset
|
83 |
list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
|
84 |
list_test_files.sort()
|
|
|
|
|
|
|
85 |
|
86 |
with h5py.File(list_test_files[0], 'r') as f:
|
87 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
@@ -96,8 +106,8 @@ if __name__ == '__main__':
|
|
96 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
97 |
config.allow_soft_placement = True
|
98 |
|
99 |
-
sess = tf.Session(config=config)
|
100 |
-
tf.keras.backend.set_session(sess)
|
101 |
|
102 |
# Loss and metric functions. Common to all models
|
103 |
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
@@ -114,24 +124,24 @@ if __name__ == '__main__':
|
|
114 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
115 |
|
116 |
### METRICS GRAPH ###
|
117 |
-
fix_img_ph = tf.placeholder(tf.float32, (1,
|
118 |
-
pred_img_ph = tf.placeholder(tf.float32, (1,
|
119 |
-
fix_seg_ph = tf.placeholder(tf.float32, (1,
|
120 |
-
pred_seg_ph = tf.placeholder(tf.float32, (1,
|
121 |
|
122 |
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
123 |
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
124 |
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
125 |
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
126 |
-
dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
127 |
-
hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
128 |
-
dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
129 |
# hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
|
130 |
|
131 |
# Needed for VxmDense type of network
|
132 |
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
133 |
|
134 |
-
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
|
135 |
|
136 |
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
137 |
print('MODEL LOCATION: ', MODEL_FILE)
|
@@ -144,6 +154,14 @@ if __name__ == '__main__':
|
|
144 |
os.makedirs(output_folder, exist_ok=True)
|
145 |
print('DESTINATION FOLDER: ', output_folder)
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
try:
|
148 |
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
149 |
'VxmDense': vxm.networks.VxmDense,
|
@@ -172,10 +190,16 @@ if __name__ == '__main__':
|
|
172 |
with open(metrics_file, 'w') as f:
|
173 |
f.write(';'.join(csv_header)+'\n')
|
174 |
|
|
|
|
|
|
|
|
|
|
|
175 |
ssim = ncc = mse = ms_ssim = dice = hd = 0
|
176 |
with sess.as_default():
|
177 |
sess.run(tf.global_variables_initializer())
|
178 |
network.load_weights(MODEL_FILE, by_name=True)
|
|
|
179 |
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
180 |
for step, in_batch in progress_bar:
|
181 |
with h5py.File(in_batch, 'r') as f:
|
@@ -184,6 +208,8 @@ if __name__ == '__main__':
|
|
184 |
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
185 |
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
186 |
fix_centroids = f['fix_centroids'][:]
|
|
|
|
|
187 |
|
188 |
if network.name == 'vxm_dense_semi_supervised_seg':
|
189 |
t0 = time.time()
|
@@ -196,29 +222,40 @@ if __name__ == '__main__':
|
|
196 |
t1 = time.time()
|
197 |
|
198 |
pred_img = min_max_norm(pred_img)
|
199 |
-
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(
|
200 |
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
201 |
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
# I need the labels to be OHE to compute the segmentation metrics.
|
204 |
-
dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
|
|
|
|
|
|
205 |
|
206 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
207 |
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
208 |
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
209 |
|
210 |
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
|
|
|
211 |
ms_ssim = ms_ssim[0]
|
212 |
|
213 |
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
#
|
221 |
-
pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
222 |
# Now we can measure the TRE in mm
|
223 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
224 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
@@ -249,15 +286,103 @@ if __name__ == '__main__':
|
|
249 |
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
250 |
# plt.close()
|
251 |
|
252 |
-
plot_predictions(fix_img, mov_img, disp_map,
|
253 |
-
plot_predictions(
|
254 |
-
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
-
|
257 |
|
258 |
print('Summary\n=======\n')
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
print('\n=======\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
tf.keras.backend.clear_session()
|
262 |
# sess.close()
|
263 |
del network
|
|
|
17 |
import numpy as np
|
18 |
import pandas as pd
|
19 |
import voxelmorph as vxm
|
20 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
21 |
|
22 |
import DeepDeformationMapRegistration.utils.constants as C
|
23 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
|
24 |
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
25 |
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
26 |
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
|
|
28 |
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
29 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
30 |
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
|
31 |
+
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation, GaussianFilter
|
32 |
+
import medpy.metric as medpy_metrics
|
33 |
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
34 |
from scipy.interpolate import RegularGridInterpolator
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
+
from scipy.ndimage import gaussian_filter, zoom
|
38 |
+
|
39 |
import h5py
|
40 |
import re
|
41 |
from Brain_study.data_generator import BatchGenerator
|
|
|
49 |
|
50 |
|
51 |
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed'
|
52 |
+
DATASET_FR = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed/full_res'
|
53 |
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/checkpoints/best_model.h5'
|
54 |
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/'
|
55 |
|
|
|
62 |
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', default=DATASET)
|
63 |
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
64 |
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
65 |
+
parser.add_argument('--fullres', action='store_true', default=False)
|
66 |
args = parser.parse_args()
|
67 |
if args.model is not None:
|
68 |
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
|
|
89 |
DATASET = args.dataset
|
90 |
list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
|
91 |
list_test_files.sort()
|
92 |
+
if args.fullres:
|
93 |
+
list_test_fr_files = [os.path.join(DATASET_FR, f) for f in os.listdir(DATASET_FR) if f.endswith('h5') and 'dm' not in f]
|
94 |
+
list_test_fr_files.sort()
|
95 |
|
96 |
with h5py.File(list_test_files[0], 'r') as f:
|
97 |
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
|
|
106 |
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
107 |
config.allow_soft_placement = True
|
108 |
|
109 |
+
sess = tf.compat.v1.Session(config=config)
|
110 |
+
tf.compat.v1.keras.backend.set_session(sess)
|
111 |
|
112 |
# Loss and metric functions. Common to all models
|
113 |
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
|
|
124 |
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
125 |
|
126 |
### METRICS GRAPH ###
|
127 |
+
fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
|
128 |
+
pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
|
129 |
+
fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
|
130 |
+
pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
|
131 |
|
132 |
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
133 |
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
134 |
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
135 |
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
136 |
+
# dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
137 |
+
# hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
138 |
+
# dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
139 |
# hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
|
140 |
|
141 |
# Needed for VxmDense type of network
|
142 |
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
143 |
|
144 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
|
145 |
|
146 |
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
147 |
print('MODEL LOCATION: ', MODEL_FILE)
|
|
|
154 |
os.makedirs(output_folder, exist_ok=True)
|
155 |
print('DESTINATION FOLDER: ', output_folder)
|
156 |
|
157 |
+
if args.fullres:
|
158 |
+
output_folder_fr = os.path.join(DATA_ROOT_DIR, args.outdirname, 'full_resolution') # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
|
159 |
+
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
|
160 |
+
if args.erase:
|
161 |
+
shutil.rmtree(output_folder_fr, ignore_errors=True)
|
162 |
+
os.makedirs(output_folder_fr, exist_ok=True)
|
163 |
+
print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
|
164 |
+
|
165 |
try:
|
166 |
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
167 |
'VxmDense': vxm.networks.VxmDense,
|
|
|
190 |
with open(metrics_file, 'w') as f:
|
191 |
f.write(';'.join(csv_header)+'\n')
|
192 |
|
193 |
+
if args.fullres:
|
194 |
+
metrics_file_fr = os.path.join(output_folder_fr, 'metrics.csv')
|
195 |
+
with open(metrics_file_fr, 'w') as f:
|
196 |
+
f.write(';'.join(csv_header) + '\n')
|
197 |
+
|
198 |
ssim = ncc = mse = ms_ssim = dice = hd = 0
|
199 |
with sess.as_default():
|
200 |
sess.run(tf.global_variables_initializer())
|
201 |
network.load_weights(MODEL_FILE, by_name=True)
|
202 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
203 |
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
204 |
for step, in_batch in progress_bar:
|
205 |
with h5py.File(in_batch, 'r') as f:
|
|
|
208 |
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
209 |
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
210 |
fix_centroids = f['fix_centroids'][:]
|
211 |
+
isotropic_shape = f['isotropic_shape'][:]
|
212 |
+
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
213 |
|
214 |
if network.name == 'vxm_dense_semi_supervised_seg':
|
215 |
t0 = time.time()
|
|
|
222 |
t1 = time.time()
|
223 |
|
224 |
pred_img = min_max_norm(pred_img)
|
225 |
+
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels+1), brain_study=False)
|
226 |
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
227 |
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
228 |
|
229 |
+
# Up sample the segmentation masks to isotropic resolution
|
230 |
+
zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
|
231 |
+
pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
|
232 |
+
fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
|
233 |
+
|
234 |
+
pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
235 |
+
fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
236 |
+
|
237 |
# I need the labels to be OHE to compute the segmentation metrics.
|
238 |
+
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
239 |
+
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
240 |
+
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
241 |
+
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
242 |
|
243 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
244 |
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
245 |
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
246 |
|
247 |
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
|
248 |
+
ssim = np.mean(ssim)
|
249 |
ms_ssim = ms_ssim[0]
|
250 |
|
251 |
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
252 |
+
fix_centroids_isotropic = fix_centroids * voxel_size
|
253 |
+
# mov_centroids_isotropic = mov_centroids * voxel_size
|
254 |
+
pred_centroids_isotropic = pred_centroids * voxel_size
|
255 |
+
|
256 |
+
# fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
257 |
+
# # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
258 |
+
# pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
|
|
259 |
# Now we can measure the TRE in mm
|
260 |
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
261 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
|
|
286 |
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
287 |
# plt.close()
|
288 |
|
289 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
|
290 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
|
291 |
+
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
|
292 |
+
|
293 |
+
progress_bar.set_description('SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
|
294 |
+
|
295 |
+
if args.fullres:
|
296 |
+
with h5py.File(list_test_fr_files[step - 1], 'r') as f:
|
297 |
+
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
298 |
+
mov_img = f['mov_image'][:][np.newaxis, ...]
|
299 |
+
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
300 |
+
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
301 |
+
fix_centroids = f['fix_centroids'][:]
|
302 |
+
|
303 |
+
# Up sample the displacement map to the full res
|
304 |
+
trf = scale_transformation(image_output_shape, fix_img.shape[1:-1])
|
305 |
+
disp_map_fr = resize_displacement_map(np.squeeze(disp_map), None, trf)[np.newaxis, ...]
|
306 |
+
disp_map_fr = gaussian_filter(disp_map_fr, 5)
|
307 |
+
# disp_mad_fr = sess.run(smooth_filter, feed_dict={'dm:0': disp_map_fr})
|
308 |
+
|
309 |
+
# Predicted image
|
310 |
+
pred_img_fr = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([mov_img, disp_map_fr]).eval()
|
311 |
+
pred_seg_fr = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([mov_seg, disp_map_fr]).eval()
|
312 |
+
|
313 |
+
# Predicted centroids
|
314 |
+
dm_interp_fr = DisplacementMapInterpolator(fix_img.shape[1:-1], 'griddata', step=2)
|
315 |
+
pred_centroids = dm_interp_fr(disp_map_fr, mov_centroids, backwards=True) + mov_centroids
|
316 |
+
|
317 |
+
# Metrics - segmentation
|
318 |
+
dice = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) / np.sum(fix_seg[..., l]) for l in range(nb_labels)])
|
319 |
+
hd = np.mean(safe_medpy_metric(pred_seg[0, ...], fix_seg[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
320 |
+
dice_macro = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) for l in range(nb_labels)])
|
321 |
+
|
322 |
+
pred_seg_card_fr = segmentation_ohe_to_cardinal(pred_seg_fr).astype(np.float32)
|
323 |
+
mov_seg_card_fr = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
324 |
+
fix_seg_card_fr = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
325 |
+
|
326 |
+
# Metrics - image
|
327 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
328 |
+
{'fix_img:0': fix_img, 'pred_img:0': pred_img_fr})
|
329 |
+
ssim = np.mean(ssim)
|
330 |
+
ms_ssim = ms_ssim[0]
|
331 |
+
|
332 |
+
# Metrics - registration
|
333 |
+
tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
|
334 |
+
|
335 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1 - t0, tre, len(missing_lbls),
|
336 |
+
missing_lbls]
|
337 |
+
with open(metrics_file_fr, 'a') as f:
|
338 |
+
f.write(';'.join(map(str, new_line)) + '\n')
|
339 |
+
|
340 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
341 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
342 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
343 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
344 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
345 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
346 |
+
|
347 |
+
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
348 |
+
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
349 |
+
# f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
|
350 |
+
# f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
|
351 |
+
# f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
|
352 |
+
# f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
|
353 |
+
|
354 |
+
# magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
|
355 |
+
# _ = plt.hist(magnitude.flatten())
|
356 |
+
# plt.title('Histogram of disp. magnitudes')
|
357 |
+
# plt.show(block=False)
|
358 |
+
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
359 |
+
# plt.close()
|
360 |
+
|
361 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, seg_batches=[fix_seg_card_fr, mov_seg_card_fr, pred_seg_card_fr], filename=os.path.join(output_folder_fr, '{:03d}_figures_seg.png'.format(step)), show=False, step=10)
|
362 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, filename=os.path.join(output_folder_fr, '{:03d}_figures_img.png'.format(step)), show=False, step=10)
|
363 |
+
# save_disp_map_img(disp_map_fr, 'Displacement map', os.path.join(output_folder_fr, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=10)
|
364 |
|
365 |
+
progress_bar.set_description('[FR] SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
|
366 |
|
367 |
print('Summary\n=======\n')
|
368 |
+
metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
|
369 |
+
print('\nAVG:\n')
|
370 |
+
print(metrics_df.mean(axis=0))
|
371 |
+
print('\nSTD:\n')
|
372 |
+
print(metrics_df.std(axis=0))
|
373 |
+
print('\nHD95perc:\n')
|
374 |
+
print(metrics_df['HD'].describe(percentiles=[.95]))
|
375 |
print('\n=======\n')
|
376 |
+
if args.fullres:
|
377 |
+
print('Summary full resolution\n=======\n')
|
378 |
+
metrics_df = pd.read_csv(metrics_file_fr, sep=';', header=0)
|
379 |
+
print('\nAVG:\n')
|
380 |
+
print(metrics_df.mean(axis=0))
|
381 |
+
print('\nSTD:\n')
|
382 |
+
print(metrics_df.std(axis=0))
|
383 |
+
print('\nHD95perc:\n')
|
384 |
+
print(metrics_df['HD'].describe(percentiles=[.95]))
|
385 |
+
print('\n=======\n')
|
386 |
tf.keras.backend.clear_session()
|
387 |
# sess.close()
|
388 |
del network
|
COMET/MultiTrain_config.py
CHANGED
@@ -9,6 +9,8 @@ from shutil import copy2
|
|
9 |
import os
|
10 |
from datetime import datetime
|
11 |
import DeepDeformationMapRegistration.utils.constants as C
|
|
|
|
|
12 |
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
|
13 |
|
14 |
err = list()
|
@@ -58,10 +60,11 @@ if __name__ == '__main__':
|
|
58 |
except KeyError as err:
|
59 |
froozen_layers = None
|
60 |
except NameError as err:
|
61 |
-
froozen_layers =
|
|
|
62 |
if froozen_layers is not None:
|
63 |
-
assert all(s in
|
64 |
-
'Invalid option for "freeze". Expected one or several of:
|
65 |
froozen_layers = list(set(froozen_layers)) # Unique elements
|
66 |
|
67 |
if augmentationConfig:
|
@@ -70,11 +73,20 @@ if __name__ == '__main__':
|
|
70 |
|
71 |
|
72 |
# copy the configuration file to the destionation folder
|
73 |
-
os.makedirs(output_folder, exist_ok=True)
|
74 |
copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
|
75 |
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
launch_train(dataset_folder=datasetConfig['train'],
|
80 |
validation_folder=datasetConfig['validation'],
|
@@ -85,10 +97,12 @@ if __name__ == '__main__':
|
|
85 |
simil=simil,
|
86 |
segm=segm,
|
87 |
max_epochs=eval(trainConfig['epochs']),
|
|
|
88 |
early_stop_patience=eval(trainConfig['earlyStopPatience']),
|
89 |
model_file=trainConfig['model'],
|
90 |
freeze_layers=froozen_layers,
|
91 |
acc_gradients=eval(trainConfig['accumulativeGradients']),
|
92 |
batch_size=eval(trainConfig['batchSize']),
|
93 |
unet=unet,
|
94 |
-
head=head
|
|
|
|
9 |
import os
|
10 |
from datetime import datetime
|
11 |
import DeepDeformationMapRegistration.utils.constants as C
|
12 |
+
import re
|
13 |
+
from COMET.augmentation_constants import LAYER_SELECTION
|
14 |
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
|
15 |
|
16 |
err = list()
|
|
|
60 |
except KeyError as err:
|
61 |
froozen_layers = None
|
62 |
except NameError as err:
|
63 |
+
froozen_layers = list(filter(lambda x: x != '', re.split(';|\s|,|,\s|;\s', trainConfig['freeze'].upper())))
|
64 |
+
|
65 |
if froozen_layers is not None:
|
66 |
+
assert all(s in LAYER_SELECTION.keys() for s in froozen_layers), \
|
67 |
+
'Invalid option for "freeze". Expected one or several of: ' + ', '.join(LAYER_SELECTION.keys())
|
68 |
froozen_layers = list(set(froozen_layers)) # Unique elements
|
69 |
|
70 |
if augmentationConfig:
|
|
|
73 |
|
74 |
|
75 |
# copy the configuration file to the destionation folder
|
76 |
+
os.makedirs(output_folder, exist_ok=True) # TODO: move this within the "resume" if case, and bring here the creation of the resume-output folder!
|
77 |
copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
|
78 |
|
79 |
+
try:
|
80 |
+
unet = [int(x) for x in trainConfig['unet'].split(',')] if trainConfig['unet'] else [16, 32, 64, 128, 256]
|
81 |
+
head = [int(x) for x in trainConfig['head'].split(',')] if trainConfig['head'] else [16, 16]
|
82 |
+
except KeyError as err:
|
83 |
+
unet = [16, 32, 64, 128, 256]
|
84 |
+
head = [16, 16]
|
85 |
+
|
86 |
+
try:
|
87 |
+
resume_checkpoint = trainConfig['resumeCheckpoint']
|
88 |
+
except KeyError as e:
|
89 |
+
resume_checkpoint = None
|
90 |
|
91 |
launch_train(dataset_folder=datasetConfig['train'],
|
92 |
validation_folder=datasetConfig['validation'],
|
|
|
97 |
simil=simil,
|
98 |
segm=segm,
|
99 |
max_epochs=eval(trainConfig['epochs']),
|
100 |
+
image_size=eval(trainConfig['imageSize']),
|
101 |
early_stop_patience=eval(trainConfig['earlyStopPatience']),
|
102 |
model_file=trainConfig['model'],
|
103 |
freeze_layers=froozen_layers,
|
104 |
acc_gradients=eval(trainConfig['accumulativeGradients']),
|
105 |
batch_size=eval(trainConfig['batchSize']),
|
106 |
unet=unet,
|
107 |
+
head=head,
|
108 |
+
resume=resume_checkpoint)
|
COMET/augmentation_constants.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import numpy as np
|
|
|
2 |
|
3 |
# Constants for augmentation layer
|
4 |
# .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
@@ -31,4 +32,26 @@ LAYER_RANGES = {'INPUT': (IN_LAYERS),
|
|
31 |
'ENCODER': (ENCONDER_LAYERS),
|
32 |
'DECODER': (DECODER_LAYERS),
|
33 |
'TOP': (TOP_LAYERS_ENC, TOP_LAYERS_DEC),
|
34 |
-
'BOTTOM': (BOTTOM_LAYERS)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import re
|
3 |
|
4 |
# Constants for augmentation layer
|
5 |
# .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
|
|
32 |
'ENCODER': (ENCONDER_LAYERS),
|
33 |
'DECODER': (DECODER_LAYERS),
|
34 |
'TOP': (TOP_LAYERS_ENC, TOP_LAYERS_DEC),
|
35 |
+
'BOTTOM': (BOTTOM_LAYERS)}
|
36 |
+
|
37 |
+
# LAYER names:
|
38 |
+
IN_LAYER_REGEXP = '.*input'
|
39 |
+
FC_LAYER_REGEXP = '.*final.*'
|
40 |
+
OUT_LAYER_REGEXP = '(?:flow|transformer)'
|
41 |
+
ENC_LAYER_REGEXP = '.*enc_(?:conv|pooling)_(\d).*'
|
42 |
+
DEC_LAYER_REGEXP = '.*dec_(?:conv|upsample)_(\d).*'
|
43 |
+
LEVEL_NUMBER = lambda x: re.match('.*(?:enc|dec)_(?:conv|upsample|pooling)_(\d).*', x)
|
44 |
+
IS_TOP_LEVEL = lambda x: int(LEVEL_NUMBER(x)[1]) < 3 if LEVEL_NUMBER(x) is not None else False or bool(re.match(FC_LAYER_REGEXP, x))
|
45 |
+
IS_BOTTOM_LEVEL = lambda x: int(LEVEL_NUMBER(x)[1]) >= 3 if LEVEL_NUMBER(x) is not None else False
|
46 |
+
|
47 |
+
LAYER_SELECTION = {'INPUT': lambda x: bool(re.match(IN_LAYER_REGEXP, x)),
|
48 |
+
'FULLYCONNECTED': lambda x: bool(re.match(FC_LAYER_REGEXP, x)),
|
49 |
+
'ENCODER': lambda x: bool(re.match(ENC_LAYER_REGEXP, x)),
|
50 |
+
'DECODER': lambda x: bool(re.match(DEC_LAYER_REGEXP, x)),
|
51 |
+
'TOP': lambda x: IS_TOP_LEVEL(x),
|
52 |
+
'BOTTOM': lambda x: IS_BOTTOM_LEVEL(x)
|
53 |
+
}
|
54 |
+
|
55 |
+
# STUPID IDEA THAT COMPLICATES THINGS. The points was to allow combinations of the layer groups
|
56 |
+
# OR_GROUPS = ['ENCODER', 'DECODER', 'INPUT', 'OUTPUT', 'FULLYCONNECTED'] # These groups can be OR'ed with the AND_GROUPS and among them. E.g., Top layers of the encoder and decoder: ENCODER or DECODER and TOP
|
57 |
+
# AND_GROUPS = ['TOP', 'BOTTOM'] # These groups can be AND'ed with the OR_GROUPS and among them
|
COMET/format_dataset.py
CHANGED
@@ -30,7 +30,7 @@ SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSL
|
|
30 |
IMG_NAME_PATTERN = '(.*).nii.gz'
|
31 |
SEG_NAME_PATTERN = '(.*).nii.gz'
|
32 |
|
33 |
-
OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128'
|
34 |
|
35 |
|
36 |
if __name__ == '__main__':
|
@@ -60,6 +60,7 @@ if __name__ == '__main__':
|
|
60 |
seg = np.asarray(seg.dataobj)
|
61 |
|
62 |
segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
|
|
|
63 |
if args.crop:
|
64 |
parenchyma = regionprops(seg[..., 0])[0]
|
65 |
bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
|
@@ -72,10 +73,11 @@ if __name__ == '__main__':
|
|
72 |
isot_shape = img.shape
|
73 |
|
74 |
zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
|
75 |
-
|
76 |
img = zoom(img, zoom_factors, order=3)
|
77 |
if args.dilate_segmentations:
|
78 |
seg = binary_dilation(seg, binary_ball, iterations=1)
|
|
|
79 |
seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
|
80 |
zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
|
81 |
'scale_j': zoom_factors[1],
|
@@ -92,11 +94,15 @@ if __name__ == '__main__':
|
|
92 |
h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
|
93 |
|
94 |
h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
|
|
|
|
|
95 |
h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
|
96 |
h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
|
97 |
h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:]) # Remove the 0 (background label)
|
98 |
h5_file.create_dataset('isotropic_shape', data=isot_shape)
|
99 |
-
|
|
|
|
|
100 |
print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
|
101 |
h5_file.close()
|
102 |
|
|
|
30 |
IMG_NAME_PATTERN = '(.*).nii.gz'
|
31 |
SEG_NAME_PATTERN = '(.*).nii.gz'
|
32 |
|
33 |
+
OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/w_bboxes'
|
34 |
|
35 |
|
36 |
if __name__ == '__main__':
|
|
|
60 |
seg = np.asarray(seg.dataobj)
|
61 |
|
62 |
segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
|
63 |
+
bbox = [0]*6
|
64 |
if args.crop:
|
65 |
parenchyma = regionprops(seg[..., 0])[0]
|
66 |
bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
|
|
|
73 |
isot_shape = img.shape
|
74 |
|
75 |
zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
|
76 |
+
img_isotropic = np.copy(img)
|
77 |
img = zoom(img, zoom_factors, order=3)
|
78 |
if args.dilate_segmentations:
|
79 |
seg = binary_dilation(seg, binary_ball, iterations=1)
|
80 |
+
seg_isotropic = np.copy(seg)
|
81 |
seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
|
82 |
zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
|
83 |
'scale_j': zoom_factors[1],
|
|
|
94 |
h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
|
95 |
|
96 |
h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
|
97 |
+
h5_file.create_dataset('image_isotropic', data=img_isotropic[..., np.newaxis], dtype=np.float32)
|
98 |
+
h5_file.create_dataset('segmentation_isotropic', data=seg_isotropic.astype(np.uint8), dtype=np.uint8)
|
99 |
h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
|
100 |
h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
|
101 |
h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:]) # Remove the 0 (background label)
|
102 |
h5_file.create_dataset('isotropic_shape', data=isot_shape)
|
103 |
+
if args.crop:
|
104 |
+
h5_file.create_dataset('bounding_box_origin', data=bbox[:3])
|
105 |
+
h5_file.create_dataset('bounding_box_shape', data=bbox[3:] - bbox[:3])
|
106 |
print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
|
107 |
h5_file.close()
|
108 |
|
DeepDeformationMapRegistration/layers/augmentation.py
CHANGED
@@ -133,7 +133,7 @@ class AugmentationLayer(kl.Layer):
|
|
133 |
mov_img = tf.zeros_like(fix_img)
|
134 |
mov_segm = tf.zeros_like(fix_segm)
|
135 |
|
136 |
-
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3])
|
137 |
|
138 |
if self.out_img_shape is not None:
|
139 |
fix_img = self.downsize_image(fix_img)
|
|
|
133 |
mov_img = tf.zeros_like(fix_img)
|
134 |
mov_segm = tf.zeros_like(fix_segm)
|
135 |
|
136 |
+
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3]) # TODO: change, don't use tile!!
|
137 |
|
138 |
if self.out_img_shape is not None:
|
139 |
fix_img = self.downsize_image(fix_img)
|
DeepDeformationMapRegistration/layers/upsampling.py
CHANGED
@@ -485,6 +485,8 @@ def UpInterpolate3D(x,
|
|
485 |
nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
|
486 |
elif data_format == 'channels_first':
|
487 |
nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
|
|
|
|
|
488 |
|
489 |
r = size[0]
|
490 |
c = size[1]
|
|
|
485 |
nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
|
486 |
elif data_format == 'channels_first':
|
487 |
nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
|
488 |
+
else:
|
489 |
+
raise ValueError('Invalid option: ', data_format)
|
490 |
|
491 |
r = size[0]
|
492 |
c = size[1]
|
DeepDeformationMapRegistration/losses.py
CHANGED
@@ -41,7 +41,7 @@ class HausdorffDistanceErosion:
|
|
41 |
def _erode(self, in_tensor):
|
42 |
indiv_channels = tf.split(in_tensor, self.im_shape[-1], -1)
|
43 |
res = list()
|
44 |
-
with tf.variable_scope('erode', reuse=tf.AUTO_REUSE):
|
45 |
for ch in indiv_channels:
|
46 |
res.append(self.conv(tf.expand_dims(ch, 0), self.kernel, [1] * (self.ndims + 2), 'SAME'))
|
47 |
# out = -tf.nn.max_pool3d(-tf.expand_dims(in_tensor, 0), [3]*self.ndims, [1]*self.ndims, 'SAME', name='HDE_erosion')
|
@@ -626,7 +626,8 @@ class StructuralSimilarityGaussian:
|
|
626 |
self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
|
627 |
lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
|
628 |
[iterator, g_1d],
|
629 |
-
[iterator.get_shape(), tf.TensorShape([None]*self.dim)] # Shape invariants
|
|
|
630 |
)[-1]
|
631 |
|
632 |
self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
|
@@ -759,48 +760,57 @@ class GeneralizedDICEScore:
|
|
759 |
Learning Los Function for Highly Unbalanced Segmentations" https://arxiv.org/abs/1707.03237
|
760 |
:param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
|
761 |
"""
|
|
|
|
|
762 |
if input_shape[-1] > 1:
|
763 |
-
|
764 |
-
|
|
|
|
|
|
|
765 |
elif num_labels is not None:
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
|
|
|
|
|
|
770 |
else:
|
771 |
-
raise ValueError('If input_shape
|
|
|
772 |
|
773 |
def one_hot_encoding(self, in_img, name=''):
|
774 |
# TODO: Test if differentiable!
|
775 |
labels, indices = tf.unique(tf.reshape(in_img, [-1]), tf.int32, name=name+'_unique')
|
776 |
-
one_hot = tf.one_hot(indices,
|
777 |
-
one_hot = tf.reshape(one_hot, self.
|
778 |
-
one_hot = tf.slice(one_hot, [0]*len(self.
|
779 |
name=name + '_remove_bg')
|
780 |
return one_hot
|
781 |
|
782 |
def weigthed_dice(self, y_true, y_pred):
|
783 |
# y_true = [B, -1, L]
|
784 |
# y_pred = [B, -1, L]
|
785 |
-
if self.
|
786 |
-
|
787 |
-
|
788 |
y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
|
789 |
y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
|
790 |
|
791 |
size_y_true = tf.reduce_sum(y_true, axis=1, name='GDICE_size_y_true')
|
792 |
size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
|
793 |
-
w = tf.
|
794 |
numerator = w * tf.reduce_sum(y_true * y_pred, axis=1)
|
795 |
denominator = w * (size_y_true + size_y_pred)
|
796 |
-
return tf.div_no_nan(2 * tf.reduce_sum(numerator, axis=-1), tf.reduce_sum(denominator, axis=-1))
|
797 |
|
798 |
def macro_dice(self, y_true, y_pred):
|
799 |
# y_true = [B, -1, L]
|
800 |
# y_pred = [B, -1, L]
|
801 |
-
if self.
|
802 |
-
|
803 |
-
|
804 |
y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
|
805 |
y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
|
806 |
|
@@ -808,7 +818,7 @@ class GeneralizedDICEScore:
|
|
808 |
size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
|
809 |
numerator = tf.reduce_sum(y_true * y_pred, axis=1)
|
810 |
denominator = (size_y_true + size_y_pred)
|
811 |
-
return tf.div_no_nan(2 * numerator, denominator)
|
812 |
|
813 |
@function_decorator('GeneralizeDICE__loss')
|
814 |
def loss(self, y_true, y_pred):
|
|
|
41 |
def _erode(self, in_tensor):
|
42 |
indiv_channels = tf.split(in_tensor, self.im_shape[-1], -1)
|
43 |
res = list()
|
44 |
+
with tf.compat.v1.variable_scope('erode', reuse=tf.AUTO_REUSE):
|
45 |
for ch in indiv_channels:
|
46 |
res.append(self.conv(tf.expand_dims(ch, 0), self.kernel, [1] * (self.ndims + 2), 'SAME'))
|
47 |
# out = -tf.nn.max_pool3d(-tf.expand_dims(in_tensor, 0), [3]*self.ndims, [1]*self.ndims, 'SAME', name='HDE_erosion')
|
|
|
626 |
self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
|
627 |
lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
|
628 |
[iterator, g_1d],
|
629 |
+
[iterator.get_shape(), tf.TensorShape([None]*self.dim)], # Shape invariants
|
630 |
+
back_prop=False,
|
631 |
)[-1]
|
632 |
|
633 |
self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
|
|
|
760 |
Learning Los Function for Highly Unbalanced Segmentations" https://arxiv.org/abs/1707.03237
|
761 |
:param input_shape: Shape of the input image, without the batch dimension, e.g., 2D: [H, W, C], 3D: [H, W, D, C]
|
762 |
"""
|
763 |
+
self.smooth = 1e-10 # If y_pred = y_true = null -> dice should be 1
|
764 |
+
self.num_labels = num_labels
|
765 |
if input_shape[-1] > 1:
|
766 |
+
try:
|
767 |
+
self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
|
768 |
+
except TypeError as err:
|
769 |
+
self.flat_shape = [-1, None, input_shape[-1]]
|
770 |
+
self.cardinal_encoded = False
|
771 |
elif num_labels is not None:
|
772 |
+
try:
|
773 |
+
self.flat_shape = [-1, np.prod(np.asarray(input_shape[:-1])), input_shape[-1]]
|
774 |
+
except TypeError as err:
|
775 |
+
self.flat_shape = [-1, None, input_shape[-1]]
|
776 |
+
self.cardinal_enc_shape = [-1, *input_shape[:-1]]
|
777 |
+
self.cardinal_encoded = True
|
778 |
+
warnings.warn('Differentiable cardinal encoding not yet implemented')
|
779 |
else:
|
780 |
+
raise ValueError('If input_shape does not correspond to cardinally encoded,'
|
781 |
+
'then num_labels must be provided')
|
782 |
|
783 |
def one_hot_encoding(self, in_img, name=''):
|
784 |
# TODO: Test if differentiable!
|
785 |
labels, indices = tf.unique(tf.reshape(in_img, [-1]), tf.int32, name=name+'_unique')
|
786 |
+
one_hot = tf.one_hot(indices, self.num_labels, name=name + '_one_hot')
|
787 |
+
one_hot = tf.reshape(one_hot, self.cardinal_enc_shape + [self.num_labels], name=name + '_reshape')
|
788 |
+
one_hot = tf.slice(one_hot, [0] * len(self.cardinal_enc_shape) + [1], [-1] * (len(self.cardinal_enc_shape) + 1),
|
789 |
name=name + '_remove_bg')
|
790 |
return one_hot
|
791 |
|
792 |
def weigthed_dice(self, y_true, y_pred):
|
793 |
# y_true = [B, -1, L]
|
794 |
# y_pred = [B, -1, L]
|
795 |
+
# if self.cardinal_encoded:
|
796 |
+
# y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
|
797 |
+
# y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
|
798 |
y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
|
799 |
y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
|
800 |
|
801 |
size_y_true = tf.reduce_sum(y_true, axis=1, name='GDICE_size_y_true')
|
802 |
size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
|
803 |
+
w = tf.math.divide_no_nan(1., tf.pow(size_y_true, 2), name='GDICE_weight')
|
804 |
numerator = w * tf.reduce_sum(y_true * y_pred, axis=1)
|
805 |
denominator = w * (size_y_true + size_y_pred)
|
806 |
+
return tf.div_no_nan(2 * tf.reduce_sum(numerator, axis=-1) + self.smooth, tf.reduce_sum(denominator, axis=-1) + self.smooth)
|
807 |
|
808 |
def macro_dice(self, y_true, y_pred):
|
809 |
# y_true = [B, -1, L]
|
810 |
# y_pred = [B, -1, L]
|
811 |
+
# if self.cardinal_encoded:
|
812 |
+
# y_true = self.one_hot_encoding(y_true, name='GDICE_one_hot_encoding_y_true')
|
813 |
+
# y_pred = self.one_hot_encoding(y_pred, name='GDICE_one_hot_encoding_y_pred')
|
814 |
y_true = tf.reshape(y_true, self.flat_shape, name='GDICE_reshape_y_true') # Flatten along the volume dimensions
|
815 |
y_pred = tf.reshape(y_pred, self.flat_shape, name='GDICE_reshape_y_pred') # Flatten along the volume dimensions
|
816 |
|
|
|
818 |
size_y_pred = tf.reduce_sum(y_pred, axis=1, name='GDICE_size_y_pred')
|
819 |
numerator = tf.reduce_sum(y_true * y_pred, axis=1)
|
820 |
denominator = (size_y_true + size_y_pred)
|
821 |
+
return tf.div_no_nan(2 * numerator + self.smooth, denominator + self.smooth)
|
822 |
|
823 |
@function_decorator('GeneralizeDICE__loss')
|
824 |
def loss(self, y_true, y_pred):
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
@@ -413,6 +413,7 @@ WAR = 30 # Warning
|
|
413 |
ERR = 40 # Error
|
414 |
DEB = 10 # Debug
|
415 |
CRI = 50 # Critical
|
|
|
416 |
|
417 |
SEVERITY_STR = {INF: 'INFO',
|
418 |
WAR: 'WARNING',
|
@@ -511,8 +512,8 @@ REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
|
511 |
IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
|
512 |
# ...OSLO_COMET_CT/Formatted_128x128x128/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
513 |
COMET_DATASET_iso_to_cubic_scales = np.asarray([0.455259 + 0.048027, 0.492012 + 0.044298, 0.577552 + 0.051708])
|
514 |
-
MAX_AUG_DISP_ISOT = 30
|
515 |
-
MAX_AUG_DEF_ISOT = 6
|
516 |
MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
|
517 |
MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
|
518 |
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
|
|
413 |
ERR = 40 # Error
|
414 |
DEB = 10 # Debug
|
415 |
CRI = 50 # Critical
|
416 |
+
SUMMARY_LINE_LENGTH = 150
|
417 |
|
418 |
SEVERITY_STR = {INF: 'INFO',
|
419 |
WAR: 'WARNING',
|
|
|
512 |
IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
|
513 |
# ...OSLO_COMET_CT/Formatted_128x128x128/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
514 |
COMET_DATASET_iso_to_cubic_scales = np.asarray([0.455259 + 0.048027, 0.492012 + 0.044298, 0.577552 + 0.051708])
|
515 |
+
MAX_AUG_DISP_ISOT = 30 # mm
|
516 |
+
MAX_AUG_DEF_ISOT = 6 # mm
|
517 |
MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
|
518 |
MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
|
519 |
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
DeepDeformationMapRegistration/utils/misc.py
CHANGED
@@ -8,6 +8,7 @@ from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
|
|
8 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
from tensorflow import squeeze
|
10 |
from scipy.ndimage import zoom
|
|
|
11 |
|
12 |
|
13 |
def try_mkdir(dir, verbose=True):
|
@@ -55,24 +56,28 @@ class DatasetCopy:
|
|
55 |
class DisplacementMapInterpolator:
|
56 |
def __init__(self,
|
57 |
image_shape=[64, 64, 64],
|
58 |
-
method='rbf'
|
|
|
59 |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
60 |
self.method = method
|
61 |
self.image_shape = image_shape
|
|
|
62 |
|
63 |
self.grid = self.__regular_grid()
|
64 |
|
65 |
def __regular_grid(self):
|
66 |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
67 |
-
yy = np.linspace(0, self.image_shape[
|
68 |
-
zz = np.linspace(0, self.image_shape[
|
69 |
|
70 |
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
71 |
|
72 |
-
return np.stack([xx.
|
|
|
|
|
73 |
|
74 |
def __call__(self, disp_map, interp_points, backwards=False):
|
75 |
-
disp_map = disp_map.reshape([-1, 3])
|
76 |
grid_pts = self.grid.copy()
|
77 |
if backwards:
|
78 |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|
@@ -115,15 +120,22 @@ class DisplacementMapInterpolator:
|
|
115 |
return disp
|
116 |
|
117 |
|
118 |
-
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(
|
119 |
segmentations = np.squeeze(segmentations)
|
120 |
if ohe:
|
121 |
-
segmentations =
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
125 |
else:
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
seg_props = regionprops(segmentations)
|
129 |
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
|
@@ -143,11 +155,15 @@ def segmentation_ohe_to_cardinal(segmentation):
|
|
143 |
return np.argmax(cpy, axis=-1)[..., np.newaxis]
|
144 |
|
145 |
|
146 |
-
def segmentation_cardinal_to_ohe(segmentation):
|
147 |
# Keep in mind that we don't handle the overlap between the segmentations!
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
return cpy
|
152 |
|
153 |
|
@@ -180,3 +196,50 @@ def scale_transformation(original_shape: [list, tuple, np.ndarray], dest_shape:
|
|
180 |
|
181 |
return trf
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
from tensorflow import squeeze
|
10 |
from scipy.ndimage import zoom
|
11 |
+
import tensorflow as tf
|
12 |
|
13 |
|
14 |
def try_mkdir(dir, verbose=True):
|
|
|
56 |
class DisplacementMapInterpolator:
|
57 |
def __init__(self,
|
58 |
image_shape=[64, 64, 64],
|
59 |
+
method='rbf',
|
60 |
+
step=1):
|
61 |
assert method in ['rbf', 'griddata', 'tf', 'tps'], "Method must be 'rbf' or 'griddata'"
|
62 |
self.method = method
|
63 |
self.image_shape = image_shape
|
64 |
+
self.step = step # If to use every point or even N-th point
|
65 |
|
66 |
self.grid = self.__regular_grid()
|
67 |
|
68 |
def __regular_grid(self):
|
69 |
xx = np.linspace(0, self.image_shape[0], self.image_shape[0], endpoint=False, dtype=np.uint16)
|
70 |
+
yy = np.linspace(0, self.image_shape[1], self.image_shape[1], endpoint=False, dtype=np.uint16)
|
71 |
+
zz = np.linspace(0, self.image_shape[2], self.image_shape[2], endpoint=False, dtype=np.uint16)
|
72 |
|
73 |
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
74 |
|
75 |
+
return np.stack([xx[::self.step, ::self.step, ::self.step].flatten(),
|
76 |
+
yy[::self.step, ::self.step, ::self.step].flatten(),
|
77 |
+
zz[::self.step, ::self.step, ::self.step].flatten()], axis=0).T
|
78 |
|
79 |
def __call__(self, disp_map, interp_points, backwards=False):
|
80 |
+
disp_map = disp_map.squeeze()[::self.step, ::self.step, ::self.step, ...].reshape([-1, 3])
|
81 |
grid_pts = self.grid.copy()
|
82 |
if backwards:
|
83 |
grid_pts = np.add(grid_pts, disp_map).astype(np.float32)
|
|
|
120 |
return disp
|
121 |
|
122 |
|
123 |
+
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(1, 28), missing_centroid=[np.nan]*3, brain_study=True):
|
124 |
segmentations = np.squeeze(segmentations)
|
125 |
if ohe:
|
126 |
+
segmentations = segmentation_ohe_to_cardinal(segmentations)
|
127 |
+
lbls = set(np.unique(segmentations)) - {0} # Remove the 0 value returned by np.unique, no label
|
128 |
+
# missing_lbls = set(expected_lbls) - lbls
|
129 |
+
# if brain_study:
|
130 |
+
# segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1
|
131 |
else:
|
132 |
+
lbls = set(np.unique(segmentations)) if 0 in expected_lbls else set(np.unique(segmentations)) - {0}
|
133 |
+
missing_lbls = set(expected_lbls) - lbls
|
134 |
+
|
135 |
+
if 0 in expected_lbls:
|
136 |
+
segmentations += np.ones_like(segmentations) # Regionsprops neglects the label 0. But we need it, so offset all labels by 1
|
137 |
+
|
138 |
+
segmentations = np.squeeze(segmentations) # remove channel dimension, not needed anyway
|
139 |
|
140 |
seg_props = regionprops(segmentations)
|
141 |
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
|
|
|
155 |
return np.argmax(cpy, axis=-1)[..., np.newaxis]
|
156 |
|
157 |
|
158 |
+
def segmentation_cardinal_to_ohe(segmentation, labels_list: list = None):
|
159 |
# Keep in mind that we don't handle the overlap between the segmentations!
|
160 |
+
#labels_list = np.unique(segmentation)[1:] if labels_list is None else labels_list
|
161 |
+
num_labels = len(labels_list)
|
162 |
+
expected_shape = segmentation.shape[:-1] + (num_labels,)
|
163 |
+
cpy = np.zeros(expected_shape, dtype=np.uint8)
|
164 |
+
seg_squeezed = np.squeeze(segmentation, axis=-1)
|
165 |
+
for ch, lbl in enumerate(labels_list):
|
166 |
+
cpy[seg_squeezed == lbl, ch] = 1
|
167 |
return cpy
|
168 |
|
169 |
|
|
|
196 |
|
197 |
return trf
|
198 |
|
199 |
+
|
200 |
+
class GaussianFilter:
|
201 |
+
def __init__(self, size, sigma, dim, num_channels, stride=None, batch: bool=True):
|
202 |
+
"""
|
203 |
+
Gaussian filter
|
204 |
+
:param size: Kernel size
|
205 |
+
:param sigma: Sigma of the Gaussian filter.
|
206 |
+
:param dim: Data dimensionality. Must be {2, 3}.
|
207 |
+
:param num_channels: Number of channels of the image to filter.
|
208 |
+
"""
|
209 |
+
self.size = size
|
210 |
+
self.dim = dim
|
211 |
+
self.sigma = float(sigma)
|
212 |
+
self.num_channels = num_channels
|
213 |
+
self.stride = size // 2 if stride is None else int(stride)
|
214 |
+
if batch:
|
215 |
+
self.stride = [1] + [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims
|
216 |
+
else:
|
217 |
+
self.stride = [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims
|
218 |
+
|
219 |
+
self.convDN = getattr(tf.nn, 'conv%dd' % dim)
|
220 |
+
self.__GF = None
|
221 |
+
|
222 |
+
self.__build_gaussian_filter()
|
223 |
+
|
224 |
+
def __build_gaussian_filter(self):
|
225 |
+
range_1d = tf.range(-(self.size/2) + 1, self.size//2 + 1)
|
226 |
+
g_1d = tf.math.exp(-1.0 * tf.pow(range_1d, 2) / (2. * tf.pow(self.sigma, 2)))
|
227 |
+
g_1d_expanded = tf.expand_dims(g_1d, -1)
|
228 |
+
iterator = tf.constant(1)
|
229 |
+
self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
|
230 |
+
lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
|
231 |
+
[iterator, g_1d],
|
232 |
+
[iterator.get_shape(), tf.TensorShape(None)], # Shape invariants
|
233 |
+
back_prop=False
|
234 |
+
)[-1]
|
235 |
+
|
236 |
+
self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
|
237 |
+
self.__GF = tf.reshape(self.__GF, (*[self.size]*self.dim, 1, 1)) # Add Ch_in and Ch_out for convolution
|
238 |
+
self.__GF = tf.tile(self.__GF, (*[1] * self.dim, self.num_channels, self.num_channels,))
|
239 |
+
|
240 |
+
def apply_filter(self, in_image):
|
241 |
+
return self.convDN(in_image, self.__GF, self.stride, 'SAME')
|
242 |
+
|
243 |
+
@property
|
244 |
+
def kernel(self):
|
245 |
+
return self.__GF
|
DeepDeformationMapRegistration/utils/operators.py
CHANGED
@@ -63,3 +63,18 @@ def sample_unique(population, samples, tout=tf.int32):
|
|
63 |
_, indices = tf.nn.top_k(z, samples)
|
64 |
ret_val = tf.gather(population, indices)
|
65 |
return tf.cast(ret_val, tout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
_, indices = tf.nn.top_k(z, samples)
|
64 |
ret_val = tf.gather(population, indices)
|
65 |
return tf.cast(ret_val, tout)
|
66 |
+
|
67 |
+
|
68 |
+
def safe_medpy_metric(prediction, reference, nb_labels, metric_fnc, fnc_args: dict={}):
|
69 |
+
vals = list()
|
70 |
+
if 'voxelspacing' in fnc_args.keys():
|
71 |
+
diag = np.power(reference.shape[:-1] * fnc_args['voxelspacing'], 2)
|
72 |
+
else:
|
73 |
+
diag = np.power(reference.shape[:-1], 2)
|
74 |
+
diag = np.sqrt(np.sum(diag))
|
75 |
+
for l in range(nb_labels):
|
76 |
+
try:
|
77 |
+
vals.append(metric_fnc(prediction[..., l], reference[..., l], **fnc_args))
|
78 |
+
except RuntimeError:
|
79 |
+
vals.append(diag)
|
80 |
+
return vals
|
DeepDeformationMapRegistration/utils/visualization.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import matplotlib
|
2 |
-
|
3 |
import matplotlib.pyplot as plt
|
4 |
from mpl_toolkits.mplot3d import Axes3D
|
5 |
import matplotlib.colors as mcolors
|
@@ -17,7 +17,7 @@ THRES = 0.9
|
|
17 |
|
18 |
# COLOR MAPS
|
19 |
chunks = np.linspace(0, 1, 10)
|
20 |
-
cmap1 = plt.get_cmap('hsv',
|
21 |
# cmaplist = [cmap1(i) for i in range(cmap1.N)]
|
22 |
cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
|
23 |
cmaplist[0] = (1, 1, 1, 1.0)
|
@@ -34,6 +34,14 @@ cmap4 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
|
|
34 |
|
35 |
cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
|
39 |
if dimensionality == 2:
|
@@ -321,7 +329,7 @@ def save_centreline_img(img, title, filename, fig=None):
|
|
321 |
plt.close()
|
322 |
|
323 |
|
324 |
-
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False):
|
325 |
if fig is not None:
|
326 |
fig.clear()
|
327 |
plt.figure(fig.number)
|
@@ -333,7 +341,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
333 |
if dim == 2:
|
334 |
ax_x = fig.add_subplot(131)
|
335 |
ax_x.set_title('H displacement')
|
336 |
-
im_x = ax_x.imshow(disp_map[..., C.H_DISP])
|
337 |
ax_x.tick_params(axis='both',
|
338 |
which='both',
|
339 |
bottom=False,
|
@@ -344,7 +352,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
344 |
|
345 |
ax_y = fig.add_subplot(132)
|
346 |
ax_y.set_title('W displacement')
|
347 |
-
im_y = ax_y.imshow(disp_map[..., C.W_DISP])
|
348 |
ax_y.tick_params(axis='both',
|
349 |
which='both',
|
350 |
bottom=False,
|
@@ -371,7 +379,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
371 |
ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
372 |
|
373 |
else:
|
374 |
-
c, d, s = _prepare_quiver_map(disp_map, dim=dim)
|
375 |
im = ax.imshow(s, interpolation='none', aspect='equal')
|
376 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
|
377 |
scale=C.QUIVER_PARAMS.arrow_scale)
|
@@ -386,7 +394,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
386 |
fig.suptitle(title)
|
387 |
else:
|
388 |
ax = fig.add_subplot(111, projection='3d')
|
389 |
-
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
|
390 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
|
391 |
_square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
|
392 |
fig.suptitle('Displacement map')
|
@@ -810,7 +818,12 @@ def plot_dataset_3d(img_sets):
|
|
810 |
return fig
|
811 |
|
812 |
|
813 |
-
def plot_predictions(
|
|
|
|
|
|
|
|
|
|
|
814 |
num_rows = fix_img_batch.shape[0]
|
815 |
img_dim = len(fix_img_batch.shape) - 2
|
816 |
img_size = fix_img_batch.shape[1:-1]
|
@@ -828,6 +841,10 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
828 |
fix_img_batch = fix_img_batch[:, selected_slice, ...]
|
829 |
mov_img_batch = mov_img_batch[:, selected_slice, ...]
|
830 |
pred_img_batch = pred_img_batch[:, selected_slice, ...]
|
|
|
|
|
|
|
|
|
831 |
disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
|
832 |
img_size = fix_img_batch.shape[1:-1]
|
833 |
elif img_dim != 2:
|
@@ -836,16 +853,24 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
836 |
for row in range(num_rows):
|
837 |
fix_img = fix_img_batch[row, :, :, 0].transpose()
|
838 |
mov_img = mov_img_batch[row, :, :, 0].transpose()
|
839 |
-
disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
|
840 |
pred_img = pred_img_batch[row, :, :, 0].transpose()
|
841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
ax[row, 0].tick_params(axis='both',
|
843 |
which='both',
|
844 |
bottom=False,
|
845 |
left=False,
|
846 |
labelleft=False,
|
847 |
labelbottom=False)
|
848 |
-
ax[row, 1].imshow(mov_img, origin='lower')
|
|
|
|
|
849 |
ax[row, 1].tick_params(axis='both',
|
850 |
which='both',
|
851 |
bottom=False,
|
@@ -853,7 +878,7 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
853 |
labelleft=False,
|
854 |
labelbottom=False)
|
855 |
|
856 |
-
c, d, s = _prepare_quiver_map(disp_map, spc=
|
857 |
cx, cy = c
|
858 |
dx, dy = d
|
859 |
disp_map_color = _prepare_colormap(disp_map)
|
@@ -866,7 +891,9 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
866 |
labelleft=False,
|
867 |
labelbottom=False)
|
868 |
|
869 |
-
ax[row, 3].imshow(mov_img, origin='lower')
|
|
|
|
|
870 |
ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
|
871 |
ax[row, 3].tick_params(axis='both',
|
872 |
which='both',
|
@@ -875,7 +902,9 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
875 |
labelleft=False,
|
876 |
labelbottom=False)
|
877 |
|
878 |
-
ax[row, 4].imshow(pred_img, origin='lower')
|
|
|
|
|
879 |
ax[row, 4].tick_params(axis='both',
|
880 |
which='both',
|
881 |
bottom=False,
|
|
|
1 |
import matplotlib
|
2 |
+
matplotlib.use('WebAgg')
|
3 |
import matplotlib.pyplot as plt
|
4 |
from mpl_toolkits.mplot3d import Axes3D
|
5 |
import matplotlib.colors as mcolors
|
|
|
17 |
|
18 |
# COLOR MAPS
|
19 |
chunks = np.linspace(0, 1, 10)
|
20 |
+
cmap1 = plt.get_cmap('hsv', 30)
|
21 |
# cmaplist = [cmap1(i) for i in range(cmap1.N)]
|
22 |
cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
|
23 |
cmaplist[0] = (1, 1, 1, 1.0)
|
|
|
34 |
|
35 |
cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
|
36 |
|
37 |
+
cmap_segs = np.asarray([mcolors.to_rgba(mcolors.CSS4_COLORS[c], 1) for c in mcolors.CSS4_COLORS.keys()])
|
38 |
+
cmap_segs.sort()
|
39 |
+
# rnd_idxs = [30, 17, 72, 90, 74, 39, 120, 63, 52, 79, 140, 68, 131, 109, 57, 49, 11, 132, 29, 46, 51, 26, 53, 7, 89, 47, 43, 121, 31, 28, 106, 92, 130, 117, 91, 118, 61, 5, 80, 93, 58, 133, 14, 98, 116, 76, 113, 111, 136, 142, 95, 122, 86, 77, 36, 97, 141, 115, 18, 81, 88, 87, 44, 146, 103, 67, 147, 48, 42, 83, 128, 65, 139, 69, 27, 135, 94, 134, 50, 19, 114, 0, 96, 10, 138, 75, 13, 12, 102, 32, 66, 16, 8, 73, 85, 145, 54, 37, 70, 143]
|
40 |
+
# cmap_segs = cmap_segs[rnd_idxs]
|
41 |
+
np.random.shuffle(cmap_segs)
|
42 |
+
cmap_segs[0, -1] = 0
|
43 |
+
cmap_segs = mcolors.ListedColormap(cmap_segs)
|
44 |
+
|
45 |
|
46 |
def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
|
47 |
if dimensionality == 2:
|
|
|
329 |
plt.close()
|
330 |
|
331 |
|
332 |
+
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False, step=1):
|
333 |
if fig is not None:
|
334 |
fig.clear()
|
335 |
plt.figure(fig.number)
|
|
|
341 |
if dim == 2:
|
342 |
ax_x = fig.add_subplot(131)
|
343 |
ax_x.set_title('H displacement')
|
344 |
+
im_x = ax_x.imshow(disp_map[..., ::step, ::step, C.H_DISP])
|
345 |
ax_x.tick_params(axis='both',
|
346 |
which='both',
|
347 |
bottom=False,
|
|
|
352 |
|
353 |
ax_y = fig.add_subplot(132)
|
354 |
ax_y.set_title('W displacement')
|
355 |
+
im_y = ax_y.imshow(disp_map[..., ::step, ::step, C.W_DISP])
|
356 |
ax_y.tick_params(axis='both',
|
357 |
which='both',
|
358 |
bottom=False,
|
|
|
379 |
ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
380 |
|
381 |
else:
|
382 |
+
c, d, s = _prepare_quiver_map(disp_map, dim=dim, spc=step)
|
383 |
im = ax.imshow(s, interpolation='none', aspect='equal')
|
384 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
|
385 |
scale=C.QUIVER_PARAMS.arrow_scale)
|
|
|
394 |
fig.suptitle(title)
|
395 |
else:
|
396 |
ax = fig.add_subplot(111, projection='3d')
|
397 |
+
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim, spc=step)
|
398 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
|
399 |
_square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
|
400 |
fig.suptitle('Displacement map')
|
|
|
818 |
return fig
|
819 |
|
820 |
|
821 |
+
def plot_predictions(img_batches, disp_map_batch, seg_batches=None, step=1, filename='predictions', fig=None, show=False):
|
822 |
+
fix_img_batch, mov_img_batch, pred_img_batch = img_batches
|
823 |
+
if seg_batches != None:
|
824 |
+
fix_seg_batch, mov_seg_batch, pred_seg_batch = seg_batches
|
825 |
+
else:
|
826 |
+
fix_seg_batch = mov_seg_batch = pred_seg_batch = None
|
827 |
num_rows = fix_img_batch.shape[0]
|
828 |
img_dim = len(fix_img_batch.shape) - 2
|
829 |
img_size = fix_img_batch.shape[1:-1]
|
|
|
841 |
fix_img_batch = fix_img_batch[:, selected_slice, ...]
|
842 |
mov_img_batch = mov_img_batch[:, selected_slice, ...]
|
843 |
pred_img_batch = pred_img_batch[:, selected_slice, ...]
|
844 |
+
if seg_batches != None:
|
845 |
+
fix_seg_batch = fix_seg_batch[:, selected_slice, ...]
|
846 |
+
mov_seg_batch = mov_seg_batch[:, selected_slice, ...]
|
847 |
+
pred_seg_batch = pred_seg_batch[:, selected_slice, ...]
|
848 |
disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
|
849 |
img_size = fix_img_batch.shape[1:-1]
|
850 |
elif img_dim != 2:
|
|
|
853 |
for row in range(num_rows):
|
854 |
fix_img = fix_img_batch[row, :, :, 0].transpose()
|
855 |
mov_img = mov_img_batch[row, :, :, 0].transpose()
|
|
|
856 |
pred_img = pred_img_batch[row, :, :, 0].transpose()
|
857 |
+
if seg_batches != None:
|
858 |
+
fix_seg = fix_seg_batch[row, :, :, 0].transpose()
|
859 |
+
mov_seg= mov_seg_batch[row, :, :, 0].transpose()
|
860 |
+
pred_seg = pred_seg_batch[row, :, :, 0].transpose()
|
861 |
+
disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
|
862 |
+
ax[row, 0].imshow(fix_img, origin='lower', cmap='gray')
|
863 |
+
if seg_batches != None:
|
864 |
+
ax[row, 0].imshow(fix_seg, origin='lower', cmap=cmap_segs)
|
865 |
ax[row, 0].tick_params(axis='both',
|
866 |
which='both',
|
867 |
bottom=False,
|
868 |
left=False,
|
869 |
labelleft=False,
|
870 |
labelbottom=False)
|
871 |
+
ax[row, 1].imshow(mov_img, origin='lower', cmap='gray')
|
872 |
+
if seg_batches != None:
|
873 |
+
ax[row, 1].imshow(mov_seg, origin='lower', cmap=cmap_segs)
|
874 |
ax[row, 1].tick_params(axis='both',
|
875 |
which='both',
|
876 |
bottom=False,
|
|
|
878 |
labelleft=False,
|
879 |
labelbottom=False)
|
880 |
|
881 |
+
c, d, s = _prepare_quiver_map(disp_map, spc=step)
|
882 |
cx, cy = c
|
883 |
dx, dy = d
|
884 |
disp_map_color = _prepare_colormap(disp_map)
|
|
|
891 |
labelleft=False,
|
892 |
labelbottom=False)
|
893 |
|
894 |
+
ax[row, 3].imshow(mov_img, origin='lower', cmap='gray')
|
895 |
+
if seg_batches != None:
|
896 |
+
ax[row, 3].imshow(mov_seg, origin='lower', cmap=cmap_segs)
|
897 |
ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
|
898 |
ax[row, 3].tick_params(axis='both',
|
899 |
which='both',
|
|
|
902 |
labelleft=False,
|
903 |
labelbottom=False)
|
904 |
|
905 |
+
ax[row, 4].imshow(pred_img, origin='lower', cmap='gray')
|
906 |
+
if seg_batches != None:
|
907 |
+
ax[row, 4].imshow(pred_seg, origin='lower', cmap=cmap_segs)
|
908 |
ax[row, 4].tick_params(axis='both',
|
909 |
which='both',
|
910 |
bottom=False,
|
SoA_methods/eval_ants.py
CHANGED
@@ -19,6 +19,8 @@ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
|
19 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
20 |
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
|
|
|
|
|
22 |
import voxelmorph as vxm
|
23 |
|
24 |
from argparse import ArgumentParser
|
@@ -35,22 +37,24 @@ WARPED_FIX = 'warpedfixout'
|
|
35 |
FWD_TRFS = 'fwdtransforms'
|
36 |
INV_TRFS = 'invtransforms'
|
37 |
|
38 |
-
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
39 |
-
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
40 |
|
41 |
if __name__ == '__main__':
|
42 |
parser = ArgumentParser()
|
43 |
parser.add_argument('--dataset', type=str, help='Directory with the images')
|
44 |
parser.add_argument('--outdir', type=str, help='Output directory')
|
|
|
45 |
args = parser.parse_args()
|
46 |
|
|
|
|
|
|
|
47 |
os.makedirs(args.outdir, exist_ok=True)
|
48 |
os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
|
49 |
os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
|
50 |
dataset_files = os.listdir(args.dataset)
|
51 |
dataset_files.sort()
|
52 |
dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
|
53 |
-
|
54 |
dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
|
55 |
|
56 |
f = h5py.File(dataset_files[0], 'r')
|
@@ -67,18 +71,18 @@ if __name__ == '__main__':
|
|
67 |
HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
|
68 |
GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
69 |
|
70 |
-
fix_img_ph = tf.placeholder(tf.float32, (1,
|
71 |
-
pred_img_ph = tf.placeholder(tf.float32, (1,
|
72 |
-
fix_seg_ph = tf.placeholder(tf.float32, (1,
|
73 |
-
pred_seg_ph = tf.placeholder(tf.float32, (1,
|
74 |
|
75 |
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
76 |
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
77 |
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
78 |
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
79 |
-
dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
80 |
-
hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
81 |
-
dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
82 |
|
83 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
84 |
config.gpu_options.allow_growth = True
|
@@ -90,7 +94,7 @@ if __name__ == '__main__':
|
|
90 |
####
|
91 |
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "{:d}".format(os.cpu_count()) #https://github.com/ANTsX/ANTsPy/issues/261
|
92 |
print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
|
93 |
-
dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
|
94 |
# Header of the metrics csv file
|
95 |
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
|
96 |
|
@@ -147,10 +151,17 @@ if __name__ == '__main__':
|
|
147 |
|
148 |
dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
|
149 |
with sess.as_default():
|
150 |
-
dice
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
156 |
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
@@ -160,26 +171,28 @@ if __name__ == '__main__':
|
|
160 |
{'fix_img:0': fix_img[np.newaxis, ...], # Batch axis
|
161 |
'pred_img:0': pred_img[np.newaxis, ...] # Batch axis
|
162 |
})
|
|
|
163 |
ms_ssim = ms_ssim[0]
|
164 |
|
165 |
# TRE
|
166 |
disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
|
|
|
167 |
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
168 |
-
upsample_scale = 128 / 64
|
169 |
-
fix_centroids_isotropic = fix_centroids * upsample_scale
|
170 |
-
pred_centroids_isotropic = pred_centroids * upsample_scale
|
171 |
|
172 |
-
fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
173 |
-
pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
174 |
-
tre_array = target_registration_error(
|
175 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
176 |
|
177 |
-
dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
|
178 |
-
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
|
179 |
-
|
180 |
-
|
181 |
-
with open(metrics_file[reg_method], 'a') as f:
|
182 |
-
|
183 |
|
184 |
save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
185 |
save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
@@ -188,12 +201,17 @@ if __name__ == '__main__':
|
|
188 |
save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
189 |
save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
190 |
|
191 |
-
plot_predictions(fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], disp_map[np.newaxis, ...],
|
192 |
-
plot_predictions(
|
193 |
save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
194 |
|
195 |
for k in metrics_file.keys():
|
196 |
print('Summary {}\n=======\n'.format(k))
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
199 |
print('\n=======\n')
|
|
|
19 |
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
20 |
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
|
22 |
+
import medpy.metric as medpy_metrics
|
23 |
+
|
24 |
import voxelmorph as vxm
|
25 |
|
26 |
from argparse import ArgumentParser
|
|
|
37 |
FWD_TRFS = 'fwdtransforms'
|
38 |
INV_TRFS = 'invtransforms'
|
39 |
|
|
|
|
|
40 |
|
41 |
if __name__ == '__main__':
|
42 |
parser = ArgumentParser()
|
43 |
parser.add_argument('--dataset', type=str, help='Directory with the images')
|
44 |
parser.add_argument('--outdir', type=str, help='Output directory')
|
45 |
+
parser.add_argument('--gpu', type=int, help='GPU')
|
46 |
args = parser.parse_args()
|
47 |
|
48 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
49 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
50 |
+
|
51 |
os.makedirs(args.outdir, exist_ok=True)
|
52 |
os.makedirs(os.path.join(args.outdir, 'SyN'), exist_ok=True)
|
53 |
os.makedirs(os.path.join(args.outdir, 'SyNCC'), exist_ok=True)
|
54 |
dataset_files = os.listdir(args.dataset)
|
55 |
dataset_files.sort()
|
56 |
dataset_files = [os.path.join(args.dataset, f) for f in dataset_files if re.match(DATASET_NAMES, f)]
|
57 |
+
dataset_files.sort()
|
58 |
dataset_iterator = tqdm(enumerate(dataset_files), desc="Running ANTs")
|
59 |
|
60 |
f = h5py.File(dataset_files[0], 'r')
|
|
|
71 |
HausdorffDistanceErosion(3, 10, im_shape=image_shape + [nb_labels]).metric,
|
72 |
GeneralizedDICEScore(image_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
73 |
|
74 |
+
fix_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
|
75 |
+
pred_img_ph = tf.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
|
76 |
+
fix_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
|
77 |
+
pred_seg_ph = tf.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
|
78 |
|
79 |
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
80 |
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
81 |
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
82 |
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
83 |
+
# dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
84 |
+
# hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
85 |
+
# dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
86 |
|
87 |
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
88 |
config.gpu_options.allow_growth = True
|
|
|
94 |
####
|
95 |
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "{:d}".format(os.cpu_count()) #https://github.com/ANTsX/ANTsPy/issues/261
|
96 |
print("Running ANTs using {} threads".format(os.environ.get("ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS")))
|
97 |
+
# dm_interp = DisplacementMapInterpolator(image_shape, 'griddata')
|
98 |
# Header of the metrics csv file
|
99 |
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE']
|
100 |
|
|
|
151 |
|
152 |
dataset_iterator.set_description('{} ({}): Getting metrics {}'.format(file_num, file_path, reg_method))
|
153 |
with sess.as_default():
|
154 |
+
dice = np.mean([medpy_metrics.dc(pred_seg[np.newaxis, ..., l], fix_seg[np.newaxis,..., l]) / np.sum(
|
155 |
+
fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
|
156 |
+
hd = np.mean(
|
157 |
+
[medpy_metrics.hd(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
|
158 |
+
dice_macro = np.mean(
|
159 |
+
[medpy_metrics.dc(pred_seg[np.newaxis,..., l], fix_seg[np.newaxis,..., l]) for l in range(nb_labels)])
|
160 |
+
|
161 |
+
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf],
|
162 |
+
# {'fix_seg:0': fix_seg[np.newaxis, ...], # Batch axis
|
163 |
+
# 'pred_seg:0': pred_seg[np.newaxis, ...] # Batch axis
|
164 |
+
# })
|
165 |
|
166 |
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
167 |
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
|
|
171 |
{'fix_img:0': fix_img[np.newaxis, ...], # Batch axis
|
172 |
'pred_img:0': pred_img[np.newaxis, ...] # Batch axis
|
173 |
})
|
174 |
+
ssim = np.mean(ssim)
|
175 |
ms_ssim = ms_ssim[0]
|
176 |
|
177 |
# TRE
|
178 |
disp_map = np.squeeze(np.asarray(nb.load(mov_to_fix_trf_list[0]).dataobj))
|
179 |
+
dm_interp = DisplacementMapInterpolator(fix_img.shape[:-1], 'griddata', step=2)
|
180 |
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
181 |
+
# upsample_scale = 128 / 64
|
182 |
+
# fix_centroids_isotropic = fix_centroids * upsample_scale
|
183 |
+
# pred_centroids_isotropic = pred_centroids * upsample_scale
|
184 |
|
185 |
+
# fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
186 |
+
# pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
187 |
+
tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
|
188 |
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
189 |
|
190 |
+
# dataset_iterator.set_description('{} ({}): Saving data {}'.format(file_num, file_path, reg_method))
|
191 |
+
# new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd,
|
192 |
+
# t1_syn-t0_syn if reg_method == 'SyN' else t1_syncc-t0_syncc,
|
193 |
+
# tre]
|
194 |
+
# with open(metrics_file[reg_method], 'a') as f:
|
195 |
+
# f.write(';'.join(map(str, new_line))+'\n')
|
196 |
|
197 |
save_nifti(fix_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
198 |
save_nifti(mov_img[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
|
|
201 |
save_nifti(mov_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
202 |
save_nifti(pred_seg_card[0, ...], os.path.join(args.outdir, reg_method, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
203 |
|
204 |
+
plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], seg_batches=[fix_seg_card[np.newaxis, ...], mov_seg_card[np.newaxis, ...], pred_seg_card[np.newaxis, ...]], filename=os.path.join(args.outdir, reg_method, '{:03d}_figures_seg.png'.format(step)), show=False)
|
205 |
+
plot_predictions(img_batches=[fix_img[np.newaxis, ...], mov_img[np.newaxis, ...], pred_img[np.newaxis, ...]], disp_map_batch=disp_map[np.newaxis, ...], filename=os.path.join(args.outdir, reg_method, '{:03d}_figures_img.png'.format(step)), show=False)
|
206 |
save_disp_map_img(disp_map[np.newaxis, ...], 'Displacement map', os.path.join(args.outdir, reg_method, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
207 |
|
208 |
for k in metrics_file.keys():
|
209 |
print('Summary {}\n=======\n'.format(k))
|
210 |
+
metrics_df = pd.read_csv(metrics_file[k], sep=';', header=0)
|
211 |
+
print('\nAVG:\n')
|
212 |
+
print(metrics_df.mean(axis=0))
|
213 |
+
print('\nSTD:\n')
|
214 |
+
print(metrics_df.std(axis=0))
|
215 |
+
print('\nHD95perc:\n')
|
216 |
+
print(metrics_df['HD'].describe(percentiles=[.95]))
|
217 |
print('\n=======\n')
|
requirements.txt
CHANGED
@@ -27,20 +27,23 @@ et-xmlfile==1.0.1
|
|
27 |
fastrlock==0.6
|
28 |
flatbuffers==1.12
|
29 |
future==0.18.2
|
30 |
-
gast==0.
|
31 |
google-auth==1.35.0
|
32 |
google-auth-oauthlib==0.4.6
|
33 |
google-pasta==0.2.0
|
34 |
googleapis-common-protos==1.53.0
|
35 |
grpcio==1.40.0
|
36 |
-
h5py==
|
37 |
idna==2.10
|
38 |
imageio==2.9.0
|
39 |
importlib-metadata==3.4.0
|
40 |
importlib-resources==5.2.2
|
|
|
41 |
ipykernel==5.5.3
|
42 |
ipython==7.16.1
|
43 |
ipython-genutils==0.2.0
|
|
|
|
|
44 |
ipywidgets==7.6.3
|
45 |
jedi==0.18.0
|
46 |
Jinja2==2.11.3
|
@@ -84,7 +87,7 @@ patsy==0.5.1
|
|
84 |
pexpect==4.8.0
|
85 |
pickleshare==0.7.5
|
86 |
Pillow==8.1.0
|
87 |
-
|
88 |
plotly==4.14.3
|
89 |
plyfile==0.7.3
|
90 |
probreg==0.3.1
|
@@ -107,6 +110,7 @@ pyrsistent==0.17.3
|
|
107 |
pystrum==0.1
|
108 |
python-dateutil==2.8.1
|
109 |
python-utils==2.5.6
|
|
|
110 |
pytz==2021.1
|
111 |
PyWavelets==1.1.1
|
112 |
PyYAML==5.4.1
|
@@ -123,13 +127,13 @@ SimpleITK==2.0.2
|
|
123 |
six==1.15.0
|
124 |
sklearn==0.0
|
125 |
statsmodels==0.12.2
|
126 |
-
|
|
|
127 |
tensorboard-data-server==0.6.1
|
128 |
-
tensorboard-plugin-wit==1.8.0
|
129 |
tensorflow-addons==0.14.0
|
130 |
tensorflow-datasets==4.4.0
|
131 |
-
tensorflow-estimator==1.
|
132 |
-
tensorflow-gpu==1.
|
133 |
tensorflow-metadata==1.2.0
|
134 |
termcolor==1.1.0
|
135 |
terminado==0.9.4
|
@@ -141,6 +145,7 @@ tikzplotlib==0.9.7
|
|
141 |
tornado==6.1
|
142 |
tqdm==4.56.0
|
143 |
traitlets==4.3.3
|
|
|
144 |
transformations==2020.1.1
|
145 |
trimesh==3.9.29
|
146 |
typeguard==2.12.1
|
|
|
27 |
fastrlock==0.6
|
28 |
flatbuffers==1.12
|
29 |
future==0.18.2
|
30 |
+
gast==0.2.2
|
31 |
google-auth==1.35.0
|
32 |
google-auth-oauthlib==0.4.6
|
33 |
google-pasta==0.2.0
|
34 |
googleapis-common-protos==1.53.0
|
35 |
grpcio==1.40.0
|
36 |
+
h5py==2.10.0
|
37 |
idna==2.10
|
38 |
imageio==2.9.0
|
39 |
importlib-metadata==3.4.0
|
40 |
importlib-resources==5.2.2
|
41 |
+
ipydatawidgets==4.2.0
|
42 |
ipykernel==5.5.3
|
43 |
ipython==7.16.1
|
44 |
ipython-genutils==0.2.0
|
45 |
+
ipyvolume==0.5.2
|
46 |
+
ipywebrtc==0.6.0
|
47 |
ipywidgets==7.6.3
|
48 |
jedi==0.18.0
|
49 |
Jinja2==2.11.3
|
|
|
87 |
pexpect==4.8.0
|
88 |
pickleshare==0.7.5
|
89 |
Pillow==8.1.0
|
90 |
+
pkg_resources==0.0.0
|
91 |
plotly==4.14.3
|
92 |
plyfile==0.7.3
|
93 |
probreg==0.3.1
|
|
|
110 |
pystrum==0.1
|
111 |
python-dateutil==2.8.1
|
112 |
python-utils==2.5.6
|
113 |
+
pythreejs==2.3.0
|
114 |
pytz==2021.1
|
115 |
PyWavelets==1.1.1
|
116 |
PyYAML==5.4.1
|
|
|
127 |
six==1.15.0
|
128 |
sklearn==0.0
|
129 |
statsmodels==0.12.2
|
130 |
+
tabulate==0.8.9
|
131 |
+
tensorboard==1.14.0
|
132 |
tensorboard-data-server==0.6.1
|
|
|
133 |
tensorflow-addons==0.14.0
|
134 |
tensorflow-datasets==4.4.0
|
135 |
+
tensorflow-estimator==1.14.0
|
136 |
+
tensorflow-gpu==1.14.0
|
137 |
tensorflow-metadata==1.2.0
|
138 |
termcolor==1.1.0
|
139 |
terminado==0.9.4
|
|
|
145 |
tornado==6.1
|
146 |
tqdm==4.56.0
|
147 |
traitlets==4.3.3
|
148 |
+
traittypes==0.2.1
|
149 |
transformations==2020.1.1
|
150 |
trimesh==3.9.29
|
151 |
typeguard==2.12.1
|