Commit
·
6a4f823
1
Parent(s):
74c6a32
Scripts for training on the IXI T1 MRI Dataset
Browse files- Brain_study/Build_test_set.py +137 -0
- Brain_study/Evaluate_network.py +187 -0
- Brain_study/Evaluate_network__test_fixed.py +249 -0
- Brain_study/MultiTrain_Baseline.py +32 -0
- Brain_study/MultiTrain_SegGuided.py +34 -0
- Brain_study/MultiTrain_UW.py +38 -0
- Brain_study/Train_Baseline.py +322 -0
- Brain_study/Train_SegmentationGuided.py +323 -0
- Brain_study/Train_UncertaintyWeighted.py +304 -0
- Brain_study/__init__.py +0 -0
- Brain_study/data_generator.py +453 -0
- Brain_study/format_dataset.py +91 -0
- Brain_study/split_dataset.py +68 -0
- Brain_study/test_datagenerator.py +91 -0
- Brain_study/utils.py +41 -0
- EvaluationScripts/Evaluate_class.py +184 -0
Brain_study/Build_test_set.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
8 |
+
parentdir = os.path.dirname(currentdir)
|
9 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
10 |
+
|
11 |
+
import tensorflow as tf
|
12 |
+
# tf.enable_eager_execution(config=config)
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import h5py
|
16 |
+
|
17 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
20 |
+
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
21 |
+
from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from Brain_study.data_generator import BatchGenerator
|
25 |
+
|
26 |
+
from skimage.measure import regionprops
|
27 |
+
from scipy.interpolate import griddata
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
|
31 |
+
|
32 |
+
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
33 |
+
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
|
34 |
+
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
|
35 |
+
|
36 |
+
POINTS = None
|
37 |
+
MISSING_CENTROID = np.asarray([[np.nan]*3])
|
38 |
+
|
39 |
+
|
40 |
+
def get_mov_centroids(fix_seg, disp_map):
|
41 |
+
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, 28))
|
42 |
+
disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
|
43 |
+
return fix_centroids, fix_centroids + disp, disp
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
|
49 |
+
parser.add_argument('--reldir', type=str, help='Relative path to dataset, in where to store the files', default='')
|
50 |
+
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
51 |
+
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
|
52 |
+
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
assert args.dataset != '', "Missing original dataset dataset"
|
56 |
+
if args.dir == '' and args.reldir != '':
|
57 |
+
OUTPUT_FOLDER_DIR = os.path.join(args.dataset, 'test_dataset')
|
58 |
+
elif args.dir != '' and args.reldir == '':
|
59 |
+
OUTPUT_FOLDER_DIR = args.dir
|
60 |
+
else:
|
61 |
+
raise ValueError("Either provide 'dir' or 'reldir'")
|
62 |
+
|
63 |
+
if args.erase:
|
64 |
+
shutil.rmtree(OUTPUT_FOLDER_DIR, ignore_errors=True)
|
65 |
+
os.makedirs(OUTPUT_FOLDER_DIR, exist_ok=True)
|
66 |
+
print('DESTINATION FOLDER: ', OUTPUT_FOLDER_DIR)
|
67 |
+
|
68 |
+
DATASET = args.dataset
|
69 |
+
|
70 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
71 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
72 |
+
|
73 |
+
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
|
74 |
+
|
75 |
+
img_generator = data_generator.get_train_generator()
|
76 |
+
nb_labels = len(img_generator.get_segmentation_labels())
|
77 |
+
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
78 |
+
image_output_shape = [64] * 3
|
79 |
+
# Build model
|
80 |
+
|
81 |
+
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
|
82 |
+
yy = np.linspace(0, image_output_shape[1], image_output_shape[2], endpoint=False)
|
83 |
+
zz = np.linspace(0, image_output_shape[2], image_output_shape[1], endpoint=False)
|
84 |
+
|
85 |
+
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
86 |
+
|
87 |
+
POINTS = np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
|
88 |
+
|
89 |
+
input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm')
|
90 |
+
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
91 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
92 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
93 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
94 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
95 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
96 |
+
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
|
97 |
+
in_img_shape=image_input_shape,
|
98 |
+
out_img_shape=image_output_shape,
|
99 |
+
only_image=False,
|
100 |
+
only_resize=False,
|
101 |
+
trainable=False,
|
102 |
+
return_displacement_map=True)
|
103 |
+
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
|
104 |
+
|
105 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
106 |
+
config.gpu_options.allow_growth = True
|
107 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
108 |
+
|
109 |
+
sess = tf.Session(config=config)
|
110 |
+
tf.keras.backend.set_session(sess)
|
111 |
+
with sess.as_default():
|
112 |
+
sess.run(tf.global_variables_initializer())
|
113 |
+
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
|
114 |
+
for step, (in_batch, _) in progress_bar:
|
115 |
+
fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
|
116 |
+
|
117 |
+
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map)
|
118 |
+
|
119 |
+
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
|
120 |
+
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
|
121 |
+
img_shape = fix_img.shape
|
122 |
+
segm_shape = fix_seg.shape
|
123 |
+
disp_shape = disp_map.shape
|
124 |
+
centroids_shape = fix_centroids.shape
|
125 |
+
with h5py.File(out_file, 'w') as f:
|
126 |
+
f.create_dataset('fix_image', shape=img_shape[1:], dtype=np.float32, data=fix_img[0, ...])
|
127 |
+
f.create_dataset('mov_image', shape=img_shape[1:], dtype=np.float32, data=mov_img[0, ...])
|
128 |
+
f.create_dataset('fix_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=fix_seg[0, ...])
|
129 |
+
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
|
130 |
+
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
|
131 |
+
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
|
132 |
+
|
133 |
+
with h5py.File(out_file_dm, 'w') as f:
|
134 |
+
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
|
135 |
+
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
|
136 |
+
|
137 |
+
print('Done')
|
Brain_study/Evaluate_network.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import h5py
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
9 |
+
parentdir = os.path.dirname(currentdir)
|
10 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
11 |
+
|
12 |
+
import tensorflow as tf
|
13 |
+
# tf.enable_eager_execution(config=config)
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import pandas as pd
|
17 |
+
import voxelmorph as vxm
|
18 |
+
|
19 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
20 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
21 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
22 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion
|
23 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
24 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
25 |
+
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
26 |
+
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
27 |
+
from scipy.interpolate import RegularGridInterpolator
|
28 |
+
from tqdm import tqdm
|
29 |
+
|
30 |
+
import h5py
|
31 |
+
|
32 |
+
from Brain_study.data_generator import BatchGenerator
|
33 |
+
|
34 |
+
import argparse
|
35 |
+
|
36 |
+
from skimage.transform import warp
|
37 |
+
import neurite as ne
|
38 |
+
|
39 |
+
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
40 |
+
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
|
41 |
+
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
|
42 |
+
|
43 |
+
OUTPUT_FOLDER_NAME = 'Evaluate'
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument('-m', '--model', type=str, help='.h5 of the model', default='')
|
48 |
+
parser.add_argument('-d', '--dir', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default='')
|
49 |
+
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
50 |
+
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
|
51 |
+
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training')
|
52 |
+
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
53 |
+
args = parser.parse_args()
|
54 |
+
if args.model != '':
|
55 |
+
assert '.h5' in args.model, 'No checkpoint file provided, use -d/--dir instead'
|
56 |
+
MODEL_FILE = args.model
|
57 |
+
DATA_ROOT_DIR = os.path.split(args.model)[0]
|
58 |
+
elif args.dir != '':
|
59 |
+
assert '.h5' not in args.model, 'Provided checkpoint file, user -m/--model instead'
|
60 |
+
MODEL_FILE = os.path.join(args.dir, 'checkpoints', 'best_model.h5')
|
61 |
+
DATA_ROOT_DIR = args.dir
|
62 |
+
else:
|
63 |
+
raise ValueError("Provide either the model file or the directory ./containing checkpoints/best_model.h5")
|
64 |
+
|
65 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
66 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
67 |
+
DATASET = args.dataset
|
68 |
+
|
69 |
+
print('MODEL LOCATION: ', MODEL_FILE)
|
70 |
+
|
71 |
+
# data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
|
72 |
+
output_folder = os.path.join(DATA_ROOT_DIR, OUTPUT_FOLDER_NAME) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
|
73 |
+
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
|
74 |
+
if args.erase:
|
75 |
+
shutil.rmtree(output_folder, ignore_errors=True)
|
76 |
+
os.makedirs(output_folder, exist_ok=True)
|
77 |
+
print('DESTINATION FOLDER: ', output_folder)
|
78 |
+
|
79 |
+
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
|
80 |
+
|
81 |
+
img_generator = data_generator.get_train_generator()
|
82 |
+
nb_labels = len(img_generator.get_segmentation_labels())
|
83 |
+
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
84 |
+
image_output_shape = [64] * 3
|
85 |
+
|
86 |
+
# Build model
|
87 |
+
|
88 |
+
input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm')
|
89 |
+
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
90 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
91 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
92 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
93 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
94 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
95 |
+
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
|
96 |
+
in_img_shape=image_input_shape,
|
97 |
+
out_img_shape=image_output_shape,
|
98 |
+
only_image=False,
|
99 |
+
only_resize=False,
|
100 |
+
trainable=False)
|
101 |
+
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
|
102 |
+
|
103 |
+
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
104 |
+
NCC(image_input_shape).loss,
|
105 |
+
vxm.losses.MSE().loss,
|
106 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
|
107 |
+
|
108 |
+
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
109 |
+
NCC(image_input_shape).metric,
|
110 |
+
vxm.losses.MSE().loss,
|
111 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
|
112 |
+
GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).loss,
|
113 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).loss]
|
114 |
+
|
115 |
+
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
116 |
+
'VxmDense': vxm.networks.VxmDense,
|
117 |
+
'AdamAccumulated': AdamAccumulated,
|
118 |
+
'loss': loss_fncs,
|
119 |
+
'metric': metric_fncs},
|
120 |
+
compile=False)
|
121 |
+
|
122 |
+
# Needed for VxmDense type of network
|
123 |
+
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
124 |
+
|
125 |
+
# Record metrics
|
126 |
+
metrics = pd.DataFrame(columns=['File', 'SSIM', 'MS-SSIM', 'MSE', 'DICE', 'HD'])
|
127 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
128 |
+
config.gpu_options.allow_growth = True
|
129 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
130 |
+
|
131 |
+
sess = tf.Session(config=config)
|
132 |
+
tf.keras.backend.set_session(sess)
|
133 |
+
with sess.as_default():
|
134 |
+
sess.run(tf.global_variables_initializer())
|
135 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
136 |
+
progress_bar = tqdm(enumerate(img_generator, 1), desc='Evaluation', total=len(img_generator))
|
137 |
+
for step, (in_batch, _) in progress_bar:
|
138 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
139 |
+
|
140 |
+
if network.name == 'vxm_dense_semi_supervised_seg':
|
141 |
+
pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
|
142 |
+
else:
|
143 |
+
pred_img, disp_map = network.predict([mov_img, fix_img])
|
144 |
+
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
|
145 |
+
|
146 |
+
# I need the labels to be OHE to compute the segmentation metrics.
|
147 |
+
dice = GeneralizedDICEScore(image_output_shape + [img_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric(fix_seg, pred_seg).eval()
|
148 |
+
hd = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [img_generator.get_data_shape()[2][-1]]).metric(fix_seg, pred_seg).eval()
|
149 |
+
|
150 |
+
pred_seg = np.argmax(pred_seg, axis=-1)[..., np.newaxis].astype(np.float32)
|
151 |
+
mov_seg = np.argmax(mov_seg, axis=-1)[..., np.newaxis].astype(np.float32)
|
152 |
+
fix_seg = np.argmax(fix_seg, axis=-1)[..., np.newaxis].astype(np.float32)
|
153 |
+
|
154 |
+
mov_coords = np.stack(np.meshgrid(*[np.arange(0, 64)]*3), axis=-1)
|
155 |
+
dest_coords = mov_coords + disp_map[0, ...]
|
156 |
+
|
157 |
+
ssim = StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric(fix_img, pred_img).eval()
|
158 |
+
ms_ssim = MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric(fix_img, pred_img).eval()[0]
|
159 |
+
mse = vxm.losses.MSE().loss(fix_img, pred_img).eval()
|
160 |
+
|
161 |
+
metrics.append({'File': step,
|
162 |
+
'SSIM': ssim,
|
163 |
+
'MS-SSIM': ms_ssim,
|
164 |
+
'MSE': mse,
|
165 |
+
'DICE': dice,
|
166 |
+
'HD': hd}, ignore_index=True)
|
167 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
168 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
169 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
170 |
+
save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
171 |
+
save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
172 |
+
save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
173 |
+
|
174 |
+
magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
|
175 |
+
_ = plt.hist(magnitude.flatten())
|
176 |
+
plt.title('Histogram of disp. magnitudes')
|
177 |
+
# plt.show(block=False)
|
178 |
+
plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
179 |
+
plt.close()
|
180 |
+
|
181 |
+
plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures.png'.format(step)), show=False)
|
182 |
+
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
183 |
+
|
184 |
+
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
|
185 |
+
|
186 |
+
metrics.to_csv(os.path.join(output_folder, 'metrics.csv'))
|
187 |
+
print('Done')
|
Brain_study/Evaluate_network__test_fixed.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
import time
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
10 |
+
parentdir = os.path.dirname(currentdir)
|
11 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
12 |
+
|
13 |
+
import tensorflow as tf
|
14 |
+
# tf.enable_eager_execution(config=config)
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import pandas as pd
|
18 |
+
import voxelmorph as vxm
|
19 |
+
|
20 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
22 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
23 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
24 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
25 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
26 |
+
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
27 |
+
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
|
28 |
+
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
29 |
+
from scipy.interpolate import RegularGridInterpolator
|
30 |
+
from tqdm import tqdm
|
31 |
+
|
32 |
+
import h5py
|
33 |
+
import re
|
34 |
+
from Brain_study.data_generator import BatchGenerator
|
35 |
+
|
36 |
+
import argparse
|
37 |
+
|
38 |
+
from skimage.transform import warp
|
39 |
+
import neurite as ne
|
40 |
+
|
41 |
+
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
42 |
+
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
|
43 |
+
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument('-m', '--model', nargs='+', type=str, help='.h5 of the model', default=None)
|
49 |
+
parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
|
50 |
+
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
51 |
+
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
|
52 |
+
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training')
|
53 |
+
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
54 |
+
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
55 |
+
args = parser.parse_args()
|
56 |
+
if args.model is not None:
|
57 |
+
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
58 |
+
MODEL_FILE_LIST = args.model
|
59 |
+
DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in args.model]
|
60 |
+
elif args.dir is not None:
|
61 |
+
assert '.h5' not in args.dir[0], 'Provided checkpoint file, user -m/--model instead'
|
62 |
+
MODEL_FILE_LIST = [os.path.join(dir_path, 'checkpoints', 'best_model.h5') for dir_path in args.dir]
|
63 |
+
DATA_ROOT_DIR_LIST = args.dir
|
64 |
+
else:
|
65 |
+
raise ValueError("Provide either the model file or the directory ./containing checkpoints/best_model.h5")
|
66 |
+
|
67 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
68 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
69 |
+
DATASET = args.dataset
|
70 |
+
list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
|
71 |
+
list_test_files.sort()
|
72 |
+
|
73 |
+
with h5py.File(list_test_files[0], 'r') as f:
|
74 |
+
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
75 |
+
nb_labels = f['fix_segmentations'][:].shape[-1]
|
76 |
+
|
77 |
+
# Header of the metrics csv file
|
78 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
79 |
+
|
80 |
+
# TF stuff
|
81 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
82 |
+
config.gpu_options.allow_growth = True
|
83 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
84 |
+
config.allow_soft_placement = True
|
85 |
+
|
86 |
+
sess = tf.Session(config=config)
|
87 |
+
tf.keras.backend.set_session(sess)
|
88 |
+
|
89 |
+
# Loss and metric functions. Common to all models
|
90 |
+
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
91 |
+
NCC(image_input_shape).loss,
|
92 |
+
vxm.losses.MSE().loss,
|
93 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
|
94 |
+
|
95 |
+
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
96 |
+
NCC(image_input_shape).metric,
|
97 |
+
vxm.losses.MSE().loss,
|
98 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
|
99 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
100 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
101 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
102 |
+
|
103 |
+
### METRICS GRAPH ###
|
104 |
+
fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
|
105 |
+
pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
|
106 |
+
fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
|
107 |
+
pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
|
108 |
+
|
109 |
+
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
110 |
+
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
111 |
+
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
112 |
+
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
113 |
+
dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
114 |
+
hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
115 |
+
dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
116 |
+
# hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
|
117 |
+
|
118 |
+
# Needed for VxmDense type of network
|
119 |
+
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
120 |
+
|
121 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
|
122 |
+
|
123 |
+
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
124 |
+
print('MODEL LOCATION: ', MODEL_FILE)
|
125 |
+
|
126 |
+
# data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
|
127 |
+
output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
|
128 |
+
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
|
129 |
+
if args.erase:
|
130 |
+
shutil.rmtree(output_folder, ignore_errors=True)
|
131 |
+
os.makedirs(output_folder, exist_ok=True)
|
132 |
+
print('DESTINATION FOLDER: ', output_folder)
|
133 |
+
|
134 |
+
try:
|
135 |
+
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
136 |
+
'VxmDense': vxm.networks.VxmDense,
|
137 |
+
'AdamAccumulated': AdamAccumulated,
|
138 |
+
'loss': loss_fncs,
|
139 |
+
'metric': metric_fncs},
|
140 |
+
compile=False)
|
141 |
+
except ValueError as e:
|
142 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
143 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
144 |
+
nb_features = [enc_features, dec_features]
|
145 |
+
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
146 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
147 |
+
nb_labels=nb_labels,
|
148 |
+
nb_unet_features=nb_features,
|
149 |
+
int_steps=0,
|
150 |
+
int_downsize=1,
|
151 |
+
seg_downsize=1)
|
152 |
+
else:
|
153 |
+
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
154 |
+
nb_unet_features=nb_features,
|
155 |
+
int_steps=0)
|
156 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
157 |
+
# Record metrics
|
158 |
+
metrics_file = os.path.join(output_folder, 'metrics.csv')
|
159 |
+
with open(metrics_file, 'w') as f:
|
160 |
+
f.write(';'.join(csv_header)+'\n')
|
161 |
+
|
162 |
+
ssim = ncc = mse = ms_ssim = dice = hd = 0
|
163 |
+
with sess.as_default():
|
164 |
+
sess.run(tf.global_variables_initializer())
|
165 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
166 |
+
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
167 |
+
for step, in_batch in progress_bar:
|
168 |
+
with h5py.File(in_batch, 'r') as f:
|
169 |
+
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
170 |
+
mov_img = f['mov_image'][:][np.newaxis, ...]
|
171 |
+
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
172 |
+
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
173 |
+
fix_centroids = f['fix_centroids'][:]
|
174 |
+
|
175 |
+
if network.name == 'vxm_dense_semi_supervised_seg':
|
176 |
+
t0 = time.time()
|
177 |
+
pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
|
178 |
+
t1 = time.time()
|
179 |
+
else:
|
180 |
+
t0 = time.time()
|
181 |
+
pred_img, disp_map = network.predict([mov_img, fix_img])
|
182 |
+
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
|
183 |
+
t1 = time.time()
|
184 |
+
|
185 |
+
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(0, 28))
|
186 |
+
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
187 |
+
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
188 |
+
|
189 |
+
# I need the labels to be OHE to compute the segmentation metrics.
|
190 |
+
dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
191 |
+
|
192 |
+
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
193 |
+
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
194 |
+
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
195 |
+
|
196 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
|
197 |
+
ms_ssim = ms_ssim[0]
|
198 |
+
|
199 |
+
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
200 |
+
upsample_scale = 128 / 64
|
201 |
+
fix_centroids_isotropic = fix_centroids * upsample_scale
|
202 |
+
# mov_centroids_isotropic = mov_centroids * upsample_scale
|
203 |
+
pred_centroids_isotropic = pred_centroids * upsample_scale
|
204 |
+
|
205 |
+
fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
206 |
+
# mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
207 |
+
pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
208 |
+
# Now we can measure the TRE in mm
|
209 |
+
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
210 |
+
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
211 |
+
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
212 |
+
|
213 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
214 |
+
with open(metrics_file, 'a') as f:
|
215 |
+
f.write(';'.join(map(str, new_line))+'\n')
|
216 |
+
|
217 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
218 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
219 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
220 |
+
save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
221 |
+
save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
222 |
+
save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
223 |
+
|
224 |
+
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
225 |
+
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
226 |
+
# f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
|
227 |
+
# f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
|
228 |
+
# f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
|
229 |
+
# f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
|
230 |
+
|
231 |
+
# magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
|
232 |
+
# _ = plt.hist(magnitude.flatten())
|
233 |
+
# plt.title('Histogram of disp. magnitudes')
|
234 |
+
# plt.show(block=False)
|
235 |
+
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
236 |
+
# plt.close()
|
237 |
+
|
238 |
+
plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures.png'.format(step)), show=False)
|
239 |
+
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
240 |
+
|
241 |
+
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
|
242 |
+
|
243 |
+
print('Summary\n=======\n')
|
244 |
+
print(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0))
|
245 |
+
print('\n=======\n')
|
246 |
+
tf.keras.backend.clear_session()
|
247 |
+
# sess.close()
|
248 |
+
del network
|
249 |
+
print('Done')
|
Brain_study/MultiTrain_Baseline.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
from Brain_study.Train_Baseline import launch_train
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
10 |
+
|
11 |
+
err = list()
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
|
16 |
+
parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
|
17 |
+
parser.add_argument('--similarity', type=str, help='Similarity metric: mse, ncc, ssim')
|
18 |
+
parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
|
19 |
+
parser.add_argument('--gpu', type=str, help='GPU number', default='0')
|
20 |
+
parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
|
21 |
+
parser.add_argument('--rw', type=float, help='Regularization weigh', default=5e-3)
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
|
25 |
+
print('TRAIN ' + args.dataset)
|
26 |
+
launch_train(dataset_folder=args.dataset,
|
27 |
+
validation_folder=args.validation,
|
28 |
+
output_folder=os.path.join(args.output, 'BASELINE_L_{}__MET_mse_ncc_ssim'.format(args.similarity)),
|
29 |
+
gpu_num=args.gpu,
|
30 |
+
lr=args.lr,
|
31 |
+
rw=args.rw,
|
32 |
+
simil=args.similarity)
|
Brain_study/MultiTrain_SegGuided.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
from Brain_study.Train_SegmentationGuided import launch_train
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
10 |
+
|
11 |
+
err = list()
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
|
16 |
+
parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
|
17 |
+
parser.add_argument('--similarity', type=str, help='Similarity loss function: mse, ncc, ssim')
|
18 |
+
parser.add_argument('--segmentation', type=str, help='Segmentation loss function: hd, dice')
|
19 |
+
parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
|
20 |
+
parser.add_argument('--gpu', type=str, help='GPU number', default='0')
|
21 |
+
parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
|
22 |
+
parser.add_argument('--rw', type=float, help='Regularization weigh', default=2e-2)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
print('TRAIN ' + args.dataset)
|
27 |
+
launch_train(dataset_folder=args.dataset,
|
28 |
+
validation_folder=args.validation,
|
29 |
+
output_folder=os.path.join(args.output, 'SEGGUIDED_Lsim_{}__Lseg_{}__MET_mse_ncc_ssim'.format(args.similarity, args.segmentation)),
|
30 |
+
gpu_num=args.gpu,
|
31 |
+
lr=args.lr,
|
32 |
+
rw=args.rw,
|
33 |
+
simil=args.similarity,
|
34 |
+
segm=args.segmentation)
|
Brain_study/MultiTrain_UW.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
from Brain_study.Train_UncertaintyWeighted import launch_train
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
10 |
+
|
11 |
+
err = list()
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
|
16 |
+
parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
|
17 |
+
parser.add_argument('--similarity', nargs='+', type=str, help='Similarity loss function: mse, ncc, ssim', default=[])
|
18 |
+
parser.add_argument('--segmentation', nargs='+', type=str, help='Segmentation loss function: hd, dice', default=[])
|
19 |
+
parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
|
20 |
+
parser.add_argument('--gpu', type=str, help='GPU number', default='0')
|
21 |
+
parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
|
22 |
+
parser.add_argument('--rw', type=float, help='Regularization weigh', default=5e-3)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
output_folder = os.path.join(args.output, 'UW_Lsim_{}__Lseg_{}__MET_mse_ncc_ssim'.format('__'.join(args.similarity),
|
27 |
+
'__'.join(args.segmentation)))
|
28 |
+
print('TRAIN ' + args.dataset)
|
29 |
+
launch_train(dataset_folder=args.dataset,
|
30 |
+
validation_folder=args.validation,
|
31 |
+
output_folder=output_folder,
|
32 |
+
gpu_num=args.gpu,
|
33 |
+
prior_reg_w=args.rw,
|
34 |
+
lr=args.lr,
|
35 |
+
simil=args.similarity,
|
36 |
+
segm=args.segmentation)
|
37 |
+
|
38 |
+
|
Brain_study/Train_Baseline.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
|
9 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
10 |
+
from tensorflow.keras import Input
|
11 |
+
from tensorflow.keras.models import Model
|
12 |
+
from tensorflow.python.keras.utils import Progbar
|
13 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
14 |
+
|
15 |
+
import voxelmorph as vxm
|
16 |
+
import neurite as ne
|
17 |
+
import h5py
|
18 |
+
import pickle
|
19 |
+
|
20 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
+
from DeepDeformationMapRegistration.losses import NCC, StructuralSimilarity, StructuralSimilarity_simplified
|
22 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir, DatasetCopy, function_decorator
|
23 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
24 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
25 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
26 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS
|
27 |
+
|
28 |
+
from Brain_study.data_generator import BatchGenerator
|
29 |
+
from Brain_study.utils import SummaryDictionary, named_logs
|
30 |
+
|
31 |
+
from tqdm import tqdm
|
32 |
+
from datetime import datetime
|
33 |
+
|
34 |
+
|
35 |
+
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim'):
|
36 |
+
assert dataset_folder is not None and output_folder is not None
|
37 |
+
|
38 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
39 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
40 |
+
C.GPU_NUM = str(gpu_num)
|
41 |
+
|
42 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
43 |
+
os.makedirs(output_folder, exist_ok=True)
|
44 |
+
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
45 |
+
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
46 |
+
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
47 |
+
C.VALIDATION_DATASET = validation_folder
|
48 |
+
C.ACCUM_GRADIENT_STEP = 16
|
49 |
+
C.BATCH_SIZE = 16 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
|
50 |
+
C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
|
51 |
+
C.LEARNING_RATE = lr
|
52 |
+
C.LIMIT_NUM_SAMPLES = None
|
53 |
+
C.EPOCHS = 10000
|
54 |
+
|
55 |
+
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
|
56 |
+
"GPU: {}\n" \
|
57 |
+
"BATCH SIZE: {}\n" \
|
58 |
+
"LR: {}\n" \
|
59 |
+
"SIMILARITY: {}\n" \
|
60 |
+
"REG. WEIGHT: {}\n" \
|
61 |
+
"EPOCHS: {:d}\n" \
|
62 |
+
"ACCUM. GRAD: {}\n" \
|
63 |
+
"EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
|
64 |
+
C.TRAINING_DATASET,
|
65 |
+
C.VALIDATION_DATASET,
|
66 |
+
C.GPU_NUM,
|
67 |
+
C.BATCH_SIZE,
|
68 |
+
C.LEARNING_RATE,
|
69 |
+
simil,
|
70 |
+
rw,
|
71 |
+
C.EPOCHS,
|
72 |
+
C.ACCUM_GRADIENT_STEP,
|
73 |
+
C.EARLY_STOP_PATIENCE)
|
74 |
+
|
75 |
+
log_file.write(aux)
|
76 |
+
print(aux)
|
77 |
+
|
78 |
+
# Load data
|
79 |
+
# Build data generator
|
80 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
81 |
+
C.TRAINING_PERC, True, ['none'], directory_val=C.VALIDATION_DATASET)
|
82 |
+
|
83 |
+
train_generator = data_generator.get_train_generator()
|
84 |
+
validation_generator = data_generator.get_validation_generator()
|
85 |
+
|
86 |
+
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
87 |
+
image_output_shape = [64] * 3
|
88 |
+
|
89 |
+
# Config the training sessions
|
90 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
91 |
+
config.gpu_options.allow_growth = True
|
92 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
93 |
+
sess = tf.Session(config=config)
|
94 |
+
tf.keras.backend.set_session(sess)
|
95 |
+
|
96 |
+
# Build model
|
97 |
+
input_layer_train = Input(shape=train_generator.get_data_shape()[-1], name='input_train')
|
98 |
+
augm_layer_train = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
99 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
100 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
101 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
102 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
103 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
104 |
+
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
|
105 |
+
in_img_shape=image_input_shape,
|
106 |
+
out_img_shape=image_output_shape,
|
107 |
+
only_image=True,
|
108 |
+
only_resize=False,
|
109 |
+
trainable=False)
|
110 |
+
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
|
111 |
+
|
112 |
+
input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid')
|
113 |
+
augm_layer_valid = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
114 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
115 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
116 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
117 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
118 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
119 |
+
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
|
120 |
+
in_img_shape=image_input_shape,
|
121 |
+
out_img_shape=image_output_shape,
|
122 |
+
only_image=False,
|
123 |
+
only_resize=False,
|
124 |
+
trainable=False)
|
125 |
+
augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
|
126 |
+
|
127 |
+
# Build model
|
128 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
129 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
130 |
+
nb_features = [enc_features, dec_features]
|
131 |
+
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
132 |
+
nb_unet_features=nb_features,
|
133 |
+
int_steps=0)
|
134 |
+
|
135 |
+
# Losses and loss weights
|
136 |
+
SSIM_KER_SIZE = 5
|
137 |
+
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
138 |
+
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
|
139 |
+
if simil.lower() == 'mse':
|
140 |
+
loss_fnc = vxm.losses.MSE().loss
|
141 |
+
elif simil.lower() == 'ncc':
|
142 |
+
loss_fnc = NCC(image_input_shape).loss
|
143 |
+
elif simil.lower() == 'ssim':
|
144 |
+
loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
|
145 |
+
elif simil.lower() == 'ms_ssim':
|
146 |
+
loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
|
147 |
+
elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse':
|
148 |
+
@function_decorator('MSSSIM_MSE__loss')
|
149 |
+
def loss_fnc(y_true, y_pred):
|
150 |
+
return vxm.losses.MSE().loss(y_true, y_pred) +\
|
151 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
|
152 |
+
elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc':
|
153 |
+
@function_decorator('MSSSIM_NCC__loss')
|
154 |
+
def loss_fnc(y_true, y_pred):
|
155 |
+
return NCC(image_input_shape).loss(y_true, y_pred) +\
|
156 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
|
157 |
+
elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse':
|
158 |
+
@function_decorator('SSIM_MSE__loss')
|
159 |
+
def loss_fnc(y_true, y_pred):
|
160 |
+
return vxm.losses.MSE().loss(y_true, y_pred) +\
|
161 |
+
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
|
162 |
+
elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc':
|
163 |
+
@function_decorator('SSIM_NCC__loss')
|
164 |
+
def loss_fnc(y_true, y_pred):
|
165 |
+
return NCC(image_input_shape).loss(y_true, y_pred) +\
|
166 |
+
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
|
167 |
+
else:
|
168 |
+
raise ValueError('Unknown similarity metric: ' + simil)
|
169 |
+
|
170 |
+
# Train
|
171 |
+
os.makedirs(output_folder, exist_ok=True)
|
172 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
173 |
+
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
174 |
+
os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
|
175 |
+
|
176 |
+
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
177 |
+
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
178 |
+
# callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
179 |
+
# save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
180 |
+
# CSVLogger(train_log_name, ';'),
|
181 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
182 |
+
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
183 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
|
184 |
+
update_freq='epoch', # or 'batch' or integer
|
185 |
+
write_graph=True, write_grads=True
|
186 |
+
)
|
187 |
+
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
188 |
+
|
189 |
+
losses = {'transformer': loss_fnc,
|
190 |
+
'flow': vxm.losses.Grad('l2').loss}
|
191 |
+
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
192 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
|
193 |
+
tf.keras.losses.MSE,
|
194 |
+
NCC(image_input_shape).metric],
|
195 |
+
#'flow': vxm.losses.Grad('l2').loss
|
196 |
+
}
|
197 |
+
loss_weights = {'transformer': 1.,
|
198 |
+
'flow': rw}
|
199 |
+
|
200 |
+
# Compile the model
|
201 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
202 |
+
network.compile(optimizer=optimizer,
|
203 |
+
loss=losses,
|
204 |
+
loss_weights=loss_weights,
|
205 |
+
metrics=metrics)
|
206 |
+
|
207 |
+
callback_tensorboard.set_model(network)
|
208 |
+
callback_best_model.set_model(network)
|
209 |
+
# callback_save_model.set_model(network)
|
210 |
+
callback_early_stop.set_model(network)
|
211 |
+
# TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
|
212 |
+
|
213 |
+
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
214 |
+
names = network.metrics_names # It give both the loss and metric names
|
215 |
+
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
216 |
+
with sess.as_default():
|
217 |
+
callback_tensorboard.on_train_begin()
|
218 |
+
callback_early_stop.on_train_begin()
|
219 |
+
callback_best_model.on_train_begin()
|
220 |
+
# callback_save_model.on_train_begin()
|
221 |
+
for epoch in range(C.EPOCHS):
|
222 |
+
callback_tensorboard.on_epoch_begin(epoch)
|
223 |
+
callback_early_stop.on_epoch_begin(epoch)
|
224 |
+
callback_best_model.on_epoch_begin(epoch)
|
225 |
+
# callback_save_model.on_epoch_begin(epoch)
|
226 |
+
|
227 |
+
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
228 |
+
print('TRAINING')
|
229 |
+
|
230 |
+
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
231 |
+
progress_bar = Progbar(len(train_generator), width=30, verbose=1)
|
232 |
+
for step, (in_batch, _) in enumerate(train_generator, 1):
|
233 |
+
# callback_tensorboard.on_train_batch_begin(step)
|
234 |
+
callback_best_model.on_train_batch_begin(step)
|
235 |
+
# callback_save_model.on_train_batch_begin(step)
|
236 |
+
callback_early_stop.on_train_batch_begin(step)
|
237 |
+
|
238 |
+
try:
|
239 |
+
fix_img, mov_img, *_ = augm_model_train.predict(in_batch)
|
240 |
+
np.nan_to_num(fix_img, copy=False)
|
241 |
+
np.nan_to_num(mov_img, copy=False)
|
242 |
+
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
|
243 |
+
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
|
244 |
+
np.unique(mov_img))
|
245 |
+
print(msg)
|
246 |
+
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
|
247 |
+
|
248 |
+
except InvalidArgumentError as err:
|
249 |
+
print('TF Error : {}'.format(str(err)))
|
250 |
+
continue
|
251 |
+
|
252 |
+
ret = network.train_on_batch(x=(mov_img, fix_img),
|
253 |
+
y=(fix_img, fix_img)) # The second element doesn't matter
|
254 |
+
if np.isnan(ret).any():
|
255 |
+
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
|
256 |
+
save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
|
257 |
+
save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
|
258 |
+
pred_img, dm = network((mov_img, fix_img))
|
259 |
+
save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
|
260 |
+
save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
|
261 |
+
log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
262 |
+
raise ValueError('CORRUPTION ERROR: Halting training')
|
263 |
+
|
264 |
+
summary.on_train_batch_end(ret)
|
265 |
+
# callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
|
266 |
+
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
267 |
+
# callback_save_model.on_train_batch_end(step, named_logs(network, ret))
|
268 |
+
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
269 |
+
progress_bar.update(step, zip(names, ret))
|
270 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
271 |
+
val_values = progress_bar._values.copy()
|
272 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
273 |
+
|
274 |
+
print('\nVALIDATION')
|
275 |
+
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
276 |
+
progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
|
277 |
+
for step, (in_batch, _) in enumerate(validation_generator, 1):
|
278 |
+
# callback_tensorboard.on_test_batch_begin(step) # This is cursed, don't do it again
|
279 |
+
# callback_early_stop.on_test_batch_begin(step)
|
280 |
+
try:
|
281 |
+
fix_img, mov_img, *_ = augm_model_valid.predict(in_batch)
|
282 |
+
except InvalidArgumentError as err:
|
283 |
+
print('TF Error : {}'.format(str(err)))
|
284 |
+
continue
|
285 |
+
|
286 |
+
ret = network.test_on_batch(x=(mov_img, fix_img),
|
287 |
+
y=(fix_img, fix_img))
|
288 |
+
# pred_segm = network.register(mov_segm, fix_segm)
|
289 |
+
summary.on_validation_batch_end(ret)
|
290 |
+
# callback_early_stop.on_test_batch_end(step, named_logs(network, ret))
|
291 |
+
# callback_tensorboard.on_test_batch_end(step, named_logs(network, ret)) # This is cursed, don't do it again
|
292 |
+
progress_bar.update(step, zip(names, ret))
|
293 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
294 |
+
val_values = progress_bar._values.copy()
|
295 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
296 |
+
|
297 |
+
train_generator.on_epoch_end()
|
298 |
+
validation_generator.on_epoch_end()
|
299 |
+
epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
|
300 |
+
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
301 |
+
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
302 |
+
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
303 |
+
# callback_save_model.on_epoch_end(epoch, epoch_summary)
|
304 |
+
print('End of epoch {}: '.format(epoch), ret, '\n')
|
305 |
+
|
306 |
+
callback_tensorboard.on_train_end()
|
307 |
+
# callback_save_model.on_train_end()
|
308 |
+
callback_best_model.on_train_end()
|
309 |
+
callback_early_stop.on_train_end()
|
310 |
+
|
311 |
+
|
312 |
+
if __name__ == '__main__':
|
313 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
314 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
|
315 |
+
|
316 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
317 |
+
config.gpu_options.allow_growth = True
|
318 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
319 |
+
tf.keras.backend.set_session(tf.Session(config=config))
|
320 |
+
|
321 |
+
launch_train('/mnt/EncryptedData1/Users/javier/vessel_registration/LiTS/None',
|
322 |
+
'TrainOutput/THESIS/UW_None_mse_ssim_haus', 0, mse=True)
|
Brain_study/Train_SegmentationGuided.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
9 |
+
from tensorflow.keras import Input
|
10 |
+
from tensorflow.keras.models import Model
|
11 |
+
from tensorflow.python.keras.utils import Progbar
|
12 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
13 |
+
|
14 |
+
import voxelmorph as vxm
|
15 |
+
import neurite as ne
|
16 |
+
import h5py
|
17 |
+
from datetime import datetime
|
18 |
+
import pickle
|
19 |
+
|
20 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
21 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir, function_decorator
|
22 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
23 |
+
from DeepDeformationMapRegistration.losses import NCC, HausdorffDistanceErosion, GeneralizedDICEScore, StructuralSimilarity_simplified
|
24 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
25 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS
|
26 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
27 |
+
|
28 |
+
from Brain_study.data_generator import BatchGenerator
|
29 |
+
from Brain_study.utils import SummaryDictionary, named_logs
|
30 |
+
|
31 |
+
import time
|
32 |
+
|
33 |
+
def launch_train(dataset_folder, validation_folder, output_folder, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim', segm='hd'):
|
34 |
+
assert dataset_folder is not None and output_folder is not None
|
35 |
+
|
36 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
37 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
38 |
+
C.GPU_NUM = str(gpu_num)
|
39 |
+
|
40 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
41 |
+
os.makedirs(output_folder, exist_ok=True)
|
42 |
+
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
43 |
+
C.TRAINING_DATASET = dataset_folder
|
44 |
+
C.VALIDATION_DATASET = validation_folder
|
45 |
+
C.ACCUM_GRADIENT_STEP = 16
|
46 |
+
C.BATCH_SIZE = 2 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
|
47 |
+
C.EARLY_STOP_PATIENCE = 10 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
|
48 |
+
C.LEARNING_RATE = lr
|
49 |
+
C.LIMIT_NUM_SAMPLES = None
|
50 |
+
C.EPOCHS = 10000
|
51 |
+
|
52 |
+
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
|
53 |
+
"GPU: {}\n" \
|
54 |
+
"BATCH SIZE: {}\n" \
|
55 |
+
"LR: {}\n" \
|
56 |
+
"SIMILARITY: {}\n" \
|
57 |
+
"REG. WEIGHT: {}\n" \
|
58 |
+
"EPOCHS: {:d}\n" \
|
59 |
+
"ACCUM. GRAD: {}\n" \
|
60 |
+
"EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
|
61 |
+
C.TRAINING_DATASET,
|
62 |
+
C.VALIDATION_DATASET,
|
63 |
+
C.GPU_NUM,
|
64 |
+
C.BATCH_SIZE,
|
65 |
+
C.LEARNING_RATE,
|
66 |
+
simil,
|
67 |
+
rw,
|
68 |
+
C.EPOCHS,
|
69 |
+
C.ACCUM_GRADIENT_STEP,
|
70 |
+
C.EARLY_STOP_PATIENCE)
|
71 |
+
log_file.write(aux)
|
72 |
+
print(aux)
|
73 |
+
|
74 |
+
# Load data
|
75 |
+
# Build data generator
|
76 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
77 |
+
C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
|
78 |
+
directory_val=C.VALIDATION_DATASET)
|
79 |
+
|
80 |
+
train_generator = data_generator.get_train_generator()
|
81 |
+
validation_generator = data_generator.get_validation_generator()
|
82 |
+
|
83 |
+
image_input_shape = train_generator.get_data_shape()[1][:-1]
|
84 |
+
image_output_shape = [64] * 3
|
85 |
+
|
86 |
+
nb_labels = len(train_generator.get_segmentation_labels())
|
87 |
+
|
88 |
+
# Config the training sessions
|
89 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
90 |
+
config.gpu_options.allow_growth = True
|
91 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
92 |
+
config.allow_soft_placement = True # https://github.com/tensorflow/tensorflow/issues/30782
|
93 |
+
sess = tf.Session(config=config)
|
94 |
+
tf.keras.backend.set_session(sess)
|
95 |
+
|
96 |
+
# Build model
|
97 |
+
input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation')
|
98 |
+
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
99 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
100 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
101 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
102 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
103 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
104 |
+
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
|
105 |
+
in_img_shape=image_input_shape,
|
106 |
+
out_img_shape=image_output_shape,
|
107 |
+
only_image=False,
|
108 |
+
only_resize=False,
|
109 |
+
trainable=False)
|
110 |
+
augm_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
|
111 |
+
|
112 |
+
enc_features = [16, 32, 32, 32]# const.ENCODER_FILTERS
|
113 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16]# const.ENCODER_FILTERS[::-1]
|
114 |
+
nb_features = [enc_features, dec_features]
|
115 |
+
|
116 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
117 |
+
nb_labels=nb_labels,
|
118 |
+
nb_unet_features=nb_features,
|
119 |
+
int_steps=0,
|
120 |
+
int_downsize=1,
|
121 |
+
seg_downsize=1)
|
122 |
+
|
123 |
+
# Compile the model
|
124 |
+
SSIM_KER_SIZE = 5
|
125 |
+
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
126 |
+
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
|
127 |
+
if simil=='ssim':
|
128 |
+
loss_simil = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
|
129 |
+
elif simil=='ms_ssim':
|
130 |
+
loss_simil = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
|
131 |
+
elif simil=='ncc':
|
132 |
+
loss_simil = NCC(image_input_shape).loss
|
133 |
+
elif simil=='ms_ssim__ncc' or simil=='ncc__ms_ssim':
|
134 |
+
@function_decorator('MS_SSIM_NCC__loss')
|
135 |
+
def loss_simil(y_true, y_pred):
|
136 |
+
return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred)
|
137 |
+
elif simil=='ms_ssim__mse' or simil=='mse__ms_ssim':
|
138 |
+
@function_decorator('MS_SSIM_MSE__loss')
|
139 |
+
def loss_simil(y_true, y_pred):
|
140 |
+
return MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred)
|
141 |
+
elif simil=='ssim__ncc' or simil=='ncc__ssim' :
|
142 |
+
@function_decorator('SSIM_NCC__loss')
|
143 |
+
def loss_simil(y_true, y_pred):
|
144 |
+
return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + NCC(image_input_shape).loss(y_true, y_pred)
|
145 |
+
elif simil=='ssim__mse' or simil=='mse__ssim':
|
146 |
+
@function_decorator('SSIM_MSE__loss')
|
147 |
+
def loss_simil(y_true, y_pred):
|
148 |
+
return StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred) + vxm.losses.MSE().loss(y_true, y_pred)
|
149 |
+
else:
|
150 |
+
loss_simil = vxm.losses.MSE().loss
|
151 |
+
|
152 |
+
if segm == 'hd':
|
153 |
+
loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
|
154 |
+
elif segm == 'dice':
|
155 |
+
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
|
156 |
+
else:
|
157 |
+
raise ValueError('No valid value for segm')
|
158 |
+
|
159 |
+
losses = {'transformer': loss_simil,
|
160 |
+
'seg_transformer': loss_segm,
|
161 |
+
'flow': vxm.losses.Grad('l2').loss}
|
162 |
+
loss_weights = {'transformer': 1,
|
163 |
+
'seg_transformer': 1.,
|
164 |
+
'flow': 5e-3}
|
165 |
+
metrics = {'transformer': [vxm.losses.MSE().loss, NCC(image_input_shape).metric, MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric],
|
166 |
+
'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
|
167 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric
|
168 |
+
]}
|
169 |
+
metrics_weights = {'transformer': 1,
|
170 |
+
'seg_transformer': 1,
|
171 |
+
'flow': rw}
|
172 |
+
|
173 |
+
# Train
|
174 |
+
os.makedirs(output_folder, exist_ok=True)
|
175 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
176 |
+
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
177 |
+
os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
|
178 |
+
|
179 |
+
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
180 |
+
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
181 |
+
# callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
182 |
+
# save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
183 |
+
# CSVLogger(train_log_name, ';'),
|
184 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
185 |
+
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
186 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
|
187 |
+
update_freq='epoch', # or 'batch' or integer
|
188 |
+
write_graph=True, write_grads=True
|
189 |
+
)
|
190 |
+
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
191 |
+
|
192 |
+
# Compile the model
|
193 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
|
194 |
+
network.compile(optimizer=optimizer,
|
195 |
+
loss=losses,
|
196 |
+
loss_weights=loss_weights,
|
197 |
+
metrics=metrics)
|
198 |
+
|
199 |
+
callback_tensorboard.set_model(network)
|
200 |
+
callback_best_model.set_model(network)
|
201 |
+
# callback_save_model.set_model(network)
|
202 |
+
callback_early_stop.set_model(network)
|
203 |
+
|
204 |
+
summary = SummaryDictionary(network, C.BATCH_SIZE, C.ACCUM_GRADIENT_STEP)
|
205 |
+
names = network.metrics_names # It give both the loss and metric names
|
206 |
+
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
207 |
+
with sess.as_default():
|
208 |
+
#sess.run(tf.global_variables_initializer())
|
209 |
+
callback_tensorboard.on_train_begin()
|
210 |
+
callback_early_stop.on_train_begin()
|
211 |
+
callback_best_model.on_train_begin()
|
212 |
+
# callback_save_model.on_train_begin()
|
213 |
+
for epoch in range(C.EPOCHS):
|
214 |
+
callback_tensorboard.on_epoch_begin(epoch)
|
215 |
+
callback_early_stop.on_epoch_begin(epoch)
|
216 |
+
callback_best_model.on_epoch_begin(epoch)
|
217 |
+
# callback_save_model.on_epoch_begin(epoch)
|
218 |
+
|
219 |
+
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
220 |
+
print('TRAINING')
|
221 |
+
|
222 |
+
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
223 |
+
progress_bar = Progbar(len(train_generator), width=30, verbose=1)
|
224 |
+
t0 = time.time()
|
225 |
+
for step, (in_batch, _) in enumerate(train_generator, 1):
|
226 |
+
#print('Loaded in {} s'.format(time.time() - t0))
|
227 |
+
# callback_tensorboard.on_train_batch_begin(step)
|
228 |
+
callback_best_model.on_train_batch_begin(step)
|
229 |
+
# callback_save_model.on_train_batch_begin(step)
|
230 |
+
callback_early_stop.on_train_batch_begin(step)
|
231 |
+
|
232 |
+
try:
|
233 |
+
t0 = time.time()
|
234 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
235 |
+
#print('Augmented in {} s'.format(time.time() - t0))
|
236 |
+
np.nan_to_num(fix_img, copy=False)
|
237 |
+
np.nan_to_num(mov_img, copy=False)
|
238 |
+
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
|
239 |
+
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
|
240 |
+
np.unique(mov_img))
|
241 |
+
print(msg)
|
242 |
+
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
|
243 |
+
|
244 |
+
except InvalidArgumentError as err:
|
245 |
+
print('TF Error : {}'.format(str(err)))
|
246 |
+
continue
|
247 |
+
|
248 |
+
t0 = time.time()
|
249 |
+
ret = network.train_on_batch(x=(mov_img, fix_img, mov_seg),
|
250 |
+
y=(fix_img, fix_img, fix_seg))
|
251 |
+
# print("Trained on batch in {} s".format(time.time() - t0))
|
252 |
+
|
253 |
+
if np.isnan(ret).any():
|
254 |
+
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
|
255 |
+
save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
|
256 |
+
save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
|
257 |
+
pred_img, dm = network((mov_img, fix_img))
|
258 |
+
save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
|
259 |
+
save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
|
260 |
+
log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
261 |
+
raise ValueError('CORRUPTION ERROR: Halting training')
|
262 |
+
|
263 |
+
summary.on_train_batch_end(ret)
|
264 |
+
# callback_tensorboard.on_train_batch_end(step, named_logs(network, ret))
|
265 |
+
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
266 |
+
# callback_save_model.on_train_batch_end(step, named_logs(network, ret))
|
267 |
+
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
268 |
+
progress_bar.update(step, zip(names, ret))
|
269 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
270 |
+
t0 = time.time()
|
271 |
+
print('End of epoch{}: '.format(step), ret, '\n')
|
272 |
+
val_values = progress_bar._values.copy()
|
273 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
274 |
+
|
275 |
+
print('\nVALIDATION')
|
276 |
+
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
277 |
+
progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
|
278 |
+
for step, (in_batch, _) in enumerate(validation_generator, 1):
|
279 |
+
# callback_tensorboard.on_test_batch_begin(step) # This is cursed, don't do it again
|
280 |
+
# callback_early_stop.on_test_batch_begin(step)
|
281 |
+
try:
|
282 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
283 |
+
except InvalidArgumentError as err:
|
284 |
+
print('TF Error : {}'.format(str(err)))
|
285 |
+
continue
|
286 |
+
|
287 |
+
ret = network.test_on_batch(x=(mov_img, fix_img, mov_seg),
|
288 |
+
y=(fix_img, fix_img, fix_seg))
|
289 |
+
# pred_segm = network.register(mov_segm, fix_segm)
|
290 |
+
summary.on_validation_batch_end(ret)
|
291 |
+
# callback_early_stop.on_test_batch_end(step, named_logs(network, ret))
|
292 |
+
# callback_tensorboard.on_test_batch_end(step, named_logs(network, ret)) # This is cursed, don't do it again
|
293 |
+
progress_bar.update(step, zip(names, ret))
|
294 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
295 |
+
val_values = progress_bar._values.copy()
|
296 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
297 |
+
|
298 |
+
train_generator.on_epoch_end()
|
299 |
+
validation_generator.on_epoch_end()
|
300 |
+
epoch_summary = summary.on_epoch_end()
|
301 |
+
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
302 |
+
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
303 |
+
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
304 |
+
# callback_save_model.on_epoch_end(epoch, named_logs(network, ret, True))
|
305 |
+
|
306 |
+
callback_tensorboard.on_train_end()
|
307 |
+
# callback_save_model.on_train_end()
|
308 |
+
callback_best_model.on_train_end()
|
309 |
+
callback_early_stop.on_train_end()
|
310 |
+
|
311 |
+
|
312 |
+
if __name__ == '__main__':
|
313 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
314 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
|
315 |
+
|
316 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
317 |
+
config.gpu_options.allow_growth = True
|
318 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
319 |
+
tf.keras.backend.set_session(tf.Session(config=config))
|
320 |
+
|
321 |
+
launch_train('/mnt/EncryptedData1/Users/javier/Brain_study/ERASE',
|
322 |
+
'TrainOutput/THESIS/UW_None_mse_ssim_haus',
|
323 |
+
0)
|
Brain_study/Train_UncertaintyWeighted.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
9 |
+
from tensorflow.keras import Input
|
10 |
+
from tensorflow.keras.models import Model
|
11 |
+
from tensorflow.python.keras.utils import Progbar
|
12 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
13 |
+
import voxelmorph as vxm
|
14 |
+
import neurite as ne
|
15 |
+
import h5py
|
16 |
+
from datetime import datetime
|
17 |
+
import pickle
|
18 |
+
|
19 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
20 |
+
from DeepDeformationMapRegistration.utils.misc import try_mkdir, DatasetCopy, function_decorator
|
21 |
+
from DeepDeformationMapRegistration.networks import WeaklySupervised
|
22 |
+
from DeepDeformationMapRegistration.losses import HausdorffDistanceErosion, NCC, StructuralSimilarity_simplified, GeneralizedDICEScore
|
23 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity, _MSSSIM_WEIGHTS
|
24 |
+
from DeepDeformationMapRegistration.layers import UncertaintyWeighting, AugmentationLayer
|
25 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
26 |
+
|
27 |
+
from Brain_study.data_generator import BatchGenerator
|
28 |
+
from Brain_study.utils import SummaryDictionary, named_logs
|
29 |
+
|
30 |
+
|
31 |
+
def launch_train(dataset_folder, validation_folder, output_folder, prior_reg_w=5e-3, lr=1e-4,
|
32 |
+
gpu_num=0, simil=['mse'], segm=['dice']):
|
33 |
+
assert dataset_folder is not None and output_folder is not None
|
34 |
+
|
35 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
36 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
37 |
+
C.GPU_NUM = str(gpu_num)
|
38 |
+
|
39 |
+
output_folder = os.path.join(output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y"))
|
40 |
+
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
41 |
+
os.makedirs(output_folder, exist_ok=True)
|
42 |
+
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
43 |
+
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
44 |
+
C.VALIDATION_DATASET = validation_folder
|
45 |
+
C.ACCUM_GRADIENT_STEP = 16
|
46 |
+
C.BATCH_SIZE = 2 if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
|
47 |
+
C.EARLY_STOP_PATIENCE = 10 * (C.ACCUM_GRADIENT_STEP/2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
|
48 |
+
C.LEARNING_RATE = lr
|
49 |
+
C.LIMIT_NUM_SAMPLES = None
|
50 |
+
C.EPOCHS = 10000
|
51 |
+
|
52 |
+
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
|
53 |
+
"GPU: {}\n" \
|
54 |
+
"BATCH SIZE: {}\n" \
|
55 |
+
"LR: {}\n" \
|
56 |
+
"SIMILARITY {:d}: {}\n" \
|
57 |
+
"SEGMENTATION {:d}: {}\n" \
|
58 |
+
"EPOCHS: {:d}" \
|
59 |
+
"ACCUM. GRAD: {}" \
|
60 |
+
"EARLY STOP PATIENCE: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
|
61 |
+
C.TRAINING_DATASET,
|
62 |
+
C.VALIDATION_DATASET,
|
63 |
+
C.GPU_NUM,
|
64 |
+
C.BATCH_SIZE,
|
65 |
+
C.LEARNING_RATE,
|
66 |
+
len(simil), ', '.join(simil),
|
67 |
+
len(segm), ', '.join(segm),
|
68 |
+
C.EPOCHS,
|
69 |
+
C.ACCUM_GRADIENT_STEP,
|
70 |
+
C.EARLY_STOP_PATIENCE)
|
71 |
+
log_file.write(aux)
|
72 |
+
print(aux)
|
73 |
+
|
74 |
+
# Load data
|
75 |
+
# Build data generator
|
76 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
77 |
+
C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
|
78 |
+
directory_val=C.VALIDATION_DATASET)
|
79 |
+
|
80 |
+
train_generator = data_generator.get_train_generator()
|
81 |
+
validation_generator = data_generator.get_validation_generator()
|
82 |
+
|
83 |
+
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
84 |
+
image_output_shape = [64] * 3
|
85 |
+
|
86 |
+
nb_labels = len(train_generator.get_segmentation_labels())
|
87 |
+
|
88 |
+
# Config the training sessions
|
89 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
90 |
+
config.gpu_options.allow_growth = True
|
91 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
92 |
+
config.allow_soft_placement = True
|
93 |
+
sess = tf.Session(config=config)
|
94 |
+
tf.keras.backend.set_session(sess)
|
95 |
+
|
96 |
+
# Losses and loss weights
|
97 |
+
SSIM_KER_SIZE = 5
|
98 |
+
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
99 |
+
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
|
100 |
+
|
101 |
+
loss_simil = []
|
102 |
+
prior_loss_w = []
|
103 |
+
for s in simil:
|
104 |
+
if s=='ssim':
|
105 |
+
loss_simil.append(StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss)
|
106 |
+
prior_loss_w.append(1.)
|
107 |
+
elif s=='ms_ssim':
|
108 |
+
loss_simil.append(MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss)
|
109 |
+
prior_loss_w.append(1.)
|
110 |
+
elif s=='ncc':
|
111 |
+
loss_simil.append(NCC(image_input_shape).loss)
|
112 |
+
prior_loss_w.append(1.)
|
113 |
+
elif s=='mse':
|
114 |
+
loss_simil.append(vxm.losses.MSE().loss)
|
115 |
+
prior_loss_w.append(1.)
|
116 |
+
else:
|
117 |
+
raise ValueError('Unknown similarity function: ', s)
|
118 |
+
|
119 |
+
loss_segm = []
|
120 |
+
for s in segm:
|
121 |
+
if s=='dice':
|
122 |
+
loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).loss)
|
123 |
+
prior_loss_w.append(1.)
|
124 |
+
elif s=='hd':
|
125 |
+
loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).loss)
|
126 |
+
prior_loss_w.append(1.)
|
127 |
+
else:
|
128 |
+
raise ValueError('Unknown similarity function: ', s)
|
129 |
+
|
130 |
+
# Build augmentation layer model
|
131 |
+
input_layer_augm = Input(shape=train_generator.get_data_shape()[0], name='input_augmentation')
|
132 |
+
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
133 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
134 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
135 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
136 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
137 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
138 |
+
in_img_shape=image_input_shape,
|
139 |
+
out_img_shape=image_output_shape,
|
140 |
+
only_image=False,
|
141 |
+
only_resize=False,
|
142 |
+
trainable=False)
|
143 |
+
augmentation_model = Model(inputs=input_layer_augm, outputs=augm_layer(input_layer_augm))
|
144 |
+
|
145 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
146 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
147 |
+
nb_features = [enc_features, dec_features]
|
148 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
149 |
+
nb_labels=nb_labels,
|
150 |
+
nb_unet_features=nb_features,
|
151 |
+
int_steps=0,
|
152 |
+
int_downsize=1,
|
153 |
+
seg_downsize=1)
|
154 |
+
# Network inputs: mov_img, fix_img, mov_seg
|
155 |
+
# Network outputs: pred_img, disp_map, pred_seg
|
156 |
+
grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
157 |
+
fix_seg = tf.keras.Input(shape=(*image_output_shape, len(train_generator.get_segmentation_labels())),
|
158 |
+
name='multiLoss_fix_seg_input', dtype=tf.float32)
|
159 |
+
|
160 |
+
multiLoss = UncertaintyWeighting(num_loss_fns=len(loss_simil) + len(loss_segm),
|
161 |
+
num_reg_fns=1,
|
162 |
+
loss_fns=[*loss_simil,
|
163 |
+
*loss_segm],
|
164 |
+
reg_fns=[vxm.losses.Grad('l2').loss],
|
165 |
+
prior_loss_w=prior_loss_w,
|
166 |
+
# prior_loss_w=[1., 0.1, 1., 1.],
|
167 |
+
prior_reg_w=[prior_reg_w],
|
168 |
+
name='MultiLossLayer')
|
169 |
+
loss = multiLoss([*[network.inputs[1]]*len(loss_simil), *[fix_seg]*len(loss_segm),
|
170 |
+
*[network.outputs[0]]*len(loss_simil), *[network.outputs[2]]*len(loss_simil),
|
171 |
+
grad,
|
172 |
+
network.outputs[1]])
|
173 |
+
|
174 |
+
# inputs = [mov_img, fix_img, mov_segm, fix_segm, zero_grads]
|
175 |
+
# outputs = [pred_img, flow, pred_segm, loss]
|
176 |
+
full_model = tf.keras.Model(inputs=network.inputs + [fix_seg, grad],
|
177 |
+
outputs=network.outputs + [loss])
|
178 |
+
|
179 |
+
# Train
|
180 |
+
os.makedirs(output_folder, exist_ok=True)
|
181 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
182 |
+
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
183 |
+
os.makedirs(os.path.join(output_folder, 'history'), exist_ok=True)
|
184 |
+
|
185 |
+
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
186 |
+
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
187 |
+
# callback_save_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'weights.{epoch:05d}-{val_loss:.2f}.h5'),
|
188 |
+
# save_best_only=True, save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
189 |
+
# CSVLogger(train_log_name, ';'),
|
190 |
+
# UpdateLossweights([haus_weight, dice_weight], [const.MODEL+'_resampler_seg', const.MODEL+'_resampler_seg'])
|
191 |
+
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
192 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
|
193 |
+
update_freq='epoch', # or 'batch' or integer
|
194 |
+
write_graph=True, write_grads=True
|
195 |
+
)
|
196 |
+
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
197 |
+
|
198 |
+
# Compile the model
|
199 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, lr=C.LEARNING_RATE)
|
200 |
+
full_model.compile(optimizer=optimizer, loss=None)
|
201 |
+
|
202 |
+
callback_tensorboard.set_model(full_model)
|
203 |
+
callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
|
204 |
+
# callback_save_model.set_model(network)
|
205 |
+
callback_early_stop.set_model(full_model)
|
206 |
+
# TODO: https://towardsdatascience.com/writing-tensorflow-2-custom-loops-438b1ab6eb6c
|
207 |
+
|
208 |
+
summary = SummaryDictionary(full_model, C.BATCH_SIZE)
|
209 |
+
names = full_model.metrics_names # It give both the loss and metric names
|
210 |
+
zero_grads = tf.zeros_like(network.references.pos_flow, name='dummy_zero_grads') # Dummy zeros-tensor
|
211 |
+
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
212 |
+
with sess.as_default():
|
213 |
+
callback_tensorboard.on_train_begin()
|
214 |
+
callback_early_stop.on_train_begin()
|
215 |
+
callback_best_model.on_train_begin()
|
216 |
+
# callback_save_model.on_train_begin()
|
217 |
+
for epoch in range(C.EPOCHS):
|
218 |
+
callback_tensorboard.on_epoch_begin(epoch)
|
219 |
+
callback_early_stop.on_epoch_begin(epoch)
|
220 |
+
callback_best_model.on_epoch_begin(epoch)
|
221 |
+
# callback_save_model.on_epoch_begin(epoch)
|
222 |
+
|
223 |
+
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
224 |
+
print('TRAINING')
|
225 |
+
|
226 |
+
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
227 |
+
progress_bar = Progbar(len(train_generator), width=30, verbose=1)
|
228 |
+
for step, (in_batch, _) in enumerate(train_generator, 1):
|
229 |
+
# callback_tensorboard.on_train_batch_begin(step)
|
230 |
+
callback_best_model.on_train_batch_begin(step)
|
231 |
+
# callback_save_model.on_train_batch_begin(step)
|
232 |
+
callback_early_stop.on_train_batch_begin(step)
|
233 |
+
|
234 |
+
try:
|
235 |
+
fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
|
236 |
+
np.nan_to_num(fix_img, copy=False)
|
237 |
+
np.nan_to_num(mov_img, copy=False)
|
238 |
+
except InvalidArgumentError as err:
|
239 |
+
print('TF Error : {}'.format(str(err)))
|
240 |
+
continue
|
241 |
+
# inputs = [mov_img, fix_img, mov_segm, fix_segm, zero_grads]
|
242 |
+
# outputs = [pred_img, flow, pred_segm, loss]
|
243 |
+
ret = full_model.train_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads))
|
244 |
+
|
245 |
+
summary.on_train_batch_end(ret)
|
246 |
+
# callback_tensorboard.on_train_batch_end(step, named_logs(full_model, ret))
|
247 |
+
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
|
248 |
+
# callback_save_model.on_train_batch_end(step, named_logs(network, ret))
|
249 |
+
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
|
250 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
251 |
+
# print(ret, '\n')
|
252 |
+
progress_bar.update(step, zip(names, ret))
|
253 |
+
print('End of epoch{}: '.format(step), ret, '\n')
|
254 |
+
val_values = progress_bar._values.copy()
|
255 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
256 |
+
|
257 |
+
print('\nVALIDATION')
|
258 |
+
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
259 |
+
progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
|
260 |
+
for step, (in_batch, _) in enumerate(validation_generator, 1):
|
261 |
+
# callback_tensorboard.on_test_batch_begin(step) # This is cursed, don't do it again
|
262 |
+
# callback_early_stop.on_test_batch_begin(step)
|
263 |
+
try:
|
264 |
+
fix_img, mov_img, fix_seg, mov_seg = augmentation_model.predict(in_batch)
|
265 |
+
except InvalidArgumentError as err:
|
266 |
+
print('TF Error : {}'.format(str(err)))
|
267 |
+
continue
|
268 |
+
|
269 |
+
ret = full_model.test_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads))
|
270 |
+
# pred_segm = network.register(mov_segm, fix_segm)
|
271 |
+
summary.on_validation_batch_end(ret)
|
272 |
+
# callback_early_stop.on_test_batch_end(step, named_logs(full_model, ret))
|
273 |
+
# callback_tensorboard.on_test_batch_end(step, named_logs(network, ret)) # This is cursed, don't do it again
|
274 |
+
progress_bar.update(step, zip(names, ret))
|
275 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
276 |
+
val_values = progress_bar._values.copy()
|
277 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
278 |
+
|
279 |
+
train_generator.on_epoch_end()
|
280 |
+
validation_generator.on_epoch_end()
|
281 |
+
epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
|
282 |
+
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
283 |
+
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
284 |
+
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
285 |
+
# callback_save_model.on_train_end(epoch, epoch_summary)
|
286 |
+
|
287 |
+
callback_tensorboard.on_train_end()
|
288 |
+
# callback_save_model.on_train_end()
|
289 |
+
callback_best_model.on_train_end()
|
290 |
+
callback_early_stop.on_train_end()
|
291 |
+
|
292 |
+
|
293 |
+
if __name__ == '__main__':
|
294 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
295 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Check availability before running using 'nvidia-smi'
|
296 |
+
|
297 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
298 |
+
config.gpu_options.allow_growth = True
|
299 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
300 |
+
tf.keras.backend.set_session(tf.Session(config=config))
|
301 |
+
|
302 |
+
launch_train('/mnt/EncryptedData1/Users/javier/Brain_study/ERASE',
|
303 |
+
'TrainOutput/THESIS/UW_None_mse_ssim_haus',
|
304 |
+
0)
|
Brain_study/__init__.py
ADDED
File without changes
|
Brain_study/data_generator.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from tensorflow import keras
|
5 |
+
import os
|
6 |
+
import h5py
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
import nibabel as nib
|
10 |
+
from nilearn.image import resample_img
|
11 |
+
from skimage.exposure import equalize_adapthist
|
12 |
+
from scipy.ndimage import zoom
|
13 |
+
import tensorflow as tf
|
14 |
+
|
15 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
16 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
17 |
+
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
18 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
19 |
+
from Brain_study.format_dataset import SEGMENTATION_NR2LBL_LUT, SEGMENTATION_LBL2NR_LUT
|
20 |
+
|
21 |
+
from tensorflow.python.keras.preprocessing.image import Iterator
|
22 |
+
from tensorflow.python.keras.utils import Sequence
|
23 |
+
import sys
|
24 |
+
|
25 |
+
#import concurrent.futures
|
26 |
+
#import multiprocessing as mp
|
27 |
+
import time
|
28 |
+
|
29 |
+
class BatchGenerator:
|
30 |
+
def __init__(self,
|
31 |
+
directory,
|
32 |
+
batch_size,
|
33 |
+
shuffle=True,
|
34 |
+
split=0.7,
|
35 |
+
combine_segmentations=True,
|
36 |
+
labels=['all'],
|
37 |
+
directory_val=None):
|
38 |
+
self.file_directory = directory
|
39 |
+
self.batch_size = batch_size
|
40 |
+
self.combine_segmentations = combine_segmentations
|
41 |
+
self.labels = labels
|
42 |
+
self.shuffle = shuffle
|
43 |
+
self.split = split
|
44 |
+
|
45 |
+
if directory_val is None:
|
46 |
+
self.file_list = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
|
47 |
+
random.shuffle(self.file_list) if self.shuffle else self.file_list.sort()
|
48 |
+
self.num_samples = len(self.file_list)
|
49 |
+
training_samples = self.file_list[:int(self.num_samples * self.split)]
|
50 |
+
|
51 |
+
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
|
52 |
+
if self.split < 1.:
|
53 |
+
validation_samples = list(set(self.file_list) - set(training_samples))
|
54 |
+
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
|
55 |
+
validation=True)
|
56 |
+
else:
|
57 |
+
self.validation_iter = None
|
58 |
+
else:
|
59 |
+
training_samples = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith(('h5', 'hd5'))]
|
60 |
+
random.shuffle(training_samples) if self.shuffle else training_samples.sort()
|
61 |
+
|
62 |
+
validation_samples = [os.path.join(directory_val, f) for f in os.listdir(directory_val) if f.endswith(('h5', 'hd5'))]
|
63 |
+
random.shuffle(validation_samples) if self.shuffle else validation_samples.sort()
|
64 |
+
|
65 |
+
self.num_samples = len(training_samples) + len(validation_samples)
|
66 |
+
self.file_list = training_samples + validation_samples
|
67 |
+
|
68 |
+
self.train_iter = BatchIterator(training_samples, batch_size, shuffle, combine_segmentations, labels)
|
69 |
+
self.validation_iter = BatchIterator(validation_samples, batch_size, shuffle, combine_segmentations, ['all'],
|
70 |
+
validation=True)
|
71 |
+
|
72 |
+
def get_train_generator(self):
|
73 |
+
return self.train_iter
|
74 |
+
|
75 |
+
def get_validation_generator(self):
|
76 |
+
if self.validation_iter is not None:
|
77 |
+
return self.validation_iter
|
78 |
+
else:
|
79 |
+
raise ValueError('No validation iterator. Split must be < 1.0')
|
80 |
+
|
81 |
+
def get_file_list(self):
|
82 |
+
return self.file_list
|
83 |
+
|
84 |
+
def get_data_shape(self):
|
85 |
+
return self.train_iter.get_data_shape()
|
86 |
+
|
87 |
+
|
88 |
+
ALL_LABELS = {2., 3., 4., 6., 8., 9., 11., 12., 14., 16., 20., 23., 29., 33., 39., 53., 67., 76., 102., 203., 210.,
|
89 |
+
211., 218., 219., 232., 233., 254., 255.}
|
90 |
+
ALL_LABELS_LOC = {label: loc for label, loc in zip(ALL_LABELS, range(0, len(ALL_LABELS)))}
|
91 |
+
|
92 |
+
|
93 |
+
class BatchIterator(Sequence):
|
94 |
+
def __init__(self, file_list, batch_size, shuffle, combine_segmentations=True, labels=['all'],
|
95 |
+
zero_grads=[64, 64, 64, 3], validation=False, **kwargs):
|
96 |
+
# super(BatchIterator, self).__init__(n=len(file_list),
|
97 |
+
# batch_size=batch_size,
|
98 |
+
# shuffle=shuffle,
|
99 |
+
# seed=None,
|
100 |
+
# **kwargs)
|
101 |
+
self.batch_size = batch_size
|
102 |
+
self.shuffle = shuffle
|
103 |
+
self.file_list = file_list
|
104 |
+
self.combine_segmentations = combine_segmentations
|
105 |
+
self.labels = labels
|
106 |
+
self.zero_grads = zero_grads
|
107 |
+
self.idx_list = np.arange(0, len(self.file_list))
|
108 |
+
self.validation = validation
|
109 |
+
self._initialize()
|
110 |
+
self.shuffle_samples()
|
111 |
+
|
112 |
+
def _initialize(self):
|
113 |
+
with h5py.File(self.file_list[0], 'r') as f:
|
114 |
+
self.image_shape = list(f['image'][:].shape)
|
115 |
+
self.segm_shape = list(f['segmentation'][:].shape)
|
116 |
+
if not self.combine_segmentations:
|
117 |
+
self.segm_shape[-1] = len(f['segmentation_labels'][:]) if self.labels[0].lower() == 'all' else len(self.labels)
|
118 |
+
|
119 |
+
self.batch_shape = self.image_shape.copy()
|
120 |
+
if self.labels[0].lower() != 'none':
|
121 |
+
self.batch_shape[-1] = 2 if self.combine_segmentations else 1 + self.segm_shape[-1] # +1 because we have the fix and the moving images
|
122 |
+
|
123 |
+
if self.labels[0] != 'all':
|
124 |
+
if isinstance(self.labels[0], str):
|
125 |
+
self.labels = [SEGMENTATION_LBL2NR_LUT[lbl] for lbl in self.labels]
|
126 |
+
|
127 |
+
self.num_steps = len(self.file_list) // self.batch_size + (1 if len(self.file_list) % self.batch_size else 0)
|
128 |
+
#self.executor = concurrent.futures.ProcessPoolExecutor(max_workers=self.batch_size)
|
129 |
+
#self.mp_pool = mp.Pool(self.batch_size)
|
130 |
+
|
131 |
+
def shuffle_samples(self):
|
132 |
+
np.random.shuffle(self.idx_list)
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
return self.num_steps
|
136 |
+
|
137 |
+
def _filter_segmentations(self, segm, segm_labels):
|
138 |
+
if self.combine_segmentations:
|
139 |
+
# TODO
|
140 |
+
warnings.warn('Cannot select labels when combinine_segmentations options is active')
|
141 |
+
if self.labels[0] != 'all':
|
142 |
+
if set(self.labels).issubset(set(segm_labels)):
|
143 |
+
# If labels in self.labels are in segm
|
144 |
+
idx = [ALL_LABELS_LOC[l] for l in self.labels]
|
145 |
+
segm = segm[..., idx]
|
146 |
+
else:
|
147 |
+
# Else we have to collect those labels that are contained and complete with zeros
|
148 |
+
idx = [ALL_LABELS_LOC[l] for l in list(set(self.labels).intersection(set(segm_labels)))]
|
149 |
+
aux = segm.copy()
|
150 |
+
segm = np.zeros(self.segm_shape)
|
151 |
+
segm[..., :len(idx)] = aux[..., idx]
|
152 |
+
# TODO: leave the zero-ed segmentations before or after the selected labels based on the order
|
153 |
+
return segm
|
154 |
+
|
155 |
+
def _load_sample(self, file_path):
|
156 |
+
with h5py.File(file_path, 'r') as f:
|
157 |
+
img = f['image'][:]
|
158 |
+
segm_labels = f['segmentation_labels'][:]
|
159 |
+
if self.combine_segmentations:
|
160 |
+
segm = f['segmentation'][:]
|
161 |
+
else:
|
162 |
+
segm = f['segmentation_expanded'][:]
|
163 |
+
if segm.shape[-1] != self.segm_shape[-1]:
|
164 |
+
aux = np.zeros(self.segm_shape)
|
165 |
+
aux[..., :segm.shape[-1]] = segm # Ensure the same shape in case there are missing labels in aux
|
166 |
+
segm = aux
|
167 |
+
# TODO: selection label segm = aux[..., self.labels] but:
|
168 |
+
# what if aux does not have a label in self.labels??
|
169 |
+
|
170 |
+
if self.labels[0].lower() != 'none' or self.validation: # I expect to ask for the segmentations during val
|
171 |
+
segm = self._filter_segmentations(segm, segm_labels)
|
172 |
+
|
173 |
+
if self.validation:
|
174 |
+
ret_val = np.concatenate([img, segm], axis=-1), (img, segm, np.zeros(self.zero_grads))
|
175 |
+
else:
|
176 |
+
ret_val = np.concatenate([img, segm], axis=-1), (img, np.zeros(self.zero_grads))
|
177 |
+
else:
|
178 |
+
ret_val = img, (img, np.zeros(self.zero_grads))
|
179 |
+
return ret_val
|
180 |
+
|
181 |
+
def __getitem__(self, idx):
|
182 |
+
in_batch = list()
|
183 |
+
# out_batch = list()
|
184 |
+
|
185 |
+
batch_idxs = self.idx_list[idx * self.batch_size:(idx + 1) * self.batch_size]
|
186 |
+
file_list = [self.file_list[i] for i in batch_idxs]
|
187 |
+
# if self.batch_size > 1:
|
188 |
+
# # Multiprocessing to speed up laoding
|
189 |
+
#
|
190 |
+
# for ret in self.executor.map(self._load_sample, file_list):
|
191 |
+
# b, i = ret
|
192 |
+
# in_batch.append(b)
|
193 |
+
# # out_batch.append(i)
|
194 |
+
# else:
|
195 |
+
# No need for multithreading, we are loading a single file
|
196 |
+
for f in file_list:
|
197 |
+
b, i = self._load_sample(f)
|
198 |
+
in_batch.append(b)
|
199 |
+
# out_batch.append(i)
|
200 |
+
|
201 |
+
in_batch = np.asarray(in_batch)
|
202 |
+
# out_batch = np.asarray(out_batch)
|
203 |
+
return in_batch, in_batch
|
204 |
+
|
205 |
+
def __iter__(self):
|
206 |
+
"""Create a generator that iterate over the Sequence."""
|
207 |
+
for item in (self[i] for i in range(len(self))):
|
208 |
+
yield item
|
209 |
+
|
210 |
+
def get_data_shape(self):
|
211 |
+
return self.batch_shape, self.image_shape, self.segm_shape
|
212 |
+
|
213 |
+
def on_epoch_end(self):
|
214 |
+
self.shuffle_samples()
|
215 |
+
|
216 |
+
def get_segmentation_labels(self):
|
217 |
+
if self.combine_segmentations:
|
218 |
+
labels = [1]
|
219 |
+
else:
|
220 |
+
with h5py.File(self.file_list[0], 'r') as f:
|
221 |
+
labels = np.unique(f['segmentation'][:])
|
222 |
+
labels = np.sort(labels)[1:] # Ignore the background
|
223 |
+
return labels
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
'''
|
231 |
+
def get_size(obj, seen=None):
|
232 |
+
"""Recursively finds size of objects"""
|
233 |
+
size = sys.getsizeof(obj)
|
234 |
+
if seen is None:
|
235 |
+
seen = set()
|
236 |
+
obj_id = id(obj)
|
237 |
+
if obj_id in seen:
|
238 |
+
return 0
|
239 |
+
# Important mark as seen *before* entering recursion to gracefully handle
|
240 |
+
# self-referential objects
|
241 |
+
seen.add(obj_id)
|
242 |
+
if isinstance(obj, dict):
|
243 |
+
size += sum([get_size(v, seen) for v in obj.values()])
|
244 |
+
size += sum([get_size(k, seen) for k in obj.keys()])
|
245 |
+
elif hasattr(obj, '__dict__'):
|
246 |
+
size += get_size(obj.__dict__, seen)
|
247 |
+
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
|
248 |
+
size += sum([get_size(i, seen) for i in obj])
|
249 |
+
return size
|
250 |
+
|
251 |
+
|
252 |
+
class BatchIterator(Iterator):
|
253 |
+
def __init__(self, generator, file_list, input_shape, output_shape, batch_size, shuffle, all_files_in_batch):
|
254 |
+
self.file_list = file_list
|
255 |
+
self.generator = generator
|
256 |
+
self.input_shape = input_shape
|
257 |
+
self.nr_of_inputs = len(input_shape)
|
258 |
+
self.output_shape = output_shape
|
259 |
+
self.nr_of_outputs = len(output_shape)
|
260 |
+
self.all_files_in_batch = all_files_in_batch
|
261 |
+
self.preload_to_memory = False
|
262 |
+
self.file_cache = {}
|
263 |
+
self.max_cache_size = 10*1024
|
264 |
+
self.verbose = False
|
265 |
+
if self.preload_to_memory:
|
266 |
+
for filename, file_index in self.file_list:
|
267 |
+
file = h5py.File(filename, 'r')
|
268 |
+
inputs = {}
|
269 |
+
for name, data in file['input'].items():
|
270 |
+
inputs[name] = np.copy(data)
|
271 |
+
self.file_cache[filename] = {'input': inputs, 'output': np.copy(file['output'])}
|
272 |
+
file.close()
|
273 |
+
if get_size(self.file_cache) / (1024*1024) >= self.max_cache_size:
|
274 |
+
print('File cache has reached limit of', self.max_cache_size, 'MBs')
|
275 |
+
break
|
276 |
+
epoch_size = len(file_list)
|
277 |
+
if all_files_in_batch:
|
278 |
+
epoch_size = len(file_list) * 10
|
279 |
+
super(BatchIterator, self).__init__(epoch_size, batch_size, shuffle, None)
|
280 |
+
|
281 |
+
def _get_sample(self, index):
|
282 |
+
filename, file_index = self.file_list[index]
|
283 |
+
if filename in self.file_cache:
|
284 |
+
file = self.file_cache[filename]
|
285 |
+
else:
|
286 |
+
file = h5py.File(filename, 'r')
|
287 |
+
inputs = []
|
288 |
+
outputs = []
|
289 |
+
for name, data in file['input'].items():
|
290 |
+
inputs.append(data[file_index, :])
|
291 |
+
for name, data in file['output'].items():
|
292 |
+
if len(data.shape) > 1:
|
293 |
+
outputs.append(data[file_index, :])
|
294 |
+
else:
|
295 |
+
outputs.append(data[file_index])
|
296 |
+
#outputs.append(file['output'][file_index, :]) # TODO fix
|
297 |
+
if filename not in self.file_cache:
|
298 |
+
file.close()
|
299 |
+
return inputs, outputs
|
300 |
+
|
301 |
+
def _get_random_sample_in_file(self, file_index):
|
302 |
+
filename = self.file_list[file_index]
|
303 |
+
file = h5py.File(filename, 'r')
|
304 |
+
x = file['output/0']
|
305 |
+
sample = np.random.randint(0, x.shape[0])
|
306 |
+
#print('Sampling image', sample, 'from file', filename)
|
307 |
+
inputs = []
|
308 |
+
outputs = []
|
309 |
+
for name, data in file['input'].items():
|
310 |
+
inputs.append(data[sample, :])
|
311 |
+
for name, data in file['output'].items():
|
312 |
+
outputs.append(data[file_index, :])
|
313 |
+
#outputs.append(file['output'][sample, :]) # TODO FIX output
|
314 |
+
file.close()
|
315 |
+
return inputs, outputs
|
316 |
+
|
317 |
+
def next(self):
|
318 |
+
|
319 |
+
with self.lock:
|
320 |
+
index_array = next(self.index_generator)
|
321 |
+
|
322 |
+
#print(len(index_array))
|
323 |
+
return self._get_batches_of_transformed_samples(index_array)
|
324 |
+
|
325 |
+
def _get_batches_of_transformed_samples(self, index_array):
|
326 |
+
start_batch = time.time()
|
327 |
+
batches_x = []
|
328 |
+
batches_y = []
|
329 |
+
for input_index in range(self.nr_of_inputs):
|
330 |
+
batches_x.append(np.zeros(tuple([len(index_array)] + list(self.input_shape[input_index]))))
|
331 |
+
for output_index in range(self.nr_of_outputs):
|
332 |
+
batches_y.append(np.zeros(tuple([len(index_array)] + list(self.output_shape[output_index]))))
|
333 |
+
|
334 |
+
timings_sampling = np.zeros((len(index_array,)))
|
335 |
+
timings_transform = np.zeros((len(index_array,)))
|
336 |
+
for batch_index, sample_index in enumerate(index_array):
|
337 |
+
# Have to copy here in order to not modify original data
|
338 |
+
start = time.time()
|
339 |
+
if self.all_files_in_batch:
|
340 |
+
input, output = self._get_random_sample_in_file(batch_index)
|
341 |
+
else:
|
342 |
+
input, output = self._get_sample(sample_index)
|
343 |
+
timings_sampling[batch_index] = time.time() - start
|
344 |
+
start = time.time()
|
345 |
+
input, output = self.generator.transform(input, output)
|
346 |
+
timings_transform[batch_index] = time.time() - start
|
347 |
+
|
348 |
+
#print('inputs', self.nr_of_inputs, len(input))
|
349 |
+
for input_index in range(self.nr_of_inputs):
|
350 |
+
batches_x[input_index][batch_index] = input[input_index]
|
351 |
+
for output_index in range(self.nr_of_outputs):
|
352 |
+
batches_y[output_index][batch_index] = output[output_index]
|
353 |
+
|
354 |
+
elapsed = time.time() - start_batch
|
355 |
+
if self.verbose:
|
356 |
+
print('Time to prepare batch:', round(elapsed,3), 'seconds')
|
357 |
+
print('Sampling mean:', round(timings_sampling.mean(), 3), 'seconds')
|
358 |
+
print('Transform mean:', round(timings_transform.mean(), 3), 'seconds')
|
359 |
+
|
360 |
+
return batches_x, batches_y
|
361 |
+
|
362 |
+
|
363 |
+
CLASSIFICATION = 'classification'
|
364 |
+
SEGMENTATION = 'segmentation'
|
365 |
+
|
366 |
+
|
367 |
+
class BatchGenerator():
|
368 |
+
def __init__(self, filelist, all_files_in_batch=False):
|
369 |
+
self.methods = []
|
370 |
+
self.args = []
|
371 |
+
self.crop_width_to = None
|
372 |
+
self.image_list = []
|
373 |
+
self.input_shape = []
|
374 |
+
self.output_shape = []
|
375 |
+
self.all_files_in_batch = all_files_in_batch
|
376 |
+
self.transforms = []
|
377 |
+
|
378 |
+
if all_files_in_batch:
|
379 |
+
file = h5py.File(filelist[0], 'r')
|
380 |
+
for name, data in file['input'].items():
|
381 |
+
self.input_shape.append(data.shape[1:])
|
382 |
+
for name, data in file['output'].items():
|
383 |
+
self.output_shape.append(data.shape[1:])
|
384 |
+
# TODO fix
|
385 |
+
#self.output_shape.append(file['output'].shape[1:])
|
386 |
+
file.close()
|
387 |
+
self.image_list = filelist
|
388 |
+
return
|
389 |
+
|
390 |
+
# Go through filelist
|
391 |
+
first = True
|
392 |
+
for filename in filelist:
|
393 |
+
samples = None
|
394 |
+
# Open file to see how many samples it has
|
395 |
+
file = h5py.File(filename, 'r')
|
396 |
+
for name, data in file['input'].items():
|
397 |
+
if first:
|
398 |
+
self.input_shape.append(data.shape[1:])
|
399 |
+
samples = data.shape[0]
|
400 |
+
# TODO fix
|
401 |
+
for name, data in file['output'].items():
|
402 |
+
if first:
|
403 |
+
self.output_shape.append(data.shape[1:])
|
404 |
+
if samples != data.shape[0]:
|
405 |
+
raise ValueError()
|
406 |
+
#self.output_shape.append(file['output'].shape[1:])
|
407 |
+
if len(self.output_shape) == 1:
|
408 |
+
self.problem_type = CLASSIFICATION
|
409 |
+
else:
|
410 |
+
self.problem_type = SEGMENTATION
|
411 |
+
|
412 |
+
file.close()
|
413 |
+
if samples is None:
|
414 |
+
raise ValueError()
|
415 |
+
# Append a tuple to image_list for each image consisting of filename and index
|
416 |
+
print(filename, samples)
|
417 |
+
for i in range(samples):
|
418 |
+
self.image_list.append((filename, i))
|
419 |
+
first = False
|
420 |
+
|
421 |
+
print('Image generator with', len(self.image_list), ' image samples created')
|
422 |
+
|
423 |
+
def flow(self, batch_size, shuffle=True):
|
424 |
+
|
425 |
+
return BatchIterator(self, self.image_list, self.input_shape, self.output_shape, batch_size, shuffle, self.all_files_in_batch)
|
426 |
+
|
427 |
+
def transform(self, inputs, outputs):
|
428 |
+
#input = input.astype(np.float32) # TODO
|
429 |
+
#output = output.astype(np.float32)
|
430 |
+
for input_indices, output_indices, transform in self.transforms:
|
431 |
+
transform.randomize()
|
432 |
+
inputs, outputs = transform.transform_all(inputs, outputs, input_indices, output_indices)
|
433 |
+
return inputs, outputs
|
434 |
+
|
435 |
+
def add_transform(self, input_indices: Union[int, List[int], None], output_indices: Union[int, List[int], None], transform: Transform):
|
436 |
+
if type(input_indices) is int:
|
437 |
+
input_indices = [input_indices]
|
438 |
+
if type(output_indices) is int:
|
439 |
+
output_indices = [output_indices]
|
440 |
+
|
441 |
+
self.transforms.append((
|
442 |
+
input_indices,
|
443 |
+
output_indices,
|
444 |
+
transform
|
445 |
+
))
|
446 |
+
|
447 |
+
def get_size(self):
|
448 |
+
if self.all_files_in_batch:
|
449 |
+
return 10*len(self.image_list)
|
450 |
+
else:
|
451 |
+
return len(self.image_list)
|
452 |
+
|
453 |
+
'''
|
Brain_study/format_dataset.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import nibabel as nib
|
3 |
+
from nilearn.image import resample_img
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import numpy as np
|
7 |
+
from scipy.ndimage import zoom
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
SEGMENTATION_NR2LBL_LUT = {0: 'background',
|
11 |
+
2: 'parietal-right-gm',
|
12 |
+
3: 'lateral-ventricle-left',
|
13 |
+
4: 'occipital-right-gm',
|
14 |
+
6: 'parietal-left-gm',
|
15 |
+
8: 'occipital-left-gm',
|
16 |
+
9: 'lateral-ventricle-right',
|
17 |
+
11: 'globus-pallidus-right',
|
18 |
+
12: 'globus-pallidus-left',
|
19 |
+
14: 'putamen-left',
|
20 |
+
16: 'putamen-right',
|
21 |
+
20: 'brain-stem',
|
22 |
+
23: 'subthalamic-nucleus-right',
|
23 |
+
29: 'fornix-left',
|
24 |
+
33: 'subthalamic-nucleus-left',
|
25 |
+
39: 'caudate-left',
|
26 |
+
53: 'caudate-right',
|
27 |
+
67: 'cerebellum-left',
|
28 |
+
76: 'cerebellum-right',
|
29 |
+
102: 'thalamus-left',
|
30 |
+
203: 'thalamus-right',
|
31 |
+
210: 'frontal-left-gm',
|
32 |
+
211: 'frontal-right-gm',
|
33 |
+
218: 'temporal-left-gm',
|
34 |
+
219: 'temporal-right-gm',
|
35 |
+
232: '3rd-ventricle',
|
36 |
+
233: '4th-ventricle',
|
37 |
+
254: 'fornix-right',
|
38 |
+
255: 'csf'}
|
39 |
+
SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
|
40 |
+
|
41 |
+
IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1'
|
42 |
+
SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/anatomical_masks'
|
43 |
+
|
44 |
+
IMG_NAME_PATTERN = '(.*).nii.gz'
|
45 |
+
SEG_NAME_PATTERN = '(.*)_lobes.nii.gz'
|
46 |
+
|
47 |
+
OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test'
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
|
51 |
+
img_list.sort()
|
52 |
+
|
53 |
+
seg_list = [os.path.join(SEG_DIRECTORY, f) for f in os.listdir(SEG_DIRECTORY) if f.endswith('.nii.gz')]
|
54 |
+
seg_list.sort()
|
55 |
+
|
56 |
+
os.makedirs(OUT_DIRECTORY, exist_ok=True)
|
57 |
+
for seg_file in tqdm(seg_list):
|
58 |
+
img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
|
59 |
+
img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
|
60 |
+
|
61 |
+
img = resample_img(nib.load(img_file), np.eye(3))
|
62 |
+
seg = resample_img(nib.load(seg_file), np.eye(3), interpolation='nearest')
|
63 |
+
|
64 |
+
isot_shape = img.shape
|
65 |
+
|
66 |
+
# Resize to 128x128x128
|
67 |
+
img = np.asarray(img.dataobj)
|
68 |
+
img = zoom(img, np.asarray([128]*3) / np.asarray(isot_shape), order=3)
|
69 |
+
|
70 |
+
seg = np.asarray(seg.dataobj)
|
71 |
+
seg = zoom(seg, np.asarray([128]*3) / np.asarray(isot_shape), order=0)
|
72 |
+
|
73 |
+
unique_lbls = np.unique(seg)[1:] # Omit background
|
74 |
+
seg_expanded = np.tile(np.zeros_like(seg)[..., np.newaxis], (1, 1, 1, len(unique_lbls)))
|
75 |
+
for ch, lbl in enumerate(unique_lbls):
|
76 |
+
seg_expanded[seg == lbl, ch] = 1
|
77 |
+
|
78 |
+
h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
|
79 |
+
|
80 |
+
h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
|
81 |
+
h5_file.create_dataset('segmentation', data=seg[..., np.newaxis].astype(np.uint8), dtype=np.uint8)
|
82 |
+
h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
|
83 |
+
h5_file.create_dataset('segmentation_labels', data=unique_lbls)
|
84 |
+
h5_file.create_dataset('isotropic_shape', data=isot_shape)
|
85 |
+
|
86 |
+
h5_file.close()
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
Brain_study/split_dataset.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import math
|
7 |
+
from shutil import copyfile
|
8 |
+
from tqdm import tqdm
|
9 |
+
import concurrent.futures
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
def copy_file(s_d):
|
14 |
+
s, d = s_d
|
15 |
+
file_name = os.path.split(s)[-1]
|
16 |
+
copyfile(s, os.path.join(d, file_name))
|
17 |
+
return int(os.path.exists(d))
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == '__main__':
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
|
23 |
+
parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
|
24 |
+
parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
|
25 |
+
parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
|
26 |
+
parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
|
27 |
+
parser.add_argument('-r', '--random', type=bool, help='Randomly split the dataset or not. Default: True', default=True)
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
assert args.train + args.validation + args.test == 1.0, 'Train+Validation+Test != 1 (100%)'
|
32 |
+
|
33 |
+
file_set = [os.path.join(args.dir, f) for f in os.listdir(args.dir) if f.endswith(args.format)]
|
34 |
+
random.shuffle(file_set) if args.random else file_set.sort()
|
35 |
+
|
36 |
+
num_files = len(file_set)
|
37 |
+
num_validation = math.floor(num_files * args.validation)
|
38 |
+
num_test = math.floor(num_files * args.test)
|
39 |
+
num_train = num_files - num_test - num_validation
|
40 |
+
|
41 |
+
dataset_root, dataset_name = os.path.split(args.dir)
|
42 |
+
dst_train = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'train_set')
|
43 |
+
dst_validation = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'validation_set')
|
44 |
+
dst_test = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'test_set')
|
45 |
+
|
46 |
+
print('OUTPUT INFORMATION\n=============')
|
47 |
+
print('Train:\t\t{}'.format(num_train))
|
48 |
+
print('Validation:\t{}'.format(num_validation))
|
49 |
+
print('Test:\t\t{}'.format(num_test))
|
50 |
+
print('Num. samples\t{}'.format(num_files))
|
51 |
+
print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_'+dataset_name))
|
52 |
+
|
53 |
+
dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
|
54 |
+
|
55 |
+
os.makedirs(dst_train, exist_ok=True)
|
56 |
+
os.makedirs(dst_validation, exist_ok=True)
|
57 |
+
os.makedirs(dst_test, exist_ok=True)
|
58 |
+
|
59 |
+
progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
|
60 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
|
61 |
+
results = list(tqdm(ex.map(copy_file, zip(file_set, dest)), desc='Copying files', total=num_files))
|
62 |
+
|
63 |
+
num_copies = np.sum(results)
|
64 |
+
if num_copies == num_files:
|
65 |
+
print('Done successfully')
|
66 |
+
else:
|
67 |
+
warnings.warn('Missing files: {}'.format(num_files - num_copies))
|
68 |
+
|
Brain_study/test_datagenerator.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Brain_study.data_generator import BatchGenerator
|
2 |
+
|
3 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from tensorflow import keras
|
7 |
+
import tensorflow as tf
|
8 |
+
from tensorflow.keras.callbacks import TensorBoard
|
9 |
+
import os
|
10 |
+
import voxelmorph as vxm
|
11 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
12 |
+
from DeepDeformationMapRegistration.losses import NCC, StructuralSimilarity, StructuralSimilarity_simplified
|
13 |
+
|
14 |
+
|
15 |
+
def named_logs(model, logs, validation=False):
|
16 |
+
result = {'size': C.BATCH_SIZE} # https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8#gistcomment-3041181
|
17 |
+
for l in zip(model.metrics_names, logs):
|
18 |
+
k = ('val_' if validation else '') + l[0]
|
19 |
+
result[k] = l[1]
|
20 |
+
return result
|
21 |
+
|
22 |
+
if __name__ == '__main__':
|
23 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(1)
|
24 |
+
|
25 |
+
C.BATCH_SIZE = 12
|
26 |
+
C.TRAINING_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
|
27 |
+
output_folder = "/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE"
|
28 |
+
|
29 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE, True, C.TRAINING_PERC, True, ['none'])
|
30 |
+
train_generator = data_generator.get_train_generator()
|
31 |
+
val_generator = data_generator.get_validation_generator()
|
32 |
+
|
33 |
+
e_iter = tqdm(range(100))
|
34 |
+
t_iter = tqdm(train_generator)
|
35 |
+
v_iter = tqdm(val_generator)
|
36 |
+
|
37 |
+
e_iter.set_description('Epoch')
|
38 |
+
t_iter.set_description('Train')
|
39 |
+
v_iter.set_description('Val')
|
40 |
+
#
|
41 |
+
# for s in e_iter:
|
42 |
+
# for b in t_iter:
|
43 |
+
# continue
|
44 |
+
#
|
45 |
+
# for b in v_iter:
|
46 |
+
# continue
|
47 |
+
|
48 |
+
# Build model
|
49 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
50 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
51 |
+
nb_features = [enc_features, dec_features]
|
52 |
+
network = vxm.networks.VxmDense(inshape=(64, 64, 64),
|
53 |
+
nb_unet_features=nb_features,
|
54 |
+
int_steps=0)
|
55 |
+
|
56 |
+
d = os.path.join(os.getcwd(), 'tensorboard_test')
|
57 |
+
os.makedirs(d, exist_ok=True)
|
58 |
+
callback_tensorboard = TensorBoard(log_dir=d,
|
59 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0, update_freq='epoch',
|
60 |
+
write_graph=True,
|
61 |
+
write_grads=True)
|
62 |
+
|
63 |
+
losses = {'transformer': StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
64 |
+
'flow': vxm.losses.Grad('l2').loss}
|
65 |
+
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
66 |
+
tf.keras.losses.MSE],
|
67 |
+
# 'flow': vxm.losses.Grad('l2').loss
|
68 |
+
}
|
69 |
+
loss_weights = {'transformer': 1.,
|
70 |
+
'flow': 5e-3}
|
71 |
+
|
72 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
73 |
+
|
74 |
+
network.compile(optimizer=optimizer,
|
75 |
+
loss=losses,
|
76 |
+
loss_weights=loss_weights,
|
77 |
+
metrics=metrics)
|
78 |
+
callback_tensorboard.set_model(network)
|
79 |
+
dummy = lambda x: named_logs(network, [x, 0, x, 0, 0])
|
80 |
+
|
81 |
+
callback_tensorboard.on_train_begin()
|
82 |
+
for s in e_iter:
|
83 |
+
callback_tensorboard.on_epoch_begin(s)
|
84 |
+
|
85 |
+
for n in range(100):
|
86 |
+
callback_tensorboard.on_train_batch_begin(n)
|
87 |
+
input('Press enter')
|
88 |
+
callback_tensorboard.on_train_batch_end(n, dummy(n))
|
89 |
+
|
90 |
+
callback_tensorboard.on_epoch_end(s, dummy)
|
91 |
+
callback_tensorboard.on_train_end()
|
Brain_study/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
3 |
+
|
4 |
+
class SummaryDictionary:
|
5 |
+
def __init__(self, model, batch_size, accumulative_gradients_step=None):
|
6 |
+
self.train_names = model.metrics_names
|
7 |
+
self.val_names = ['val_'+n for n in self.train_names]
|
8 |
+
self.batch_size = batch_size
|
9 |
+
self.acc_grad_step = accumulative_gradients_step
|
10 |
+
self._reset()
|
11 |
+
|
12 |
+
def _reset(self):
|
13 |
+
self.summary_dict = {'size': self.batch_size}
|
14 |
+
if self.acc_grad_step is not None:
|
15 |
+
self.summary_dict = {'accumulative_grad_step': self.acc_grad_step}
|
16 |
+
for k in self.train_names + self.val_names:
|
17 |
+
self.summary_dict[k] = list()
|
18 |
+
|
19 |
+
def on_train_batch_end(self, values):
|
20 |
+
for k, v in zip(self.train_names, values):
|
21 |
+
self.summary_dict[k].append(v)
|
22 |
+
|
23 |
+
def on_validation_batch_end(self, values):
|
24 |
+
for k, v in zip(self.val_names, values):
|
25 |
+
self.summary_dict[k].append(v)
|
26 |
+
|
27 |
+
def on_epoch_end(self):
|
28 |
+
for k, v in self.summary_dict.items():
|
29 |
+
self.summary_dict[k] = np.asarray(v).mean()
|
30 |
+
|
31 |
+
ret_val = self.summary_dict.copy()
|
32 |
+
self._reset()
|
33 |
+
return ret_val
|
34 |
+
|
35 |
+
|
36 |
+
def named_logs(model, logs, validation=False):
|
37 |
+
result = {'size': C.BATCH_SIZE} # https://gist.github.com/erenon/91f526302cd8e9d21b73f24c0f9c4bb8#gistcomment-3041181
|
38 |
+
for l in zip(model.metrics_names, logs):
|
39 |
+
k = ('val_' if validation else '') + l[0]
|
40 |
+
result[k] = l[1]
|
41 |
+
return result
|
EvaluationScripts/Evaluate_class.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.signal import correlate as cc
|
2 |
+
from scipy.spatial.distance import euclidean
|
3 |
+
from skimage.metrics import mean_squared_error as mse
|
4 |
+
from skimage.metrics import structural_similarity as ssim
|
5 |
+
from medpy.metric.binary import dc, hd95
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import os
|
9 |
+
from DeepDeformationMapRegistration.utils.constants import EPS
|
10 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
11 |
+
from skimage.transform import resize
|
12 |
+
from skimage.measure import regionprops, label
|
13 |
+
|
14 |
+
|
15 |
+
def ncc(y_true, y_pred, eps=EPS):
|
16 |
+
f_yt = np.reshape(y_true, [-1])
|
17 |
+
f_yp = np.reshape(y_pred, [-1])
|
18 |
+
mean_yt = np.mean(f_yt)
|
19 |
+
mean_yp = np.mean(f_yp)
|
20 |
+
|
21 |
+
n_f_yt = f_yt - mean_yt
|
22 |
+
n_f_yp = f_yp - mean_yp
|
23 |
+
norm_yt = np.linalg.norm(f_yt, ord=2)
|
24 |
+
norm_yp = np.linalg.norm(f_yp, ord=2)
|
25 |
+
numerator = np.sum(np.multiply(n_f_yt, n_f_yp))
|
26 |
+
denominator = norm_yt * norm_yp + eps
|
27 |
+
return np.divide(numerator, denominator)
|
28 |
+
|
29 |
+
|
30 |
+
class EvaluationFigures:
|
31 |
+
def __init__(self, output_folder):
|
32 |
+
pd.set_option('display.max_columns', None)
|
33 |
+
self.__metrics_df = pd.DataFrame(columns=['Name', 'Train_ds', 'Eval_ds', 'MSE', 'NCC', 'SSIM',
|
34 |
+
'DICE_PAR', 'DICE_TUM', 'DICE_VES',
|
35 |
+
'HD95_PAR', 'HD95_TUM', 'HD95_VES', 'TRE'])
|
36 |
+
self.__output_folder = output_folder
|
37 |
+
|
38 |
+
def add_sample(self, name, train_ds, eval_ds, fix_t, par_t, tum_t, ves_t, centroid_t, fix_p, par_p, tum_p, ves_p, centroid_p, scale_transform=None):
|
39 |
+
|
40 |
+
n_fix_t = self.__mean_centred_img(fix_t)
|
41 |
+
n_fix_p = self.__mean_centred_img(fix_p)
|
42 |
+
|
43 |
+
if scale_transform is not None:
|
44 |
+
s_centroid_t = self.__scale_point(centroid_t, scale_transform)
|
45 |
+
s_centroid_p = self.__scale_point(centroid_p, scale_transform)
|
46 |
+
else:
|
47 |
+
s_centroid_t = centroid_t
|
48 |
+
s_centroid_p = centroid_p
|
49 |
+
|
50 |
+
new_row = {'Name': name,
|
51 |
+
'Train_ds': train_ds,
|
52 |
+
'Eval_ds': eval_ds,
|
53 |
+
'MSE': mse(fix_t, fix_p),
|
54 |
+
'NCC': ncc(n_fix_t, n_fix_p),
|
55 |
+
'SSIM': ssim(fix_t, fix_p, multichannel=True),
|
56 |
+
'DICE_PAR': dc(par_p, par_t),
|
57 |
+
'DICE_TUM': dc(tum_p, tum_t),
|
58 |
+
'DICE_VES': dc(ves_p, ves_t),
|
59 |
+
'HD95_PAR': hd95(par_p, par_t) if np.sum(par_p) else 64,
|
60 |
+
'HD95_TUM': hd95(tum_p, tum_t) if np.sum(tum_p) else 64,
|
61 |
+
'HD95_VES': hd95(ves_p, ves_t) if np.sum(ves_p) else 64,
|
62 |
+
'TRE': euclidean(s_centroid_t, s_centroid_p)}
|
63 |
+
|
64 |
+
self.__metrics_df = self.__metrics_df.append(new_row, ignore_index=True)
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def __mean_centred_img(img):
|
68 |
+
return img - np.mean(img)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def __scale_point(point, scale_matrix):
|
72 |
+
assert scale_matrix.shape == (4, 4), 'Transformation matrix is expected to have shape (4, 4)'
|
73 |
+
aux_aug = np.ones((4,))
|
74 |
+
aux_aug[:3] = point
|
75 |
+
return np.matmul(scale_matrix, aux_aug)[:1]
|
76 |
+
|
77 |
+
def save_metrics(self, dest_folder=None):
|
78 |
+
if dest_folder is None:
|
79 |
+
dest_folder = self.__output_folder
|
80 |
+
self.__metrics_df.to_csv(os.path.join(dest_folder, 'metrics.csv'))
|
81 |
+
self.__metrics_df.to_latex(os.path.join(dest_folder, 'table.txt'), sparsify=True)
|
82 |
+
print('Metrics saved in: ' + os.path.join(dest_folder))
|
83 |
+
|
84 |
+
def print_summary(self):
|
85 |
+
print(self.__metrics_df[['MSE', 'NCC', 'SSIM',
|
86 |
+
'DICE_PAR', 'DICE_TUM', 'DICE_VES',
|
87 |
+
'HD95_PAR', 'HD95_TUM', 'HD95_VES', 'TRE']].describe())
|
88 |
+
|
89 |
+
|
90 |
+
def resize_img_to_original_space(img, bb, first_reshape, original_shape, clip_img=False, flow=False):
|
91 |
+
first_reshape = first_reshape.astype(int)
|
92 |
+
bb = bb.astype(int)
|
93 |
+
original_shape = original_shape.astype(int)
|
94 |
+
|
95 |
+
if flow:
|
96 |
+
# Multiply before resizing to reduce the number of multiplications
|
97 |
+
img = _rescale_flow_values(img, bb, img.shape, first_reshape, original_shape)
|
98 |
+
|
99 |
+
min_i, min_j, min_k, bb_i, bb_j, bb_k = bb
|
100 |
+
max_i = min_i + bb_i
|
101 |
+
max_j = min_j + bb_j
|
102 |
+
max_k = min_k + bb_k
|
103 |
+
|
104 |
+
img_bb = resize(img, (bb_i, bb_j, bb_k)) # Get the original bounding box shape
|
105 |
+
|
106 |
+
# Place the bounding box again in the cubic volume
|
107 |
+
img_copy = np.zeros((*first_reshape, img.shape[-1]) if len(img.shape) > 3 else first_reshape) # Get channels if any
|
108 |
+
img_copy[min_i:max_i, min_j:max_j, min_k:max_k, ...] = img_bb
|
109 |
+
|
110 |
+
# Now resize to the original shape
|
111 |
+
resized_img = resize(img_copy, original_shape, preserve_range=True, anti_aliasing=False)
|
112 |
+
if clip_img or flow:
|
113 |
+
# clip_mask = np.zeros(img_copy.shape[:3], np.int)
|
114 |
+
# clip_mask[min_i:max_i, min_j:max_j, min_k:max_k] = 1
|
115 |
+
# clip_mask = resize(clip_mask, original_shape, preserve_range=True, anti_aliasing=False)
|
116 |
+
# clip_mask[clip_mask > 0.5] = 1
|
117 |
+
# clip_mask[clip_mask < 1] = 0
|
118 |
+
#
|
119 |
+
# [min_i, min_j, min_k, max_i, max_j, max_k] = regionprops(label(clip_mask))[0].bbox
|
120 |
+
#
|
121 |
+
# resized_img = resized_img[min_i:max_i, min_j:max_j, min_k:max_k, ...]
|
122 |
+
|
123 |
+
# Compute the coordinates of the boundix box in the upsampled volume, instead of resizing a mask image
|
124 |
+
S = resize_transformation(img.shape, bb=None, first_reshape=first_reshape, original_shape=original_shape, translate=True)
|
125 |
+
bb_coords = np.asarray([[min_i, min_j, min_k], [max_i, max_j, max_k]])
|
126 |
+
bb_coords = np.hstack([bb_coords, np.ones((2, 1))])
|
127 |
+
|
128 |
+
upsamp_bbox_coords = np.around(np.matmul(S, bb_coords.T)[:-1, :].T).astype(np.int)
|
129 |
+
min_i = upsamp_bbox_coords[0][0]
|
130 |
+
min_j = upsamp_bbox_coords[0][1]
|
131 |
+
min_k = upsamp_bbox_coords[0][2]
|
132 |
+
max_i = upsamp_bbox_coords[1][0]
|
133 |
+
max_j = upsamp_bbox_coords[1][1]
|
134 |
+
max_k = upsamp_bbox_coords[1][2]
|
135 |
+
resized_img = resized_img[min_i:max_i, min_j:max_j, min_k:max_k, ...]
|
136 |
+
|
137 |
+
if flow:
|
138 |
+
# Return also the origin of the bb in the resized volume for the following interpolation
|
139 |
+
return resized_img, np.asarray([min_i, min_j, min_k])
|
140 |
+
return resized_img
|
141 |
+
# This is supposed to be an isotropic image with voxel size 1 mm
|
142 |
+
|
143 |
+
|
144 |
+
def _rescale_flow_values(flow, bb, current_img_shape, first_reshape, original_shape):
|
145 |
+
S = resize_transformation(current_img_shape, bb, first_reshape, original_shape, translate=False)
|
146 |
+
|
147 |
+
[si, sj, sk] = np.diag(S[:3, :3])
|
148 |
+
flow[..., 0] *= si
|
149 |
+
flow[..., 1] *= sj
|
150 |
+
flow[..., 2] *= sk
|
151 |
+
|
152 |
+
return flow
|
153 |
+
|
154 |
+
|
155 |
+
def resize_pts_to_original_space(pt, bb, current_img_shape, first_reshape, original_shape):
|
156 |
+
T = resize_transformation(current_img_shape, bb, first_reshape, original_shape)
|
157 |
+
if len(pt.shape) > 1:
|
158 |
+
pt_aug = np.ones((4, pt.shape[0]))
|
159 |
+
pt_aug[0:3, :] = pt.T
|
160 |
+
else:
|
161 |
+
pt_aug = np.ones((4,))
|
162 |
+
pt_aug[0:3] = pt
|
163 |
+
trf_pt = np.matmul(T, pt_aug)[:-1, ...].T
|
164 |
+
|
165 |
+
return trf_pt
|
166 |
+
|
167 |
+
|
168 |
+
def resize_transformation(current_img_shape, bb=None, first_reshape=None, original_shape=None, translate=True):
|
169 |
+
first_reshape = first_reshape.astype(int)
|
170 |
+
original_shape = original_shape.astype(int)
|
171 |
+
|
172 |
+
first_resize_trf = np.eye(4)
|
173 |
+
if bb is not None:
|
174 |
+
bb = bb.astype(int)
|
175 |
+
min_i, min_j, min_k, bb_i, bb_j, bb_k = bb
|
176 |
+
np.fill_diagonal(first_resize_trf, [bb_i / current_img_shape[0], bb_j / current_img_shape[1], bb_k / current_img_shape[2], 1])
|
177 |
+
if translate:
|
178 |
+
first_resize_trf[:3, -1] = np.asarray([min_i, min_j, min_k])
|
179 |
+
|
180 |
+
original_resize_trf = np.eye(4)
|
181 |
+
np.fill_diagonal(original_resize_trf, [original_shape[0] / first_reshape[0], original_shape[1] / first_reshape[1], original_shape[2] / first_reshape[2], 1])
|
182 |
+
|
183 |
+
return np.matmul(original_resize_trf, first_resize_trf)
|
184 |
+
|