Commit
·
476daa5
1
Parent(s):
6a4f823
Scripts for training on the COMET CT Dataset
Browse files- COMET/Build_test_set.py +135 -0
- COMET/COMET_train.py +426 -0
- COMET/COMET_train_UW.py +342 -0
- COMET/Evaluate_network.py +264 -0
- COMET/MultiTrain_cli.py +61 -0
- COMET/MultiTrain_config.py +71 -0
- COMET/augmentation_constants.py +34 -0
- COMET/format_dataset.py +114 -0
- COMET/spit_dataset.py +40 -0
- COMET/train_config_files/Config_BASELINE_None_froozen.ini +21 -0
- COMET/train_config_files/Config_BASELINE_bottom_froozen.ini +21 -0
- COMET/train_config_files/Config_BASELINE_top_froozen.ini +21 -0
- COMET/train_config_files/Config_SEGGUIDED_None_froozen.ini +21 -0
- COMET/train_config_files/Config_SEGGUIDED_bottom_froozen.ini +21 -0
- COMET/train_config_files/Config_SEGGUIDED_top_froozen.ini +21 -0
- COMET/train_config_files/Config_UW_None_froozen.ini +21 -0
- COMET/train_config_files/Config_UW_bottom_froozen.ini +21 -0
- COMET/train_config_files/Config_UW_top_froozen.ini +21 -0
COMET/Build_test_set.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
8 |
+
parentdir = os.path.dirname(currentdir)
|
9 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
10 |
+
|
11 |
+
import tensorflow as tf
|
12 |
+
# tf.enable_eager_execution(config=config)
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import h5py
|
16 |
+
|
17 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
19 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
20 |
+
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
21 |
+
from DeepDeformationMapRegistration.utils.misc import get_segmentations_centroids
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from Brain_study.data_generator import BatchGenerator
|
25 |
+
|
26 |
+
from skimage.measure import regionprops
|
27 |
+
from scipy.interpolate import griddata
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
|
31 |
+
|
32 |
+
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test'
|
33 |
+
|
34 |
+
POINTS = None
|
35 |
+
MISSING_CENTROID = np.asarray([[np.nan]*3])
|
36 |
+
|
37 |
+
|
38 |
+
def get_mov_centroids(fix_seg, disp_map, nb_labels=28, brain_study=True):
|
39 |
+
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=brain_study)
|
40 |
+
disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear')
|
41 |
+
return fix_centroids, fix_centroids + disp, disp
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
parser = argparse.ArgumentParser()
|
46 |
+
parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='')
|
47 |
+
parser.add_argument('--reldir', type=str, help='Relative path to dataset, in where to store the files', default='')
|
48 |
+
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
49 |
+
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='')
|
50 |
+
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
51 |
+
args = parser.parse_args()
|
52 |
+
|
53 |
+
assert args.dataset != '', "Missing original dataset dataset"
|
54 |
+
if args.dir == '' and args.reldir != '':
|
55 |
+
OUTPUT_FOLDER_DIR = os.path.join(args.dataset, 'test_dataset')
|
56 |
+
elif args.dir != '' and args.reldir == '':
|
57 |
+
OUTPUT_FOLDER_DIR = args.dir
|
58 |
+
else:
|
59 |
+
raise ValueError("Either provide 'dir' or 'reldir'")
|
60 |
+
|
61 |
+
if args.erase:
|
62 |
+
shutil.rmtree(OUTPUT_FOLDER_DIR, ignore_errors=True)
|
63 |
+
os.makedirs(OUTPUT_FOLDER_DIR, exist_ok=True)
|
64 |
+
print('DESTINATION FOLDER: ', OUTPUT_FOLDER_DIR)
|
65 |
+
|
66 |
+
DATASET = args.dataset
|
67 |
+
|
68 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
69 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
70 |
+
|
71 |
+
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'])
|
72 |
+
|
73 |
+
img_generator = data_generator.get_train_generator()
|
74 |
+
nb_labels = len(img_generator.get_segmentation_labels())
|
75 |
+
image_input_shape = img_generator.get_data_shape()[-1][:-1]
|
76 |
+
image_output_shape = [64] * 3
|
77 |
+
# Build model
|
78 |
+
|
79 |
+
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False)
|
80 |
+
yy = np.linspace(0, image_output_shape[1], image_output_shape[2], endpoint=False)
|
81 |
+
zz = np.linspace(0, image_output_shape[2], image_output_shape[1], endpoint=False)
|
82 |
+
|
83 |
+
xx, yy, zz = np.meshgrid(xx, yy, zz)
|
84 |
+
|
85 |
+
POINTS = np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T
|
86 |
+
|
87 |
+
input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm')
|
88 |
+
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
89 |
+
max_deformation=C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
90 |
+
max_rotation=C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
91 |
+
num_control_points=C.NUM_CONTROL_PTS_AUG,
|
92 |
+
num_augmentations=C.NUM_AUGMENTATIONS,
|
93 |
+
gamma_augmentation=C.GAMMA_AUGMENTATION,
|
94 |
+
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION,
|
95 |
+
in_img_shape=image_input_shape,
|
96 |
+
out_img_shape=image_output_shape,
|
97 |
+
only_image=False,
|
98 |
+
only_resize=False,
|
99 |
+
trainable=False,
|
100 |
+
return_displacement_map=True)
|
101 |
+
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm))
|
102 |
+
|
103 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
104 |
+
config.gpu_options.allow_growth = True
|
105 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
106 |
+
|
107 |
+
sess = tf.Session(config=config)
|
108 |
+
tf.keras.backend.set_session(sess)
|
109 |
+
with sess.as_default():
|
110 |
+
sess.run(tf.global_variables_initializer())
|
111 |
+
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator))
|
112 |
+
for step, (in_batch, _) in progress_bar:
|
113 |
+
fix_img, mov_img, fix_seg, mov_seg, disp_map = augm_model.predict(in_batch)
|
114 |
+
|
115 |
+
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, False)
|
116 |
+
|
117 |
+
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step))
|
118 |
+
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step))
|
119 |
+
img_shape = fix_img.shape
|
120 |
+
segm_shape = fix_seg.shape
|
121 |
+
disp_shape = disp_map.shape
|
122 |
+
centroids_shape = fix_centroids.shape
|
123 |
+
with h5py.File(out_file, 'w') as f:
|
124 |
+
f.create_dataset('fix_image', shape=img_shape[1:], dtype=np.float32, data=fix_img[0, ...])
|
125 |
+
f.create_dataset('mov_image', shape=img_shape[1:], dtype=np.float32, data=mov_img[0, ...])
|
126 |
+
f.create_dataset('fix_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=fix_seg[0, ...])
|
127 |
+
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...])
|
128 |
+
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids)
|
129 |
+
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids)
|
130 |
+
|
131 |
+
with h5py.File(out_file_dm, 'w') as f:
|
132 |
+
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map)
|
133 |
+
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids)
|
134 |
+
|
135 |
+
print('Done')
|
COMET/COMET_train.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import keras
|
4 |
+
|
5 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
6 |
+
parentdir = os.path.dirname(currentdir)
|
7 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
8 |
+
|
9 |
+
from datetime import datetime
|
10 |
+
|
11 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
12 |
+
from tensorflow.python.keras.utils import Progbar
|
13 |
+
from tensorflow.keras import Input
|
14 |
+
from tensorflow.keras.models import Model
|
15 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
16 |
+
|
17 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
18 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion
|
19 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
20 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import _MSSSIM_WEIGHTS
|
21 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
22 |
+
from DeepDeformationMapRegistration.utils.misc import function_decorator
|
23 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer
|
24 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
25 |
+
|
26 |
+
from Brain_study.data_generator import BatchGenerator
|
27 |
+
from Brain_study.utils import SummaryDictionary, named_logs
|
28 |
+
|
29 |
+
import COMET.augmentation_constants as COMET_C
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
import tensorflow as tf
|
33 |
+
import voxelmorph as vxm
|
34 |
+
import h5py
|
35 |
+
import re
|
36 |
+
import itertools
|
37 |
+
|
38 |
+
|
39 |
+
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3, simil='ssim',
|
40 |
+
segm='dice', max_epochs=C.EPOCHS, freeze_layers=None,
|
41 |
+
acc_gradients=1, batch_size=16):
|
42 |
+
# 0. Input checks
|
43 |
+
assert dataset_folder is not None and output_folder is not None and model_file is not None
|
44 |
+
assert '.h5' in model_file, 'The model must be an H5 file'
|
45 |
+
|
46 |
+
USE_SEGMENTATIONS = bool(re.search('SEGGUIDED', model_file))
|
47 |
+
|
48 |
+
# 1. Load variables
|
49 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
50 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
51 |
+
C.GPU_NUM = str(gpu_num)
|
52 |
+
|
53 |
+
if freeze_layers is not None:
|
54 |
+
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
55 |
+
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
56 |
+
freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
|
57 |
+
if len(freeze_layers) > 1:
|
58 |
+
freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
|
59 |
+
|
60 |
+
os.makedirs(output_folder, exist_ok=True)
|
61 |
+
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
62 |
+
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
63 |
+
C.TRAINING_DATASET = dataset_folder #dataset_copy.copy_dataset()
|
64 |
+
C.VALIDATION_DATASET = validation_folder
|
65 |
+
C.ACCUM_GRADIENT_STEP = acc_gradients
|
66 |
+
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
|
67 |
+
C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
|
68 |
+
C.LEARNING_RATE = lr
|
69 |
+
C.LIMIT_NUM_SAMPLES = None
|
70 |
+
C.EPOCHS = max_epochs
|
71 |
+
|
72 |
+
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
|
73 |
+
"GPU: {}\n" \
|
74 |
+
"BATCH SIZE: {}\n" \
|
75 |
+
"LR: {}\n" \
|
76 |
+
"SIMILARITY: {}\n" \
|
77 |
+
"SEGMENTATION: {}\n"\
|
78 |
+
"REG. WEIGHT: {}\n" \
|
79 |
+
"EPOCHS: {:d}\n" \
|
80 |
+
"ACCUM. GRAD: {}\n" \
|
81 |
+
"EARLY STOP PATIENCE: {}\n" \
|
82 |
+
"FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
|
83 |
+
C.TRAINING_DATASET,
|
84 |
+
C.VALIDATION_DATASET,
|
85 |
+
C.GPU_NUM,
|
86 |
+
C.BATCH_SIZE,
|
87 |
+
C.LEARNING_RATE,
|
88 |
+
simil,
|
89 |
+
segm,
|
90 |
+
rw,
|
91 |
+
C.EPOCHS,
|
92 |
+
C.ACCUM_GRADIENT_STEP,
|
93 |
+
C.EARLY_STOP_PATIENCE,
|
94 |
+
freeze_layers)
|
95 |
+
|
96 |
+
log_file.write(aux)
|
97 |
+
print(aux)
|
98 |
+
|
99 |
+
# 2. Data generator
|
100 |
+
used_labels = 'all' if USE_SEGMENTATIONS else 'none'
|
101 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
102 |
+
C.TRAINING_PERC, labels=[used_labels], combine_segmentations=not USE_SEGMENTATIONS,
|
103 |
+
directory_val=C.VALIDATION_DATASET)
|
104 |
+
|
105 |
+
train_generator = data_generator.get_train_generator()
|
106 |
+
validation_generator = data_generator.get_validation_generator()
|
107 |
+
|
108 |
+
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
109 |
+
image_output_shape = [64] * 3
|
110 |
+
nb_labels = len(train_generator.get_segmentation_labels())
|
111 |
+
|
112 |
+
# 3. Load model
|
113 |
+
# IMPORTANT: the mode MUST be loaded AFTER setting up the session configuration
|
114 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
115 |
+
config.gpu_options.allow_growth = True
|
116 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
117 |
+
sess = tf.Session(config=config)
|
118 |
+
tf.keras.backend.set_session(sess)
|
119 |
+
|
120 |
+
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
121 |
+
NCC(image_input_shape).loss,
|
122 |
+
vxm.losses.MSE().loss,
|
123 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss,
|
124 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss,
|
125 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss,
|
126 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
|
127 |
+
]
|
128 |
+
|
129 |
+
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
130 |
+
NCC(image_input_shape).metric,
|
131 |
+
vxm.losses.MSE().loss,
|
132 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
|
133 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
134 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
135 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro,]
|
136 |
+
print('MODEL LOCATION: ', model_file)
|
137 |
+
|
138 |
+
try:
|
139 |
+
network = tf.keras.models.load_model(model_file, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
140 |
+
'VxmDense': vxm.networks.VxmDense,
|
141 |
+
'AdamAccumulated': AdamAccumulated,
|
142 |
+
'loss': loss_fncs,
|
143 |
+
'metric': metric_fncs},
|
144 |
+
compile=False)
|
145 |
+
except ValueError as e:
|
146 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
147 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
148 |
+
nb_features = [enc_features, dec_features]
|
149 |
+
if USE_SEGMENTATIONS:
|
150 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
151 |
+
nb_labels=nb_labels,
|
152 |
+
nb_unet_features=nb_features,
|
153 |
+
int_steps=0,
|
154 |
+
int_downsize=1,
|
155 |
+
seg_downsize=1)
|
156 |
+
else:
|
157 |
+
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
158 |
+
nb_unet_features=nb_features,
|
159 |
+
int_steps=0)
|
160 |
+
network.load_weights(model_file, by_name=True)
|
161 |
+
# 4. Freeze/unfreeze model layers
|
162 |
+
# freeze_layers = range(0, len(network.layers) - 8) # Do not freeze the last layers after the UNet (8 last layers)
|
163 |
+
# for l in freeze_layers:
|
164 |
+
# network.layers[l].trainable = False
|
165 |
+
# msg = "[INF]: Frozen layers {} to {}".format(0, len(network.layers) - 8)
|
166 |
+
# print(msg)
|
167 |
+
# log_file.write("INF: Frozen layers {} to {}".format(0, len(network.layers) - 8))
|
168 |
+
if freeze_layers is not None:
|
169 |
+
aux = list()
|
170 |
+
for r in freeze_layers:
|
171 |
+
for l in range(*r):
|
172 |
+
network.layers[l].trainable = False
|
173 |
+
aux.append(l)
|
174 |
+
aux.sort()
|
175 |
+
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
|
176 |
+
else:
|
177 |
+
msg = "[INF] None frozen layers"
|
178 |
+
print(msg)
|
179 |
+
log_file.write(msg)
|
180 |
+
# network.trainable = False # Freeze the base model
|
181 |
+
# # Create a new model on top
|
182 |
+
# input_new_model = keras.Input(network.input_shape)
|
183 |
+
# x = base_model(input_new_model, training=False)
|
184 |
+
# x =
|
185 |
+
# network = keras.Model(input_new_model, x)
|
186 |
+
|
187 |
+
network.summary()
|
188 |
+
network.summary(print_fn=log_file.writelines)
|
189 |
+
# Complete the model with the augmentation layer
|
190 |
+
augm_train_input_shape = train_generator.get_data_shape()[0] if USE_SEGMENTATIONS else train_generator.get_data_shape()[-1]
|
191 |
+
input_layer_train = Input(shape=augm_train_input_shape, name='input_train')
|
192 |
+
augm_layer_train = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
193 |
+
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
194 |
+
max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
195 |
+
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
|
196 |
+
num_augmentations=COMET_C.NUM_AUGMENTATIONS,
|
197 |
+
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
|
198 |
+
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
199 |
+
in_img_shape=image_input_shape,
|
200 |
+
out_img_shape=image_output_shape,
|
201 |
+
only_image=not USE_SEGMENTATIONS, # If baseline then True
|
202 |
+
only_resize=False,
|
203 |
+
trainable=False)
|
204 |
+
augm_model_train = Model(inputs=input_layer_train, outputs=augm_layer_train(input_layer_train))
|
205 |
+
|
206 |
+
input_layer_valid = Input(shape=validation_generator.get_data_shape()[0], name='input_valid')
|
207 |
+
augm_layer_valid = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
208 |
+
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
209 |
+
max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
210 |
+
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
|
211 |
+
num_augmentations=COMET_C.NUM_AUGMENTATIONS,
|
212 |
+
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
|
213 |
+
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
214 |
+
in_img_shape=image_input_shape,
|
215 |
+
out_img_shape=image_output_shape,
|
216 |
+
only_image=False,
|
217 |
+
only_resize=False,
|
218 |
+
trainable=False)
|
219 |
+
augm_model_valid = Model(inputs=input_layer_valid, outputs=augm_layer_valid(input_layer_valid))
|
220 |
+
|
221 |
+
# 5. Setup training environment: loss, optimizer, callbacks, evaluation
|
222 |
+
|
223 |
+
# Losses and loss weights
|
224 |
+
SSIM_KER_SIZE = 5
|
225 |
+
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
226 |
+
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
|
227 |
+
if simil.lower() == 'mse':
|
228 |
+
loss_fnc = vxm.losses.MSE().loss
|
229 |
+
elif simil.lower() == 'ncc':
|
230 |
+
loss_fnc = NCC(image_input_shape).loss
|
231 |
+
elif simil.lower() == 'ssim':
|
232 |
+
loss_fnc = StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss
|
233 |
+
elif simil.lower() == 'ms_ssim':
|
234 |
+
loss_fnc = MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss
|
235 |
+
elif simil.lower() == 'mse__ms_ssim' or simil.lower() == 'ms_ssim__mse':
|
236 |
+
@function_decorator('MSSSIM_MSE__loss')
|
237 |
+
def loss_fnc(y_true, y_pred):
|
238 |
+
return vxm.losses.MSE().loss(y_true, y_pred) + \
|
239 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
|
240 |
+
elif simil.lower() == 'ncc__ms_ssim' or simil.lower() == 'ms_ssim__ncc':
|
241 |
+
@function_decorator('MSSSIM_NCC__loss')
|
242 |
+
def loss_fnc(y_true, y_pred):
|
243 |
+
return NCC(image_input_shape).loss(y_true, y_pred) + \
|
244 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).loss(y_true, y_pred)
|
245 |
+
elif simil.lower() == 'mse__ssim' or simil.lower() == 'ssim__mse':
|
246 |
+
@function_decorator('SSIM_MSE__loss')
|
247 |
+
def loss_fnc(y_true, y_pred):
|
248 |
+
return vxm.losses.MSE().loss(y_true, y_pred) + \
|
249 |
+
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
|
250 |
+
elif simil.lower() == 'ncc__ssim' or simil.lower() == 'ssim__ncc':
|
251 |
+
@function_decorator('SSIM_NCC__loss')
|
252 |
+
def loss_fnc(y_true, y_pred):
|
253 |
+
return NCC(image_input_shape).loss(y_true, y_pred) + \
|
254 |
+
StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss(y_true, y_pred)
|
255 |
+
else:
|
256 |
+
raise ValueError('Unknown similarity metric: ' + simil)
|
257 |
+
|
258 |
+
if USE_SEGMENTATIONS:
|
259 |
+
if segm == 'hd':
|
260 |
+
loss_segm = HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).loss
|
261 |
+
elif segm == 'dice':
|
262 |
+
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss
|
263 |
+
elif segm == 'dice_macro':
|
264 |
+
loss_segm = GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).loss_macro
|
265 |
+
else:
|
266 |
+
raise ValueError('No valid value for segm')
|
267 |
+
|
268 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
269 |
+
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
270 |
+
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
271 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
|
272 |
+
update_freq='epoch', # or 'batch' or integer
|
273 |
+
write_graph=True, write_grads=True
|
274 |
+
)
|
275 |
+
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE, min_delta=0.00001)
|
276 |
+
|
277 |
+
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
278 |
+
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
279 |
+
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
280 |
+
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
281 |
+
if USE_SEGMENTATIONS:
|
282 |
+
losses = {'transformer': loss_fnc,
|
283 |
+
'seg_transformer': loss_segm,
|
284 |
+
'flow': vxm.losses.Grad('l2').loss}
|
285 |
+
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
286 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
|
287 |
+
tf.keras.losses.MSE,
|
288 |
+
NCC(image_input_shape).metric],
|
289 |
+
'seg_transformer': [GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric,
|
290 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [train_generator.get_data_shape()[2][-1]]).metric,
|
291 |
+
GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]], num_labels=nb_labels).metric_macro,
|
292 |
+
],
|
293 |
+
#'flow': vxm.losses.Grad('l2').loss
|
294 |
+
}
|
295 |
+
loss_weights = {'transformer': 1.,
|
296 |
+
'seg_transformer': 1.,
|
297 |
+
'flow': rw}
|
298 |
+
else:
|
299 |
+
losses = {'transformer': loss_fnc,
|
300 |
+
'flow': vxm.losses.Grad('l2').loss}
|
301 |
+
metrics = {'transformer': [StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).metric,
|
302 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE, power_factors=MS_SSIM_WEIGHTS).metric,
|
303 |
+
tf.keras.losses.MSE,
|
304 |
+
NCC(image_input_shape).metric],
|
305 |
+
#'flow': vxm.losses.Grad('l2').loss
|
306 |
+
}
|
307 |
+
loss_weights = {'transformer': 1.,
|
308 |
+
'flow': rw}
|
309 |
+
|
310 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
311 |
+
network.compile(optimizer=optimizer,
|
312 |
+
loss=losses,
|
313 |
+
loss_weights=loss_weights,
|
314 |
+
metrics=metrics)
|
315 |
+
|
316 |
+
# 6. Training loop
|
317 |
+
callback_tensorboard.set_model(network)
|
318 |
+
callback_early_stop.set_model(network)
|
319 |
+
callback_best_model.set_model(network)
|
320 |
+
callback_save_checkpoint.set_model(network)
|
321 |
+
|
322 |
+
summary = SummaryDictionary(network, C.BATCH_SIZE)
|
323 |
+
names = network.metrics_names
|
324 |
+
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
325 |
+
|
326 |
+
with sess.as_default():
|
327 |
+
# tf.global_variables_initializer()
|
328 |
+
callback_tensorboard.on_train_begin()
|
329 |
+
callback_early_stop.on_train_begin()
|
330 |
+
callback_best_model.on_train_begin()
|
331 |
+
callback_save_checkpoint.on_train_begin()
|
332 |
+
|
333 |
+
for epoch in range(C.EPOCHS):
|
334 |
+
callback_tensorboard.on_epoch_begin(epoch)
|
335 |
+
callback_early_stop.on_epoch_begin(epoch)
|
336 |
+
callback_best_model.on_epoch_begin(epoch)
|
337 |
+
callback_save_checkpoint.on_epoch_begin(epoch)
|
338 |
+
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
339 |
+
print("TRAIN")
|
340 |
+
|
341 |
+
log_file.write('\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
342 |
+
progress_bar = Progbar(len(train_generator), width=30, verbose=1)
|
343 |
+
for step, (in_batch, _) in enumerate(train_generator, 1):
|
344 |
+
callback_best_model.on_train_batch_begin(step)
|
345 |
+
callback_save_checkpoint.on_train_batch_begin(step)
|
346 |
+
callback_early_stop.on_train_batch_begin(step)
|
347 |
+
|
348 |
+
try:
|
349 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model_train.predict(in_batch)
|
350 |
+
np.nan_to_num(fix_img, copy=False)
|
351 |
+
np.nan_to_num(mov_img, copy=False)
|
352 |
+
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(np.sum(fix_img)):
|
353 |
+
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
|
354 |
+
np.unique(mov_img))
|
355 |
+
print(msg)
|
356 |
+
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
|
357 |
+
|
358 |
+
except InvalidArgumentError as err:
|
359 |
+
print('TF Error : {}'.format(str(err)))
|
360 |
+
continue
|
361 |
+
|
362 |
+
if USE_SEGMENTATIONS:
|
363 |
+
in_data = (mov_img, fix_img, mov_seg)
|
364 |
+
out_data = (fix_img, fix_img, fix_seg)
|
365 |
+
else:
|
366 |
+
in_data = (mov_img, fix_img)
|
367 |
+
out_data = (fix_img, fix_img)
|
368 |
+
ret = network.train_on_batch(x=in_data, y=out_data) # The second element doesn't matter
|
369 |
+
if np.isnan(ret).any():
|
370 |
+
os.makedirs(os.path.join(output_folder, 'corrupted'), exist_ok=True)
|
371 |
+
save_nifti(mov_img, os.path.join(output_folder, 'corrupted', 'mov_img_nan.nii.gz'))
|
372 |
+
save_nifti(fix_img, os.path.join(output_folder, 'corrupted', 'fix_img_nan.nii.gz'))
|
373 |
+
pred_img, dm = network((mov_img, fix_img))
|
374 |
+
save_nifti(pred_img, os.path.join(output_folder, 'corrupted', 'pred_img_nan.nii.gz'))
|
375 |
+
save_nifti(dm, os.path.join(output_folder, 'corrupted', 'dm_nan.nii.gz'))
|
376 |
+
log_file.write('\n\n[{}]\tERR: Corruption error'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
377 |
+
raise ValueError('CORRUPTION ERROR: Halting training')
|
378 |
+
|
379 |
+
summary.on_train_batch_end(ret)
|
380 |
+
callback_best_model.on_train_batch_end(step, named_logs(network, ret))
|
381 |
+
callback_save_checkpoint.on_train_batch_end(step, named_logs(network, ret))
|
382 |
+
callback_early_stop.on_train_batch_end(step, named_logs(network, ret))
|
383 |
+
progress_bar.update(step, zip(names, ret))
|
384 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
385 |
+
val_values = progress_bar._values.copy()
|
386 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
387 |
+
|
388 |
+
print('\nVALIDATION')
|
389 |
+
log_file.write('\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
390 |
+
progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
|
391 |
+
for step, (in_batch, _) in enumerate(validation_generator, 1):
|
392 |
+
try:
|
393 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model_valid.predict(in_batch)
|
394 |
+
except InvalidArgumentError as err:
|
395 |
+
print('TF Error : {}'.format(str(err)))
|
396 |
+
continue
|
397 |
+
|
398 |
+
if USE_SEGMENTATIONS:
|
399 |
+
in_data = (mov_img, fix_img, mov_seg)
|
400 |
+
out_data = (fix_img, fix_img, fix_seg)
|
401 |
+
else:
|
402 |
+
in_data = (mov_img, fix_img)
|
403 |
+
out_data = (fix_img, fix_img)
|
404 |
+
ret = network.test_on_batch(x=in_data,
|
405 |
+
y=out_data)
|
406 |
+
|
407 |
+
summary.on_validation_batch_end(ret)
|
408 |
+
progress_bar.update(step, zip(names, ret))
|
409 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
410 |
+
val_values = progress_bar._values.copy()
|
411 |
+
ret = [val_values[x][0]/val_values[x][1] for x in names]
|
412 |
+
|
413 |
+
train_generator.on_epoch_end()
|
414 |
+
validation_generator.on_epoch_end()
|
415 |
+
epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
|
416 |
+
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
417 |
+
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
418 |
+
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
419 |
+
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
420 |
+
print('End of epoch {}: '.format(epoch), ret, '\n')
|
421 |
+
|
422 |
+
callback_tensorboard.on_train_end()
|
423 |
+
callback_best_model.on_train_end()
|
424 |
+
callback_save_checkpoint.on_train_end()
|
425 |
+
callback_early_stop.on_train_end()
|
426 |
+
# 7. Wrap up
|
COMET/COMET_train_UW.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
4 |
+
parentdir = os.path.dirname(currentdir)
|
5 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
6 |
+
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
|
10 |
+
from tensorflow.python.keras.utils import Progbar
|
11 |
+
from tensorflow.keras import Input
|
12 |
+
from tensorflow.keras.models import Model
|
13 |
+
from tensorflow.python.framework.errors import InvalidArgumentError
|
14 |
+
|
15 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
16 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, \
|
17 |
+
HausdorffDistanceErosion
|
18 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
19 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import _MSSSIM_WEIGHTS
|
20 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
21 |
+
from DeepDeformationMapRegistration.utils.misc import function_decorator
|
22 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
23 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
24 |
+
|
25 |
+
from Brain_study.data_generator import BatchGenerator
|
26 |
+
from Brain_study.utils import SummaryDictionary, named_logs
|
27 |
+
|
28 |
+
import COMET.augmentation_constants as COMET_C
|
29 |
+
|
30 |
+
import numpy as np
|
31 |
+
import tensorflow as tf
|
32 |
+
import voxelmorph as vxm
|
33 |
+
import h5py
|
34 |
+
import re
|
35 |
+
import itertools
|
36 |
+
|
37 |
+
|
38 |
+
def launch_train(dataset_folder, validation_folder, output_folder, model_file, gpu_num=0, lr=1e-4, rw=5e-3,
|
39 |
+
simil=['ssim'], segm=['dice'], max_epochs=C.EPOCHS, prior_reg_w=5e-3, freeze_layers=None,
|
40 |
+
acc_gradients=1, batch_size=16):
|
41 |
+
# 0. Input checks
|
42 |
+
assert dataset_folder is not None and output_folder is not None and model_file is not None
|
43 |
+
assert '.h5' in model_file, 'The model must be an H5 file'
|
44 |
+
|
45 |
+
# 1. Load variables
|
46 |
+
os.environ['CUDA_DEVICE_ORDER'] = C.DEV_ORDER
|
47 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_num) # Check availability before running using 'nvidia-smi'
|
48 |
+
C.GPU_NUM = str(gpu_num)
|
49 |
+
|
50 |
+
if freeze_layers is not None:
|
51 |
+
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in freeze_layers), \
|
52 |
+
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
53 |
+
multiple_ranges = 'TOP' in freeze_layers
|
54 |
+
freeze_layers = [list(COMET_C.LAYER_RANGES[l]) for l in list(set(freeze_layers))]
|
55 |
+
freeze_layers = freeze_layers[0] if multiple_ranges else freeze_layers
|
56 |
+
|
57 |
+
# if len(freeze_layers) > 1:
|
58 |
+
# freeze_layers = list(itertools.chain.from_iterable(freeze_layers))
|
59 |
+
|
60 |
+
os.makedirs(output_folder, exist_ok=True)
|
61 |
+
# dataset_copy = DatasetCopy(dataset_folder, os.path.join(output_folder, 'temp'))
|
62 |
+
log_file = open(os.path.join(output_folder, 'log.txt'), 'w')
|
63 |
+
C.TRAINING_DATASET = dataset_folder # dataset_copy.copy_dataset()
|
64 |
+
C.VALIDATION_DATASET = validation_folder
|
65 |
+
C.ACCUM_GRADIENT_STEP = acc_gradients
|
66 |
+
C.BATCH_SIZE = batch_size if C.ACCUM_GRADIENT_STEP == 1 else C.ACCUM_GRADIENT_STEP
|
67 |
+
C.EARLY_STOP_PATIENCE = 5 * (C.ACCUM_GRADIENT_STEP / 2 if C.ACCUM_GRADIENT_STEP != 1 else 1)
|
68 |
+
C.LEARNING_RATE = lr
|
69 |
+
C.LIMIT_NUM_SAMPLES = None
|
70 |
+
C.EPOCHS = max_epochs
|
71 |
+
|
72 |
+
aux = "[{}]\tINFO:\nTRAIN DATASET: {}\nVALIDATION DATASET: {}\n" \
|
73 |
+
"GPU: {}\n" \
|
74 |
+
"BATCH SIZE: {}\n" \
|
75 |
+
"LR: {}\n" \
|
76 |
+
"SIMILARITY: {}\n" \
|
77 |
+
"SEGMENTATION: {}\n" \
|
78 |
+
"REG. WEIGHT: {}\n" \
|
79 |
+
"EPOCHS: {:d}\n" \
|
80 |
+
"ACCUM. GRAD: {}\n" \
|
81 |
+
"EARLY STOP PATIENCE: {}\n" \
|
82 |
+
"FROZEN LAYERS: {}".format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'),
|
83 |
+
C.TRAINING_DATASET,
|
84 |
+
C.VALIDATION_DATASET,
|
85 |
+
C.GPU_NUM,
|
86 |
+
C.BATCH_SIZE,
|
87 |
+
C.LEARNING_RATE,
|
88 |
+
simil,
|
89 |
+
segm,
|
90 |
+
rw,
|
91 |
+
C.EPOCHS,
|
92 |
+
C.ACCUM_GRADIENT_STEP,
|
93 |
+
C.EARLY_STOP_PATIENCE,
|
94 |
+
freeze_layers)
|
95 |
+
|
96 |
+
log_file.write(aux)
|
97 |
+
print(aux)
|
98 |
+
|
99 |
+
# 2. Data generator
|
100 |
+
data_generator = BatchGenerator(C.TRAINING_DATASET, C.BATCH_SIZE if C.ACCUM_GRADIENT_STEP == 1 else 1, True,
|
101 |
+
C.TRAINING_PERC, labels=['all'], combine_segmentations=False,
|
102 |
+
directory_val=C.VALIDATION_DATASET)
|
103 |
+
|
104 |
+
train_generator = data_generator.get_train_generator()
|
105 |
+
validation_generator = data_generator.get_validation_generator()
|
106 |
+
|
107 |
+
image_input_shape = train_generator.get_data_shape()[-1][:-1]
|
108 |
+
image_output_shape = [64] * 3
|
109 |
+
nb_labels = len(train_generator.get_segmentation_labels())
|
110 |
+
|
111 |
+
# 3. Load model
|
112 |
+
# IMPORTANT: the mode MUST be loaded AFTER setting up the session configuration
|
113 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
114 |
+
config.gpu_options.allow_growth = True
|
115 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
116 |
+
sess = tf.Session(config=config)
|
117 |
+
tf.keras.backend.set_session(sess)
|
118 |
+
|
119 |
+
print('MODEL LOCATION: ', model_file)
|
120 |
+
|
121 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
122 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
123 |
+
nb_features = [enc_features, dec_features]
|
124 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
125 |
+
nb_labels=nb_labels,
|
126 |
+
nb_unet_features=nb_features,
|
127 |
+
int_steps=0,
|
128 |
+
int_downsize=1,
|
129 |
+
seg_downsize=1)
|
130 |
+
network.load_weights(model_file, by_name=True)
|
131 |
+
|
132 |
+
# 4. Freeze/unfreeze model layers
|
133 |
+
if freeze_layers is not None:
|
134 |
+
aux = list()
|
135 |
+
for r in freeze_layers:
|
136 |
+
for l in range(*r):
|
137 |
+
network.layers[l].trainable = False
|
138 |
+
aux.append(l)
|
139 |
+
aux.sort()
|
140 |
+
msg = "[INF]: Frozen layers {}".format(', '.join([str(a) for a in aux]))
|
141 |
+
else:
|
142 |
+
msg = "[INF] None frozen layers"
|
143 |
+
print(msg)
|
144 |
+
log_file.write(msg)
|
145 |
+
|
146 |
+
network.summary()
|
147 |
+
network.summary(print_fn=log_file.write)
|
148 |
+
# Complete the model with the augmentation layer
|
149 |
+
input_layer_train = Input(shape=train_generator.get_data_shape()[0], name='input_train')
|
150 |
+
augm_layer = AugmentationLayer(max_displacement=COMET_C.MAX_AUG_DISP, # Max 30 mm in isotropic space
|
151 |
+
max_deformation=COMET_C.MAX_AUG_DEF, # Max 6 mm in isotropic space
|
152 |
+
max_rotation=COMET_C.MAX_AUG_ANGLE, # Max 10 deg in isotropic space
|
153 |
+
num_control_points=COMET_C.NUM_CONTROL_PTS_AUG,
|
154 |
+
num_augmentations=COMET_C.NUM_AUGMENTATIONS,
|
155 |
+
gamma_augmentation=COMET_C.GAMMA_AUGMENTATION,
|
156 |
+
brightness_augmentation=COMET_C.BRIGHTNESS_AUGMENTATION,
|
157 |
+
in_img_shape=image_input_shape,
|
158 |
+
out_img_shape=image_output_shape,
|
159 |
+
only_image=False, # If baseline then True
|
160 |
+
only_resize=False,
|
161 |
+
trainable=False)
|
162 |
+
augm_model = Model(inputs=input_layer_train, outputs=augm_layer(input_layer_train))
|
163 |
+
|
164 |
+
# 5. Setup training environment: loss, optimizer, callbacks, evaluation
|
165 |
+
|
166 |
+
# Losses and loss weights
|
167 |
+
SSIM_KER_SIZE = 5
|
168 |
+
MS_SSIM_WEIGHTS = _MSSSIM_WEIGHTS[:3]
|
169 |
+
MS_SSIM_WEIGHTS /= np.sum(MS_SSIM_WEIGHTS)
|
170 |
+
loss_simil = []
|
171 |
+
prior_loss_w = []
|
172 |
+
for s in simil:
|
173 |
+
if s == 'ssim':
|
174 |
+
loss_simil.append(StructuralSimilarity_simplified(patch_size=SSIM_KER_SIZE, dim=3, dynamic_range=1.).loss)
|
175 |
+
prior_loss_w.append(1.)
|
176 |
+
elif s == 'ms_ssim':
|
177 |
+
loss_simil.append(MultiScaleStructuralSimilarity(max_val=1., filter_size=SSIM_KER_SIZE,
|
178 |
+
power_factors=MS_SSIM_WEIGHTS).loss)
|
179 |
+
prior_loss_w.append(1.)
|
180 |
+
elif s == 'ncc':
|
181 |
+
loss_simil.append(NCC(image_input_shape).loss)
|
182 |
+
prior_loss_w.append(1.)
|
183 |
+
elif s == 'mse':
|
184 |
+
loss_simil.append(vxm.losses.MSE().loss)
|
185 |
+
prior_loss_w.append(1.)
|
186 |
+
else:
|
187 |
+
raise ValueError('Unknown similarity function: ', s)
|
188 |
+
|
189 |
+
loss_segm = []
|
190 |
+
for s in segm:
|
191 |
+
if s == 'dice':
|
192 |
+
loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]],
|
193 |
+
num_labels=nb_labels).loss)
|
194 |
+
prior_loss_w.append(1.)
|
195 |
+
elif s == 'dice_macro':
|
196 |
+
loss_segm.append(GeneralizedDICEScore(image_output_shape + [train_generator.get_data_shape()[2][-1]],
|
197 |
+
num_labels=nb_labels).loss_macro)
|
198 |
+
prior_loss_w.append(1.)
|
199 |
+
elif s == 'hd':
|
200 |
+
loss_segm.append(HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [
|
201 |
+
train_generator.get_data_shape()[2][-1]]).loss)
|
202 |
+
prior_loss_w.append(1.)
|
203 |
+
else:
|
204 |
+
raise ValueError('Unknown similarity function: ', s)
|
205 |
+
|
206 |
+
# Uncertainty weigthing module
|
207 |
+
grad = tf.keras.Input(shape=(*image_output_shape, 3), name='multiLoss_grad_input', dtype=tf.float32)
|
208 |
+
fix_seg = tf.keras.Input(shape=(*image_output_shape, len(train_generator.get_segmentation_labels())),
|
209 |
+
name='multiLoss_fix_seg_input', dtype=tf.float32)
|
210 |
+
|
211 |
+
multiLoss = UncertaintyWeighting(num_loss_fns=len(loss_simil) + len(loss_segm),
|
212 |
+
num_reg_fns=1,
|
213 |
+
loss_fns=[*loss_simil,
|
214 |
+
*loss_segm],
|
215 |
+
reg_fns=[vxm.losses.Grad('l2').loss],
|
216 |
+
prior_loss_w=prior_loss_w,
|
217 |
+
# prior_loss_w=[1., 0.1, 1., 1.],
|
218 |
+
prior_reg_w=[prior_reg_w],
|
219 |
+
name='MultiLossLayer')
|
220 |
+
loss = multiLoss([*[network.inputs[1]] * len(loss_simil), *[fix_seg] * len(loss_segm),
|
221 |
+
*[network.outputs[0]] * len(loss_simil), *[network.outputs[2]] * len(loss_simil),
|
222 |
+
grad,
|
223 |
+
network.outputs[1]])
|
224 |
+
|
225 |
+
# inputs = [mov_img, fix_img, mov_segm, fix_segm, zero_grads]
|
226 |
+
# outputs = [pred_img, flow, pred_segm, loss]
|
227 |
+
full_model = tf.keras.Model(inputs=network.inputs + [fix_seg, grad],
|
228 |
+
outputs=network.outputs + [loss])
|
229 |
+
|
230 |
+
os.makedirs(os.path.join(output_folder, 'checkpoints'), exist_ok=True)
|
231 |
+
os.makedirs(os.path.join(output_folder, 'tensorboard'), exist_ok=True)
|
232 |
+
callback_tensorboard = TensorBoard(log_dir=os.path.join(output_folder, 'tensorboard'),
|
233 |
+
batch_size=C.BATCH_SIZE, write_images=False, histogram_freq=0,
|
234 |
+
update_freq='epoch', # or 'batch' or integer
|
235 |
+
write_graph=True, write_grads=True
|
236 |
+
)
|
237 |
+
callback_early_stop = EarlyStopping(monitor='val_loss', verbose=1, patience=C.EARLY_STOP_PATIENCE,
|
238 |
+
min_delta=0.00001)
|
239 |
+
|
240 |
+
callback_best_model = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'best_model.h5'),
|
241 |
+
save_best_only=True, monitor='val_loss', verbose=1, mode='min')
|
242 |
+
callback_save_checkpoint = ModelCheckpoint(os.path.join(output_folder, 'checkpoints', 'checkpoint.h5'),
|
243 |
+
save_weights_only=True, monitor='val_loss', verbose=0, mode='min')
|
244 |
+
|
245 |
+
optimizer = AdamAccumulated(C.ACCUM_GRADIENT_STEP, C.LEARNING_RATE)
|
246 |
+
full_model.compile(optimizer=optimizer,
|
247 |
+
loss=None, )
|
248 |
+
|
249 |
+
# 6. Training loop
|
250 |
+
callback_tensorboard.set_model(full_model)
|
251 |
+
callback_early_stop.set_model(full_model)
|
252 |
+
callback_best_model.set_model(network) # ONLY SAVE THE NETWORK!!!
|
253 |
+
callback_save_checkpoint.set_model(network) # ONLY SAVE THE NETWORK!!!
|
254 |
+
|
255 |
+
summary = SummaryDictionary(full_model, C.BATCH_SIZE)
|
256 |
+
names = full_model.metrics_names
|
257 |
+
zero_grads = tf.zeros_like(network.references.pos_flow, name='dummy_zero_grads') # Dummy zeros-tensor
|
258 |
+
log_file.write('\n\n[{}]\tINFO:\tStart training\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y')))
|
259 |
+
|
260 |
+
with sess.as_default():
|
261 |
+
# tf.global_variables_initializer()
|
262 |
+
callback_tensorboard.on_train_begin()
|
263 |
+
callback_early_stop.on_train_begin()
|
264 |
+
callback_best_model.on_train_begin()
|
265 |
+
callback_save_checkpoint.on_train_begin()
|
266 |
+
|
267 |
+
for epoch in range(C.EPOCHS):
|
268 |
+
callback_tensorboard.on_epoch_begin(epoch)
|
269 |
+
callback_early_stop.on_epoch_begin(epoch)
|
270 |
+
callback_best_model.on_epoch_begin(epoch)
|
271 |
+
callback_save_checkpoint.on_epoch_begin(epoch)
|
272 |
+
print("\nEpoch {}/{}".format(epoch, C.EPOCHS))
|
273 |
+
print("TRAIN")
|
274 |
+
|
275 |
+
log_file.write(
|
276 |
+
'\n\n[{}]\tINFO:\tTraining epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
277 |
+
progress_bar = Progbar(len(train_generator), width=30, verbose=1)
|
278 |
+
for step, (in_batch, _) in enumerate(train_generator, 1):
|
279 |
+
callback_best_model.on_train_batch_begin(step)
|
280 |
+
callback_save_checkpoint.on_train_batch_begin(step)
|
281 |
+
callback_early_stop.on_train_batch_begin(step)
|
282 |
+
|
283 |
+
try:
|
284 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
285 |
+
np.nan_to_num(fix_img, copy=False)
|
286 |
+
np.nan_to_num(mov_img, copy=False)
|
287 |
+
if np.isnan(np.sum(mov_img)) or np.isnan(np.sum(fix_img)) or np.isinf(np.sum(mov_img)) or np.isinf(
|
288 |
+
np.sum(fix_img)):
|
289 |
+
msg = 'CORRUPTED DATA!! Unique: Fix: {}\tMoving: {}'.format(np.unique(fix_img),
|
290 |
+
np.unique(mov_img))
|
291 |
+
print(msg)
|
292 |
+
log_file.write('\n\n[{}]\tWAR: {}'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), msg))
|
293 |
+
|
294 |
+
except InvalidArgumentError as err:
|
295 |
+
print('TF Error : {}'.format(str(err)))
|
296 |
+
continue
|
297 |
+
|
298 |
+
ret = full_model.train_on_batch(
|
299 |
+
x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads)) # The second element doesn't matter
|
300 |
+
|
301 |
+
summary.on_train_batch_end(ret)
|
302 |
+
callback_best_model.on_train_batch_end(step, named_logs(full_model, ret))
|
303 |
+
callback_save_checkpoint.on_train_batch_end(step, named_logs(full_model, ret))
|
304 |
+
callback_early_stop.on_train_batch_end(step, named_logs(full_model, ret))
|
305 |
+
progress_bar.update(step, zip(names, ret))
|
306 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
307 |
+
val_values = progress_bar._values.copy()
|
308 |
+
ret = [val_values[x][0] / val_values[x][1] for x in names]
|
309 |
+
|
310 |
+
print('\nVALIDATION')
|
311 |
+
log_file.write(
|
312 |
+
'\n\n[{}]\tINFO:\tValidation epoch {}\n\n'.format(datetime.now().strftime('%H:%M:%S\t%d/%m/%Y'), epoch))
|
313 |
+
progress_bar = Progbar(len(validation_generator), width=30, verbose=1)
|
314 |
+
for step, (in_batch, _) in enumerate(validation_generator, 1):
|
315 |
+
try:
|
316 |
+
fix_img, mov_img, fix_seg, mov_seg = augm_model.predict(in_batch)
|
317 |
+
except InvalidArgumentError as err:
|
318 |
+
print('TF Error : {}'.format(str(err)))
|
319 |
+
continue
|
320 |
+
|
321 |
+
ret = full_model.test_on_batch(x=(mov_img, fix_img, mov_seg, fix_seg, zero_grads))
|
322 |
+
|
323 |
+
summary.on_validation_batch_end(ret)
|
324 |
+
progress_bar.update(step, zip(names, ret))
|
325 |
+
log_file.write('\t\tStep {:03d}: {}'.format(step, ret))
|
326 |
+
val_values = progress_bar._values.copy()
|
327 |
+
ret = [val_values[x][0] / val_values[x][1] for x in names]
|
328 |
+
|
329 |
+
train_generator.on_epoch_end()
|
330 |
+
validation_generator.on_epoch_end()
|
331 |
+
epoch_summary = summary.on_epoch_end() # summary resets after on_epoch_end() call
|
332 |
+
callback_tensorboard.on_epoch_end(epoch, epoch_summary)
|
333 |
+
callback_best_model.on_epoch_end(epoch, epoch_summary)
|
334 |
+
callback_save_checkpoint.on_epoch_end(epoch, epoch_summary)
|
335 |
+
callback_early_stop.on_epoch_end(epoch, epoch_summary)
|
336 |
+
print('End of epoch {}: '.format(epoch), ret, '\n')
|
337 |
+
|
338 |
+
callback_tensorboard.on_train_end()
|
339 |
+
callback_best_model.on_train_end()
|
340 |
+
callback_save_checkpoint.on_train_end()
|
341 |
+
callback_early_stop.on_train_end()
|
342 |
+
# 7. Wrap up
|
COMET/Evaluate_network.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
import time
|
5 |
+
import tkinter
|
6 |
+
|
7 |
+
import h5py
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
11 |
+
parentdir = os.path.dirname(currentdir)
|
12 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
13 |
+
|
14 |
+
import tensorflow as tf
|
15 |
+
# tf.enable_eager_execution(config=config)
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import pandas as pd
|
19 |
+
import voxelmorph as vxm
|
20 |
+
|
21 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
22 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm
|
23 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
24 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
25 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
26 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
27 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
28 |
+
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
29 |
+
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
|
30 |
+
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
31 |
+
from scipy.interpolate import RegularGridInterpolator
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
import h5py
|
35 |
+
import re
|
36 |
+
from Brain_study.data_generator import BatchGenerator
|
37 |
+
|
38 |
+
import argparse
|
39 |
+
|
40 |
+
from skimage.transform import warp
|
41 |
+
import neurite as ne
|
42 |
+
|
43 |
+
# from tkinter import filedialog as fd
|
44 |
+
|
45 |
+
|
46 |
+
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/test_fixed'
|
47 |
+
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/checkpoints/best_model.h5'
|
48 |
+
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/COMET/ERASE/COMET_L_ssim__MET_mse_ncc_ssim_141343-01122021/'
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
parser = argparse.ArgumentParser()
|
53 |
+
parser.add_argument('-m', '--model', nargs='+', type=str, help='.h5 of the model', default=None)
|
54 |
+
parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
|
55 |
+
parser.add_argument('--gpu', type=int, help='GPU', default=0)
|
56 |
+
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on', default=DATASET)
|
57 |
+
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
|
58 |
+
parser.add_argument('--outdirname', type=str, default='Evaluate')
|
59 |
+
args = parser.parse_args()
|
60 |
+
if args.model is not None:
|
61 |
+
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
|
62 |
+
MODEL_FILE_LIST = args.model
|
63 |
+
DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in args.model]
|
64 |
+
elif args.dir is not None:
|
65 |
+
assert '.h5' not in args.dir[0], 'Provided checkpoint file, user -m/--model instead'
|
66 |
+
MODEL_FILE_LIST = [os.path.join(dir_path, 'checkpoints', 'best_model.h5') for dir_path in args.dir]
|
67 |
+
DATA_ROOT_DIR_LIST = args.dir
|
68 |
+
else:
|
69 |
+
# try:
|
70 |
+
# MODEL_FILE_LIST = fd.askopenfilenames(title='Select .h model file',
|
71 |
+
# initialdir='/mnt/EncryptedData1/Users/javier/train_output/COMET',
|
72 |
+
# filetypes=(('Model', '*.h5'),))
|
73 |
+
# if len(MODEL_FILE_LIST):
|
74 |
+
# DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in MODEL_FILE_LIST]
|
75 |
+
# else:
|
76 |
+
# raise ValueError('No model selected')
|
77 |
+
# except tkinter.TclError as e:
|
78 |
+
# raise ValueError('Cannot launch TkInter file explorer. User the -m or -d arguments')
|
79 |
+
pass
|
80 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
81 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
82 |
+
DATASET = args.dataset
|
83 |
+
list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
|
84 |
+
list_test_files.sort()
|
85 |
+
|
86 |
+
with h5py.File(list_test_files[0], 'r') as f:
|
87 |
+
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
|
88 |
+
nb_labels = f['fix_segmentations'][:].shape[-1]
|
89 |
+
|
90 |
+
# Header of the metrics csv file
|
91 |
+
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
92 |
+
|
93 |
+
# TF stuff
|
94 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
95 |
+
config.gpu_options.allow_growth = True
|
96 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
97 |
+
config.allow_soft_placement = True
|
98 |
+
|
99 |
+
sess = tf.Session(config=config)
|
100 |
+
tf.keras.backend.set_session(sess)
|
101 |
+
|
102 |
+
# Loss and metric functions. Common to all models
|
103 |
+
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
104 |
+
NCC(image_input_shape).loss,
|
105 |
+
vxm.losses.MSE().loss,
|
106 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
|
107 |
+
|
108 |
+
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
109 |
+
NCC(image_input_shape).metric,
|
110 |
+
vxm.losses.MSE().loss,
|
111 |
+
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
|
112 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
113 |
+
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
114 |
+
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
115 |
+
|
116 |
+
### METRICS GRAPH ###
|
117 |
+
fix_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='fix_img')
|
118 |
+
pred_img_ph = tf.placeholder(tf.float32, (1, *image_output_shape, 1), name='pred_img')
|
119 |
+
fix_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='fix_seg')
|
120 |
+
pred_seg_ph = tf.placeholder(tf.float32, (1, *image_output_shape, nb_labels), name='pred_seg')
|
121 |
+
|
122 |
+
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
123 |
+
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
124 |
+
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
125 |
+
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
126 |
+
dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
127 |
+
hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
128 |
+
dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
129 |
+
# hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
|
130 |
+
|
131 |
+
# Needed for VxmDense type of network
|
132 |
+
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
133 |
+
|
134 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata')
|
135 |
+
|
136 |
+
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
137 |
+
print('MODEL LOCATION: ', MODEL_FILE)
|
138 |
+
|
139 |
+
# data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
|
140 |
+
output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
|
141 |
+
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
|
142 |
+
if args.erase:
|
143 |
+
shutil.rmtree(output_folder, ignore_errors=True)
|
144 |
+
os.makedirs(output_folder, exist_ok=True)
|
145 |
+
print('DESTINATION FOLDER: ', output_folder)
|
146 |
+
|
147 |
+
try:
|
148 |
+
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
149 |
+
'VxmDense': vxm.networks.VxmDense,
|
150 |
+
'AdamAccumulated': AdamAccumulated,
|
151 |
+
'loss': loss_fncs,
|
152 |
+
'metric': metric_fncs},
|
153 |
+
compile=False)
|
154 |
+
except ValueError as e:
|
155 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
156 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
157 |
+
nb_features = [enc_features, dec_features]
|
158 |
+
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
159 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
160 |
+
nb_labels=nb_labels,
|
161 |
+
nb_unet_features=nb_features,
|
162 |
+
int_steps=0,
|
163 |
+
int_downsize=1,
|
164 |
+
seg_downsize=1)
|
165 |
+
else:
|
166 |
+
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
167 |
+
nb_unet_features=nb_features,
|
168 |
+
int_steps=0)
|
169 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
170 |
+
# Record metrics
|
171 |
+
metrics_file = os.path.join(output_folder, 'metrics.csv')
|
172 |
+
with open(metrics_file, 'w') as f:
|
173 |
+
f.write(';'.join(csv_header)+'\n')
|
174 |
+
|
175 |
+
ssim = ncc = mse = ms_ssim = dice = hd = 0
|
176 |
+
with sess.as_default():
|
177 |
+
sess.run(tf.global_variables_initializer())
|
178 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
179 |
+
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
180 |
+
for step, in_batch in progress_bar:
|
181 |
+
with h5py.File(in_batch, 'r') as f:
|
182 |
+
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
183 |
+
mov_img = f['mov_image'][:][np.newaxis, ...]
|
184 |
+
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
185 |
+
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
186 |
+
fix_centroids = f['fix_centroids'][:]
|
187 |
+
|
188 |
+
if network.name == 'vxm_dense_semi_supervised_seg':
|
189 |
+
t0 = time.time()
|
190 |
+
pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
|
191 |
+
t1 = time.time()
|
192 |
+
else:
|
193 |
+
t0 = time.time()
|
194 |
+
pred_img, disp_map = network.predict([mov_img, fix_img])
|
195 |
+
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
|
196 |
+
t1 = time.time()
|
197 |
+
|
198 |
+
pred_img = min_max_norm(pred_img)
|
199 |
+
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(0, nb_labels), brain_study=False)
|
200 |
+
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
201 |
+
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
202 |
+
|
203 |
+
# I need the labels to be OHE to compute the segmentation metrics.
|
204 |
+
dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
205 |
+
|
206 |
+
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
207 |
+
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
208 |
+
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
209 |
+
|
210 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
|
211 |
+
ms_ssim = ms_ssim[0]
|
212 |
+
|
213 |
+
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
214 |
+
upsample_scale = 128 / 64
|
215 |
+
fix_centroids_isotropic = fix_centroids * upsample_scale
|
216 |
+
# mov_centroids_isotropic = mov_centroids * upsample_scale
|
217 |
+
pred_centroids_isotropic = pred_centroids * upsample_scale
|
218 |
+
|
219 |
+
fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
220 |
+
# mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
221 |
+
pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
|
222 |
+
# Now we can measure the TRE in mm
|
223 |
+
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
224 |
+
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
225 |
+
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
226 |
+
|
227 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
228 |
+
with open(metrics_file, 'a') as f:
|
229 |
+
f.write(';'.join(map(str, new_line))+'\n')
|
230 |
+
|
231 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
232 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
233 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
234 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
235 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
236 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
237 |
+
|
238 |
+
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
239 |
+
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
240 |
+
# f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
|
241 |
+
# f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
|
242 |
+
# f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
|
243 |
+
# f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
|
244 |
+
|
245 |
+
# magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
|
246 |
+
# _ = plt.hist(magnitude.flatten())
|
247 |
+
# plt.title('Histogram of disp. magnitudes')
|
248 |
+
# plt.show(block=False)
|
249 |
+
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
250 |
+
# plt.close()
|
251 |
+
|
252 |
+
plot_predictions(fix_img, mov_img, disp_map, pred_img, os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False)
|
253 |
+
plot_predictions(fix_seg, mov_seg, disp_map, pred_seg, os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False)
|
254 |
+
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False)
|
255 |
+
|
256 |
+
progress_bar.set_description('SSIM {:.04f}\tG_DICE: {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice, dice_macro))
|
257 |
+
|
258 |
+
print('Summary\n=======\n')
|
259 |
+
print('\nAVG:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).mean(axis=0)) + '\nSTD:\n' + str(pd.read_csv(metrics_file, sep=';', header=0).std(axis=0)))
|
260 |
+
print('\n=======\n')
|
261 |
+
tf.keras.backend.clear_session()
|
262 |
+
# sess.close()
|
263 |
+
del network
|
264 |
+
print('Done')
|
COMET/MultiTrain_cli.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
|
10 |
+
|
11 |
+
err = list()
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--model', type=str, help='Path to the model file', required=True)
|
16 |
+
parser.add_argument('--dataset', type=str, help='Location of the training data', default=TRAIN_DATASET)
|
17 |
+
parser.add_argument('--validation', type=str, help='Location of the validation data', default=None)
|
18 |
+
parser.add_argument('--similarity', nargs='+', help='Similarity metric: mse, ncc, ssim', default=['ncc'])
|
19 |
+
parser.add_argument('--segmentation', nargs='+', help='Segmentation loss function: hd, dice', default=['dice'])
|
20 |
+
parser.add_argument('--output', type=str, help='Output directory', default=TRAIN_DATASET)
|
21 |
+
parser.add_argument('--gpu', type=str, help='GPU number', default='0')
|
22 |
+
parser.add_argument('--lr', type=float, help='Learning rate', default=1e-4)
|
23 |
+
parser.add_argument('--rw', type=float, help='Regularization weigh', default=5e-3)
|
24 |
+
parser.add_argument('--epochs', type=int, help='Max number of epochs', default=1000)
|
25 |
+
parser.add_argument('--name', type=str, default='COMET')
|
26 |
+
parser.add_argument('--uw', type=bool, default=False)
|
27 |
+
parser.add_argument('--freeze', nargs='+', help='What layers to freeze: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM', default=None)
|
28 |
+
parser.add_argument('--epochs', default=1000)
|
29 |
+
parser.add_argument('--batch', default=16)
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
print('TRAIN ' + args.dataset)
|
33 |
+
|
34 |
+
if args.uw:
|
35 |
+
from COMET.COMET_train_UW import launch_train
|
36 |
+
simil = args.similarity
|
37 |
+
segm = args.segmentation
|
38 |
+
output_folder = os.path.join(args.output, '{}_Lsim_{}__Lseg_{}'.format(args.name, '_'.join(simil), '_'.join(segm)))
|
39 |
+
else:
|
40 |
+
from COMET.COMET_train import launch_train
|
41 |
+
simil = args.similarity[0]
|
42 |
+
segm = args.segmentation[0]
|
43 |
+
output_folder = os.path.join(args.output, '{}_Lsim_{}__Lseg_{}'.format(args.name, simil, segm))
|
44 |
+
output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")
|
45 |
+
|
46 |
+
if args.freeze is not None:
|
47 |
+
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in args.freeze),\
|
48 |
+
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
49 |
+
args.freeze = list(set(args.freeze))
|
50 |
+
|
51 |
+
launch_train(dataset_folder=args.dataset,
|
52 |
+
validation_folder=args.validation,
|
53 |
+
output_folder=output_folder,
|
54 |
+
gpu_num=args.gpu,
|
55 |
+
lr=args.lr,
|
56 |
+
rw=args.rw,
|
57 |
+
simil=simil,
|
58 |
+
segm=segm,
|
59 |
+
max_epochs=args.epochs,
|
60 |
+
model_file=args.model,
|
61 |
+
freeze_layers=args.freeze)
|
COMET/MultiTrain_config.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from configparser import ConfigParser
|
8 |
+
from shutil import copy2
|
9 |
+
import os
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
|
13 |
+
|
14 |
+
err = list()
|
15 |
+
|
16 |
+
if __name__ == '__main__':
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('--ini', help='Configuration file')
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
configFile = ConfigParser()
|
22 |
+
configFile.read(args.ini)
|
23 |
+
print('Loaded configuration file: ' + args.ini)
|
24 |
+
print({section: dict(configFile[section]) for section in configFile.sections()})
|
25 |
+
print('\n\n')
|
26 |
+
|
27 |
+
trainConfig = configFile['TRAIN']
|
28 |
+
lossesConfig = configFile['LOSSES']
|
29 |
+
datasetConfig = configFile['DATASETS']
|
30 |
+
othersConfig = configFile['OTHERS']
|
31 |
+
|
32 |
+
print('TRAIN MODEL IN' + trainConfig['model'])
|
33 |
+
|
34 |
+
simil = lossesConfig['similarity'].split(',')
|
35 |
+
segm = lossesConfig['segmentation'].split(',')
|
36 |
+
if trainConfig['name'].lower() == 'uw':
|
37 |
+
from COMET.COMET_train_UW import launch_train
|
38 |
+
output_folder = os.path.join(othersConfig['outputFolder'], '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], '_'.join(simil), '_'.join(segm)))
|
39 |
+
else:
|
40 |
+
from COMET.COMET_train import launch_train
|
41 |
+
simil = simil[0]
|
42 |
+
segm = segm[0]
|
43 |
+
output_folder = os.path.join(othersConfig['outputFolder'], '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], simil, segm))
|
44 |
+
output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")
|
45 |
+
|
46 |
+
try:
|
47 |
+
froozen_layers = eval(trainConfig['freeze'])
|
48 |
+
except NameError as err:
|
49 |
+
froozen_layers = [trainConfig['freeze'].upper()]
|
50 |
+
if froozen_layers is not None:
|
51 |
+
assert all(s in ['INPUT', 'OUTPUT', 'ENCODER', 'DECODER', 'TOP', 'BOTTOM'] for s in froozen_layers),\
|
52 |
+
'Invalid option for "freeze". Expected one or several of: INPUT, OUTPUT, ENCODER, DECODER, TOP, BOTTOM'
|
53 |
+
froozen_layers = list(set(froozen_layers)) # Unique elements
|
54 |
+
|
55 |
+
# copy the configuration file to the destionation folder
|
56 |
+
os.makedirs(output_folder, exist_ok=True)
|
57 |
+
copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
|
58 |
+
|
59 |
+
launch_train(dataset_folder=datasetConfig['train'],
|
60 |
+
validation_folder=datasetConfig['validation'],
|
61 |
+
output_folder=output_folder,
|
62 |
+
gpu_num=eval(trainConfig['gpu']),
|
63 |
+
lr=eval(trainConfig['learningRate']),
|
64 |
+
rw=eval(trainConfig['regularizationWeight']),
|
65 |
+
simil=simil,
|
66 |
+
segm=segm,
|
67 |
+
max_epochs=eval(trainConfig['epochs']),
|
68 |
+
model_file=trainConfig['model'],
|
69 |
+
freeze_layers=froozen_layers,
|
70 |
+
acc_gradients=eval(trainConfig['accumulativeGradients']),
|
71 |
+
batch_size=eval(trainConfig['batchSize']))
|
COMET/augmentation_constants.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# Constants for augmentation layer
|
4 |
+
# .../T1/training/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
5 |
+
# The augmentation values will be scaled using the average+std
|
6 |
+
ZOOM_FACTORS = np.asarray([0.5032864535069749, 0.5363100665659675, 0.6292598243796296])
|
7 |
+
MAX_AUG_DISP_ISOT = 30
|
8 |
+
MAX_AUG_DEF_ISOT = 6
|
9 |
+
MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * ZOOM_FACTORS) # Scaled displacements
|
10 |
+
MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * ZOOM_FACTORS) # Scaled deformations
|
11 |
+
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * ZOOM_FACTORS[1] / ZOOM_FACTORS[0]) * 180 / np.pi,
|
12 |
+
np.arctan(np.tan(10*np.pi/180) * ZOOM_FACTORS[2] / ZOOM_FACTORS[1]) * 180 / np.pi,
|
13 |
+
np.arctan(np.tan(10*np.pi/180) * ZOOM_FACTORS[2] / ZOOM_FACTORS[0]) * 180 / np.pi]) # Scaled angles
|
14 |
+
GAMMA_AUGMENTATION = False
|
15 |
+
BRIGHTNESS_AUGMENTATION = False
|
16 |
+
NUM_CONTROL_PTS_AUG = 10
|
17 |
+
NUM_AUGMENTATIONS = 5
|
18 |
+
|
19 |
+
IN_LAYERS = (0, 3)
|
20 |
+
OUT_LAYERS = (33, 39)
|
21 |
+
|
22 |
+
ENCONDER_LAYERS = (3, 17)
|
23 |
+
DECODER_LAYERS = (17, 33)
|
24 |
+
|
25 |
+
TOP_LAYERS_ENC = (3, 9)
|
26 |
+
TOP_LAYERS_DEC = (22, 29)
|
27 |
+
BOTTOM_LAYERS = (9, 22)
|
28 |
+
|
29 |
+
LAYER_RANGES = {'INPUT': (IN_LAYERS),
|
30 |
+
'OUTPUT': (OUT_LAYERS),
|
31 |
+
'ENCODER': (ENCONDER_LAYERS),
|
32 |
+
'DECODER': (DECODER_LAYERS),
|
33 |
+
'TOP': (TOP_LAYERS_ENC, TOP_LAYERS_DEC),
|
34 |
+
'BOTTOM': (BOTTOM_LAYERS)}
|
COMET/format_dataset.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
3 |
+
parentdir = os.path.dirname(currentdir)
|
4 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import nibabel as nib
|
8 |
+
from nilearn.image import resample_img
|
9 |
+
import re
|
10 |
+
import numpy as np
|
11 |
+
from scipy.ndimage import zoom
|
12 |
+
from skimage.measure import regionprops
|
13 |
+
from tqdm import tqdm
|
14 |
+
from argparse import ArgumentParser
|
15 |
+
from scipy.ndimage.morphology import binary_dilation, generate_binary_structure
|
16 |
+
|
17 |
+
import pandas as pd
|
18 |
+
|
19 |
+
from DeepDeformationMapRegistration.utils import constants as C
|
20 |
+
from DeepDeformationMapRegistration.utils.misc import segmentation_cardinal_to_ohe, segmentation_ohe_to_cardinal
|
21 |
+
|
22 |
+
SEGMENTATION_NR2LBL_LUT = {0: 'background',
|
23 |
+
1: 'parenchyma',
|
24 |
+
2: 'vessel'}
|
25 |
+
SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}
|
26 |
+
|
27 |
+
IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Volumes'
|
28 |
+
SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Segmentations' # '/home/jpdefrutos/workspace/LiverSegmentation_UNet3D/data/prediction'
|
29 |
+
|
30 |
+
IMG_NAME_PATTERN = '(.*).nii.gz'
|
31 |
+
SEG_NAME_PATTERN = '(.*).nii.gz'
|
32 |
+
|
33 |
+
OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128'
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
parser = ArgumentParser()
|
38 |
+
parser.add_argument('--crop', action='store_true') # If present, args.crop = True, else args.crop = False
|
39 |
+
parser.add_argument('--offset', type=int, default=C.MAX_AUG_DISP_ISOT + 10, help='Crop offset in mm')
|
40 |
+
parser.add_argument('--dilate-segmentations', type=bool, default=False)
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
|
44 |
+
img_list.sort()
|
45 |
+
|
46 |
+
seg_list = [os.path.join(SEG_DIRECTORY, f) for f in os.listdir(SEG_DIRECTORY) if f.endswith('.nii.gz')]
|
47 |
+
seg_list.sort()
|
48 |
+
|
49 |
+
zoom_file = pd.DataFrame(columns=['scale_i', 'scale_j', 'scale_k'])
|
50 |
+
os.makedirs(OUT_DIRECTORY, exist_ok=True)
|
51 |
+
binary_ball = generate_binary_structure(3, 1)
|
52 |
+
for seg_file in tqdm(seg_list):
|
53 |
+
img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
|
54 |
+
img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')
|
55 |
+
|
56 |
+
img = resample_img(nib.load(img_file), np.eye(3))
|
57 |
+
seg = resample_img(nib.load(seg_file), np.eye(3), interpolation='nearest')
|
58 |
+
|
59 |
+
img = np.asarray(img.dataobj)
|
60 |
+
seg = np.asarray(seg.dataobj)
|
61 |
+
|
62 |
+
segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
|
63 |
+
if args.crop:
|
64 |
+
parenchyma = regionprops(seg[..., 0])[0]
|
65 |
+
bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
|
66 |
+
# check that the new bbox is within the image limits!
|
67 |
+
bbox[:3] = np.maximum(bbox[:3], [0, 0, 0])
|
68 |
+
bbox[3:] = np.minimum(bbox[3:], img.shape)
|
69 |
+
img = img[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
|
70 |
+
seg = seg[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5], ...]
|
71 |
+
# Resize to 128x128x128
|
72 |
+
isot_shape = img.shape
|
73 |
+
|
74 |
+
zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
|
75 |
+
|
76 |
+
img = zoom(img, zoom_factors, order=3)
|
77 |
+
if args.dilate_segmentations:
|
78 |
+
seg = binary_dilation(seg, binary_ball, iterations=1)
|
79 |
+
seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
|
80 |
+
zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
|
81 |
+
'scale_j': zoom_factors[1],
|
82 |
+
'scale_k': zoom_factors[2]}, ignore_index=True)
|
83 |
+
|
84 |
+
# seg -> cardinal
|
85 |
+
# seg_expanded -> OHE
|
86 |
+
if segs_are_ohe:
|
87 |
+
seg_expanded = seg.copy()
|
88 |
+
seg = segmentation_ohe_to_cardinal(seg) # Ordinal encoded. argmax returns the first ocurrence of the maximum. Hence the previoous multiplication operation
|
89 |
+
else:
|
90 |
+
seg_expanded = segmentation_cardinal_to_ohe(seg)
|
91 |
+
|
92 |
+
h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')
|
93 |
+
|
94 |
+
h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
|
95 |
+
h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
|
96 |
+
h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
|
97 |
+
h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:]) # Remove the 0 (background label)
|
98 |
+
h5_file.create_dataset('isotropic_shape', data=isot_shape)
|
99 |
+
|
100 |
+
print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
|
101 |
+
h5_file.close()
|
102 |
+
|
103 |
+
zoom_file.to_csv(os.path.join(OUT_DIRECTORY, 'zoom_factors.csv'))
|
104 |
+
print("Average")
|
105 |
+
print(zoom_file.mean().to_list())
|
106 |
+
|
107 |
+
print("Standard deviation")
|
108 |
+
print(zoom_file.std().to_list())
|
109 |
+
|
110 |
+
print("Average + STD")
|
111 |
+
print((zoom_file.mean() + zoom_file.std()).to_list())
|
112 |
+
|
113 |
+
|
114 |
+
|
COMET/spit_dataset.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from shutil import move, copy2
|
2 |
+
import os
|
3 |
+
|
4 |
+
OR_DIR = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128'
|
5 |
+
val_split = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/For_validation.txt'
|
6 |
+
test_split = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/For_testing.txt'
|
7 |
+
|
8 |
+
# Create out dirs
|
9 |
+
os.makedirs(os.path.join(OR_DIR, 'train'), exist_ok=True)
|
10 |
+
os.makedirs(os.path.join(OR_DIR, 'validation'), exist_ok=True)
|
11 |
+
os.makedirs(os.path.join(OR_DIR, 'test'), exist_ok=True)
|
12 |
+
|
13 |
+
# Copy all to train and then split into validation and test
|
14 |
+
list_of_files = [os.path.join(OR_DIR, f) for f in os.listdir(OR_DIR) if f.endswith('.h5')]
|
15 |
+
list_of_files.sort()
|
16 |
+
for f in list_of_files:
|
17 |
+
copy2(f, os.path.join(OR_DIR, 'train'))
|
18 |
+
|
19 |
+
# Get the indices for the validation and test subsets
|
20 |
+
with open(val_split, 'r') as f:
|
21 |
+
val_idcs = f.readlines()[0]
|
22 |
+
val_idcs = [int(e) for e in val_idcs.split(',')]
|
23 |
+
|
24 |
+
with open(test_split, 'r') as f:
|
25 |
+
test_indcs = f.readlines()[0]
|
26 |
+
test_indcs = [int(e) for e in test_indcs.split(',')]
|
27 |
+
|
28 |
+
# move the files from train to validation and test
|
29 |
+
for i in val_idcs:
|
30 |
+
move(os.path.join(OR_DIR, 'train', '{:05d}_CT.h5'.format(i)), os.path.join(OR_DIR, 'validation'))
|
31 |
+
print('Done moving the validation subset.')
|
32 |
+
|
33 |
+
for i in test_indcs:
|
34 |
+
move(os.path.join(OR_DIR, 'train', '{:05d}_CT.h5'.format(i)), os.path.join(OR_DIR, 'test'))
|
35 |
+
print('Done moving the validation subset.')
|
36 |
+
|
37 |
+
print('Done splitting the data')
|
38 |
+
print('Training samples: '+str(len(os.listdir(os.path.join(OR_DIR, 'train')))))
|
39 |
+
print('Validation samples: '+str(len(val_idcs)))
|
40 |
+
print('Test samples: '+str(len(test_indcs)))
|
COMET/train_config_files/Config_BASELINE_None_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
similarity = ncc
|
3 |
+
segmentation =
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/BASELINE_L_ncc__MET_mse_ncc_ssim_232329-01092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 1
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = BASELINE
|
14 |
+
freeze = None
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/NONE_FROZEN
|
COMET/train_config_files/Config_BASELINE_bottom_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
similarity = ncc
|
3 |
+
segmentation =
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/BASELINE_L_ncc__MET_mse_ncc_ssim_232329-01092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 0
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = BASELINE
|
14 |
+
freeze = BOTTOM
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/BOTTOM_FROZEN
|
COMET/train_config_files/Config_BASELINE_top_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
similarity = ncc
|
3 |
+
segmentation =
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/BASELINE_L_ncc__MET_mse_ncc_ssim_232329-01092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 1
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = BASELINE
|
14 |
+
freeze = TOP
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/TOP_FROZEN
|
COMET/train_config_files/Config_SEGGUIDED_None_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
segmentation = dice
|
3 |
+
similarity = ssim
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/SEGGUIDED_Lsim_ncc__Lseg_dice__MET_mse_ncc_ssim_013319-11092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 1
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = SEGGUIDED
|
14 |
+
freeze = None
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/NONE_FROZEN
|
COMET/train_config_files/Config_SEGGUIDED_bottom_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
segmentation = dice
|
3 |
+
similarity = ssim
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/SEGGUIDED_Lsim_ncc__Lseg_dice__MET_mse_ncc_ssim_013319-11092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 2
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = SEGGUIDED
|
14 |
+
freeze = BOTTOM
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/BOTTOM_FROZEN
|
COMET/train_config_files/Config_SEGGUIDED_top_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
segmentation = dice
|
3 |
+
similarity = ssim
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/SEGGUIDED_Lsim_ncc__Lseg_dice__MET_mse_ncc_ssim_013319-11092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 1
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = SEGGUIDED
|
14 |
+
freeze = TOP
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/TOP_FROZEN
|
COMET/train_config_files/Config_UW_None_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
segmentation = dice,hd
|
3 |
+
similarity = ncc,ssim
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/UW_Lsim_ncc__ssim__Lseg_dice__MET_mse_ncc_ssim_204557-02092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 0
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = UW
|
14 |
+
freeze = None
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/NONE_FROZEN
|
COMET/train_config_files/Config_UW_bottom_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
segmentation = dice,hd
|
3 |
+
similarity = ncc,ssim
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/UW_Lsim_ncc__ssim__Lseg_dice__MET_mse_ncc_ssim_204557-02092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 2
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = UW
|
14 |
+
freeze = BOTTOM
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/BOTTOM_FROZEN
|
COMET/train_config_files/Config_UW_top_froozen.ini
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[LOSSES]
|
2 |
+
segmentation = dice,hd
|
3 |
+
similarity = ncc,ssim
|
4 |
+
|
5 |
+
[TRAIN]
|
6 |
+
model = /mnt/EncryptedData1/Users/javier/train_output/Brain_study/No_gamma/UW_Lsim_ncc__ssim__Lseg_dice__MET_mse_ncc_ssim_204557-02092021/checkpoints/best_model.h5
|
7 |
+
batchSize = 8
|
8 |
+
learningRate = 1e-5
|
9 |
+
accumulativeGradients = 1
|
10 |
+
gpu = 0
|
11 |
+
regularizationWeight = 1e-5
|
12 |
+
epochs = 10000
|
13 |
+
name = UW
|
14 |
+
freeze = TOP
|
15 |
+
|
16 |
+
[DATASETS]
|
17 |
+
train = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train
|
18 |
+
validation = /mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/validation
|
19 |
+
|
20 |
+
[OTHERS]
|
21 |
+
outputFolder = /mnt/EncryptedData1/Users/javier/train_output/COMET/TOP_FROZEN
|