Commit
·
a290524
1
Parent(s):
dc36465
Working on the clean repo
Browse files- DeepDeformationMapRegistration/main.py +410 -0
- setup.py +26 -0
DeepDeformationMapRegistration/main.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# 1. Image files generator
|
3 |
+
|
4 |
+
# timer start
|
5 |
+
# 2. Preprocess the image
|
6 |
+
# 3. Predict the displacement
|
7 |
+
# timer stop
|
8 |
+
|
9 |
+
# 4. Evaluate the registration: NCC; SSIM; DICE; HD95
|
10 |
+
|
11 |
+
import os, sys
|
12 |
+
|
13 |
+
import shutil
|
14 |
+
import time
|
15 |
+
import tkinter
|
16 |
+
|
17 |
+
import h5py
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
|
20 |
+
currentdir = os.path.dirname(os.path.realpath(__file__))
|
21 |
+
parentdir = os.path.dirname(currentdir)
|
22 |
+
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
|
23 |
+
|
24 |
+
import tensorflow as tf
|
25 |
+
# tf.enable_eager_execution(config=config)
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import pandas as pd
|
29 |
+
import voxelmorph as vxm
|
30 |
+
from voxelmorph.tf.layers import SpatialTransformer
|
31 |
+
|
32 |
+
import DeepDeformationMapRegistration.utils.constants as C
|
33 |
+
from DeepDeformationMapRegistration.utils.operators import min_max_norm, safe_medpy_metric
|
34 |
+
from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
|
35 |
+
from DeepDeformationMapRegistration.layers import AugmentationLayer, UncertaintyWeighting
|
36 |
+
from DeepDeformationMapRegistration.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
|
37 |
+
from DeepDeformationMapRegistration.ms_ssim_tf import MultiScaleStructuralSimilarity
|
38 |
+
from DeepDeformationMapRegistration.utils.acummulated_optimizer import AdamAccumulated
|
39 |
+
from DeepDeformationMapRegistration.utils.visualization import save_disp_map_img, plot_predictions
|
40 |
+
from DeepDeformationMapRegistration.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal, segmentation_cardinal_to_ohe
|
41 |
+
from DeepDeformationMapRegistration.utils.misc import resize_displacement_map, scale_transformation, GaussianFilter
|
42 |
+
import medpy.metric as medpy_metrics
|
43 |
+
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
|
44 |
+
from scipy.interpolate import RegularGridInterpolator
|
45 |
+
from tqdm import tqdm
|
46 |
+
import nibabel as nib
|
47 |
+
from scipy.ndimage import gaussian_filter, zoom
|
48 |
+
|
49 |
+
import h5py
|
50 |
+
import re
|
51 |
+
from Brain_study.data_generator import BatchGenerator
|
52 |
+
|
53 |
+
import argparse
|
54 |
+
|
55 |
+
from skimage.transform import warp
|
56 |
+
import neurite as ne
|
57 |
+
|
58 |
+
import tempfile
|
59 |
+
|
60 |
+
import logging
|
61 |
+
|
62 |
+
MODELS_FILE = {'liver': {'BL-N': './models/liver/bl_ncc.h5',
|
63 |
+
'BL-S': './models/liver/bl_ssim.h5',
|
64 |
+
'SG-ND': './models/liver/sg_ncc_dsc.h5',
|
65 |
+
'SD-NSD': './models/liver/sg_ncc_ssim_dsc.h5',
|
66 |
+
'UW-NSD': './models/liver/uw_ncc_ssim_dsc.h5',
|
67 |
+
'UW-NSDH': './models/liver/uw_ncc_ssim_dsc_hd.h5',
|
68 |
+
},
|
69 |
+
'brain': {'BL-N': './models/brain/bl_ncc.h5',
|
70 |
+
'BL-S': './models/brain/bl_ssim.h5',
|
71 |
+
'SG-ND': './models/brain/sg_ncc_dsc.h5',
|
72 |
+
'SD-NSD': './models/brain/sg_ncc_ssim_dsc.h5',
|
73 |
+
'UW-NSD': './models/brain/uw_ncc_ssim_dsc.h5',
|
74 |
+
'UW-NSDH': './models/brain/uw_ncc_ssim_dsc_hd.h5',
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
parser = argparse.ArgumentParser()
|
80 |
+
parser.add_argument('-f', '--fixed', type=str, help='Path to fixed image file (NIfTI)')
|
81 |
+
parser.add_argument('-m', '--moving', type=str, help='Path to oving image file (NIfTI)')
|
82 |
+
parser.add_argument('-o', '--outputdir', type=str, help='Output directory', default='./Registration_output')
|
83 |
+
parser.add_argument('--gpu', type=int, help='In case of multi-GPU systems, limits the execution to the defined GPU number', default=None)
|
84 |
+
parser.add_argument('--model', type=str, help='Which model to use: BL-N, BL-S, SG-ND, SG-NSD, UW-NSD, UW-NSDH', default='UW-NSD')
|
85 |
+
# parser.add_argument('--brain', type=bool, action='store_true', help='Perform brain MRi registration', default=False)
|
86 |
+
args = parser.parse_args()
|
87 |
+
logger = logging.getLogger(__name__)
|
88 |
+
|
89 |
+
assert os.path.exists(args.fixed), 'Fixed image not found'
|
90 |
+
assert os.path.exists(args.moving), 'Moving image not found'
|
91 |
+
assert args.model in ['BL-N', 'BL-S', 'SG-ND', 'SG-NSD', 'UW-NSD', 'UW-NSDH'], 'Invalid model type'
|
92 |
+
|
93 |
+
if isinstance(args.gpu, int):
|
94 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
95 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
|
96 |
+
|
97 |
+
# Load the file and preprocess it
|
98 |
+
fixed_image = nib.load(args.fixed)
|
99 |
+
moving_image = nib.load(args.moving)
|
100 |
+
|
101 |
+
# TF stuff
|
102 |
+
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
|
103 |
+
config.gpu_options.allow_growth = True
|
104 |
+
config.log_device_placement = False ## to log device placement (on which device the operation ran)
|
105 |
+
config.allow_soft_placement = True
|
106 |
+
|
107 |
+
sess = tf.compat.v1.Session(config=config)
|
108 |
+
tf.compat.v1.keras.backend.set_session(sess)
|
109 |
+
|
110 |
+
# Preprocess data
|
111 |
+
if args.erase:
|
112 |
+
shutil.rmtree(args.outputdir, ignore_errors=True)
|
113 |
+
os.makedirs(args.outputdir, exist_ok=True)
|
114 |
+
lm_output_dir = os.path.join(args.outputdir, 'livermask')
|
115 |
+
os.makedirs(lm_output_dir, exist_ok=True)
|
116 |
+
|
117 |
+
# 1. Run Livermask to get the mask around the liver in both the fixed and moving image
|
118 |
+
logger.info('Getting parenchyma segmentations...')
|
119 |
+
livermask_cmd = "python -m livermaks.livermask --input {} --output {}".format(args.fixed, os.path.join(lm_output_dir, 'fixed.nii.gz'))
|
120 |
+
os.system(livermask_cmd)
|
121 |
+
logger.info('... fixed image done...')
|
122 |
+
livermask_cmd = "python -m livermaks.livermask --input {} --output {}".format(args.moving, os.path.join(lm_output_dir, 'moving.nii.gz'))
|
123 |
+
os.system(livermask_cmd)
|
124 |
+
logger.info('... moving image done.')
|
125 |
+
|
126 |
+
# 2. Crop around the liver
|
127 |
+
# 2.1 Load the segmentations
|
128 |
+
# 2.2 Find the outermost box containing both boxes
|
129 |
+
# 2.3 Crop the fixed and moving images using such boxes
|
130 |
+
# 2.4 Resize the images to the expected input size
|
131 |
+
|
132 |
+
# 3. Build the whole graph
|
133 |
+
|
134 |
+
|
135 |
+
# Loss and metric functions. Common to all models
|
136 |
+
# loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
|
137 |
+
# NCC(image_input_shape).loss,
|
138 |
+
# vxm.losses.MSE().loss,
|
139 |
+
# MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
|
140 |
+
#
|
141 |
+
# metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
|
142 |
+
# NCC(image_input_shape).metric,
|
143 |
+
# vxm.losses.MSE().loss,
|
144 |
+
# MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
|
145 |
+
# GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
|
146 |
+
# HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
|
147 |
+
# GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
|
148 |
+
|
149 |
+
### METRICS GRAPH ###
|
150 |
+
# fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
|
151 |
+
# pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
|
152 |
+
# fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
|
153 |
+
# pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
|
154 |
+
#
|
155 |
+
# ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
|
156 |
+
# ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
|
157 |
+
# mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
|
158 |
+
# ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
|
159 |
+
# dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
|
160 |
+
# hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
|
161 |
+
# dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
|
162 |
+
# hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
|
163 |
+
|
164 |
+
# Needed for VxmDense type of network
|
165 |
+
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
|
166 |
+
|
167 |
+
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
|
168 |
+
|
169 |
+
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
|
170 |
+
print('MODEL LOCATION: ', MODEL_FILE)
|
171 |
+
|
172 |
+
# data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
|
173 |
+
output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
|
174 |
+
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
|
175 |
+
|
176 |
+
print('DESTINATION FOLDER: ', output_folder)
|
177 |
+
|
178 |
+
if args.fullres:
|
179 |
+
output_folder_fr = os.path.join(DATA_ROOT_DIR, args.outdirname, 'full_resolution') # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
|
180 |
+
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
|
181 |
+
if args.erase:
|
182 |
+
shutil.rmtree(output_folder_fr, ignore_errors=True)
|
183 |
+
os.makedirs(output_folder_fr, exist_ok=True)
|
184 |
+
print('DESTINATION FOLDER FULL RESOLUTION: ', output_folder_fr)
|
185 |
+
|
186 |
+
try:
|
187 |
+
network = tf.keras.models.load_model(MODEL_FILE, {'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
|
188 |
+
'VxmDense': vxm.networks.VxmDense,
|
189 |
+
'AdamAccumulated': AdamAccumulated,
|
190 |
+
'loss': loss_fncs,
|
191 |
+
'metric': metric_fncs},
|
192 |
+
compile=False)
|
193 |
+
except ValueError as e:
|
194 |
+
enc_features = [16, 32, 32, 32] # const.ENCODER_FILTERS
|
195 |
+
dec_features = [32, 32, 32, 32, 32, 16, 16] # const.ENCODER_FILTERS[::-1]
|
196 |
+
nb_features = [enc_features, dec_features]
|
197 |
+
if re.search('^UW|SEGGUIDED_', MODEL_FILE):
|
198 |
+
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
|
199 |
+
nb_labels=nb_labels,
|
200 |
+
nb_unet_features=nb_features,
|
201 |
+
int_steps=0,
|
202 |
+
int_downsize=1,
|
203 |
+
seg_downsize=1)
|
204 |
+
else:
|
205 |
+
network = vxm.networks.VxmDense(inshape=image_output_shape,
|
206 |
+
nb_unet_features=nb_features,
|
207 |
+
int_steps=0)
|
208 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
209 |
+
# Record metrics
|
210 |
+
metrics_file = os.path.join(output_folder, 'metrics.csv')
|
211 |
+
with open(metrics_file, 'w') as f:
|
212 |
+
f.write(';'.join(csv_header)+'\n')
|
213 |
+
|
214 |
+
if args.fullres:
|
215 |
+
metrics_file_fr = os.path.join(output_folder_fr, 'metrics.csv')
|
216 |
+
with open(metrics_file_fr, 'w') as f:
|
217 |
+
f.write(';'.join(csv_header) + '\n')
|
218 |
+
|
219 |
+
ssim = ncc = mse = ms_ssim = dice = hd = 0
|
220 |
+
with sess.as_default():
|
221 |
+
sess.run(tf.global_variables_initializer())
|
222 |
+
network.load_weights(MODEL_FILE, by_name=True)
|
223 |
+
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
|
224 |
+
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
|
225 |
+
for step, in_batch in progress_bar:
|
226 |
+
with h5py.File(in_batch, 'r') as f:
|
227 |
+
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
228 |
+
mov_img = f['mov_image'][:][np.newaxis, ...]
|
229 |
+
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
230 |
+
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
231 |
+
fix_centroids = f['fix_centroids'][:]
|
232 |
+
isotropic_shape = f['isotropic_shape'][:]
|
233 |
+
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
|
234 |
+
|
235 |
+
if network.name == 'vxm_dense_semi_supervised_seg':
|
236 |
+
t0 = time.time()
|
237 |
+
pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
|
238 |
+
t1 = time.time()
|
239 |
+
else:
|
240 |
+
t0 = time.time()
|
241 |
+
pred_img, disp_map = network.predict([mov_img, fix_img])
|
242 |
+
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
|
243 |
+
t1 = time.time()
|
244 |
+
|
245 |
+
pred_img = min_max_norm(pred_img)
|
246 |
+
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels+1), brain_study=False)
|
247 |
+
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
|
248 |
+
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
|
249 |
+
|
250 |
+
# Up sample the segmentation masks to isotropic resolution
|
251 |
+
zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
|
252 |
+
pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
|
253 |
+
fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
|
254 |
+
|
255 |
+
pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
256 |
+
fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
|
257 |
+
|
258 |
+
# I need the labels to be OHE to compute the segmentation metrics.
|
259 |
+
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
|
260 |
+
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
261 |
+
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
262 |
+
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
|
263 |
+
|
264 |
+
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
|
265 |
+
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
266 |
+
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
267 |
+
|
268 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img, 'pred_img:0': pred_img})
|
269 |
+
ssim = np.mean(ssim)
|
270 |
+
ms_ssim = ms_ssim[0]
|
271 |
+
|
272 |
+
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
|
273 |
+
fix_centroids_isotropic = fix_centroids * voxel_size
|
274 |
+
# mov_centroids_isotropic = mov_centroids * voxel_size
|
275 |
+
pred_centroids_isotropic = pred_centroids * voxel_size
|
276 |
+
|
277 |
+
# fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
278 |
+
# # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
279 |
+
# pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.COMET_DATASET_iso_to_cubic_scales)
|
280 |
+
# Now we can measure the TRE in mm
|
281 |
+
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
|
282 |
+
tre = np.mean([v for v in tre_array if not np.isnan(v)])
|
283 |
+
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
|
284 |
+
|
285 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1-t0, tre, len(missing_lbls), missing_lbls]
|
286 |
+
with open(metrics_file, 'a') as f:
|
287 |
+
f.write(';'.join(map(str, new_line))+'\n')
|
288 |
+
|
289 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
290 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
291 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
292 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
293 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
294 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
295 |
+
|
296 |
+
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
297 |
+
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
298 |
+
# f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
|
299 |
+
# f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
|
300 |
+
# f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
|
301 |
+
# f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
|
302 |
+
|
303 |
+
# magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
|
304 |
+
# _ = plt.hist(magnitude.flatten())
|
305 |
+
# plt.title('Histogram of disp. magnitudes')
|
306 |
+
# plt.show(block=False)
|
307 |
+
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
308 |
+
# plt.close()
|
309 |
+
|
310 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
|
311 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
|
312 |
+
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
|
313 |
+
|
314 |
+
progress_bar.set_description('SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
|
315 |
+
|
316 |
+
if args.fullres:
|
317 |
+
with h5py.File(list_test_fr_files[step - 1], 'r') as f:
|
318 |
+
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
|
319 |
+
mov_img = f['mov_image'][:][np.newaxis, ...]
|
320 |
+
fix_seg = f['fix_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
321 |
+
mov_seg = f['mov_segmentations'][:][np.newaxis, ...].astype(np.float32)
|
322 |
+
fix_centroids = f['fix_centroids'][:]
|
323 |
+
|
324 |
+
# Up sample the displacement map to the full res
|
325 |
+
trf = scale_transformation(image_output_shape, fix_img.shape[1:-1])
|
326 |
+
disp_map_fr = resize_displacement_map(np.squeeze(disp_map), None, trf)[np.newaxis, ...]
|
327 |
+
disp_map_fr = gaussian_filter(disp_map_fr, 5)
|
328 |
+
# disp_mad_fr = sess.run(smooth_filter, feed_dict={'dm:0': disp_map_fr})
|
329 |
+
|
330 |
+
# Predicted image
|
331 |
+
pred_img_fr = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([mov_img, disp_map_fr]).eval()
|
332 |
+
pred_seg_fr = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([mov_seg, disp_map_fr]).eval()
|
333 |
+
|
334 |
+
# Predicted centroids
|
335 |
+
dm_interp_fr = DisplacementMapInterpolator(fix_img.shape[1:-1], 'griddata', step=2)
|
336 |
+
pred_centroids = dm_interp_fr(disp_map_fr, mov_centroids, backwards=True) + mov_centroids
|
337 |
+
|
338 |
+
# Metrics - segmentation
|
339 |
+
dice = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) / np.sum(fix_seg[..., l]) for l in range(nb_labels)])
|
340 |
+
hd = np.mean(safe_medpy_metric(pred_seg[0, ...], fix_seg[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
|
341 |
+
dice_macro = np.mean([medpy_metrics.dc(pred_seg_fr[..., l], fix_seg[..., l]) for l in range(nb_labels)])
|
342 |
+
|
343 |
+
pred_seg_card_fr = segmentation_ohe_to_cardinal(pred_seg_fr).astype(np.float32)
|
344 |
+
mov_seg_card_fr = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
|
345 |
+
fix_seg_card_fr = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
|
346 |
+
|
347 |
+
# Metrics - image
|
348 |
+
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf],
|
349 |
+
{'fix_img:0': fix_img, 'pred_img:0': pred_img_fr})
|
350 |
+
ssim = np.mean(ssim)
|
351 |
+
ms_ssim = ms_ssim[0]
|
352 |
+
|
353 |
+
# Metrics - registration
|
354 |
+
tre_array = target_registration_error(fix_centroids, pred_centroids, False).eval()
|
355 |
+
|
356 |
+
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, t1 - t0, tre, len(missing_lbls),
|
357 |
+
missing_lbls]
|
358 |
+
with open(metrics_file_fr, 'a') as f:
|
359 |
+
f.write(';'.join(map(str, new_line)) + '\n')
|
360 |
+
|
361 |
+
save_nifti(fix_img[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
362 |
+
save_nifti(mov_img[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
363 |
+
save_nifti(pred_img[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
364 |
+
save_nifti(fix_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
365 |
+
save_nifti(mov_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
366 |
+
save_nifti(pred_seg_card[0, ...], os.path.join(output_folder_fr, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
|
367 |
+
|
368 |
+
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
|
369 |
+
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
|
370 |
+
# f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
|
371 |
+
# f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
|
372 |
+
# f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
|
373 |
+
# f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
|
374 |
+
|
375 |
+
# magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
|
376 |
+
# _ = plt.hist(magnitude.flatten())
|
377 |
+
# plt.title('Histogram of disp. magnitudes')
|
378 |
+
# plt.show(block=False)
|
379 |
+
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
|
380 |
+
# plt.close()
|
381 |
+
|
382 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, seg_batches=[fix_seg_card_fr, mov_seg_card_fr, pred_seg_card_fr], filename=os.path.join(output_folder_fr, '{:03d}_figures_seg.png'.format(step)), show=False, step=10)
|
383 |
+
plot_predictions(img_batches=[fix_img, mov_img, pred_img_fr], disp_map_batch=disp_map_fr, filename=os.path.join(output_folder_fr, '{:03d}_figures_img.png'.format(step)), show=False, step=10)
|
384 |
+
# save_disp_map_img(disp_map_fr, 'Displacement map', os.path.join(output_folder_fr, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=10)
|
385 |
+
|
386 |
+
progress_bar.set_description('[FR] SSIM {:.04f}\tM_DICE: {:.04f}'.format(ssim, dice_macro))
|
387 |
+
|
388 |
+
print('Summary\n=======\n')
|
389 |
+
metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
|
390 |
+
print('\nAVG:\n')
|
391 |
+
print(metrics_df.mean(axis=0))
|
392 |
+
print('\nSTD:\n')
|
393 |
+
print(metrics_df.std(axis=0))
|
394 |
+
print('\nHD95perc:\n')
|
395 |
+
print(metrics_df['HD'].describe(percentiles=[.95]))
|
396 |
+
print('\n=======\n')
|
397 |
+
if args.fullres:
|
398 |
+
print('Summary full resolution\n=======\n')
|
399 |
+
metrics_df = pd.read_csv(metrics_file_fr, sep=';', header=0)
|
400 |
+
print('\nAVG:\n')
|
401 |
+
print(metrics_df.mean(axis=0))
|
402 |
+
print('\nSTD:\n')
|
403 |
+
print(metrics_df.std(axis=0))
|
404 |
+
print('\nHD95perc:\n')
|
405 |
+
print(metrics_df['HD'].describe(percentiles=[.95]))
|
406 |
+
print('\n=======\n')
|
407 |
+
tf.keras.backend.clear_session()
|
408 |
+
# sess.close()
|
409 |
+
del network
|
410 |
+
print('Done')
|
setup.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages, setup
|
2 |
+
import os
|
3 |
+
|
4 |
+
entry_points = {'console_script':['DeepDeformationMapRegistration=DeepDeformationMapRegistration.main:main']}
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name='DeepDeformationMapRegistration',
|
8 |
+
py_modules=['DeepDeformationMapRegistration'],
|
9 |
+
packages=find_packages(include=['DeepDeformationMapRegistration', 'DeepDeformationMapRegistration.*']),
|
10 |
+
version='1.0',
|
11 |
+
description='Deep-registration training toolkit',
|
12 |
+
author='Javier Pérez de Frutos',
|
13 |
+
classifiers=[
|
14 |
+
'Programming language :: Python :: 3',
|
15 |
+
'License :: OSI Approveed :: MIT License',
|
16 |
+
'Operating System :: OS Independent'
|
17 |
+
],
|
18 |
+
python_requires='>=3.6',
|
19 |
+
install_requires=[
|
20 |
+
'tensorflow-gpu==1.14.0',
|
21 |
+
'tensorboard==1.14.0',
|
22 |
+
'nibabel==3.2.1',
|
23 |
+
'numpy==1.18.5',
|
24 |
+
'livermask'
|
25 |
+
]
|
26 |
+
)
|