|
import os, sys |
|
|
|
import shutil |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
parentdir = os.path.dirname(currentdir) |
|
sys.path.append(parentdir) |
|
|
|
import tensorflow as tf |
|
|
|
|
|
import numpy as np |
|
import h5py |
|
|
|
import ddmr.utils.constants as C |
|
from ddmr.utils.nifti_utils import save_nifti |
|
from ddmr.layers import AugmentationLayer |
|
from ddmr.utils.visualization import save_disp_map_img, plot_predictions |
|
from ddmr.utils.misc import get_segmentations_centroids, DisplacementMapInterpolator |
|
|
|
from tqdm import tqdm |
|
|
|
from Brain_study.data_generator import BatchGenerator |
|
|
|
from skimage.measure import regionprops |
|
from scipy.interpolate import griddata |
|
|
|
import argparse |
|
|
|
|
|
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training' |
|
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5' |
|
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/' |
|
|
|
POINTS = None |
|
MISSING_CENTROID = np.asarray([[np.nan]*3]) |
|
|
|
|
|
def get_mov_centroids(fix_seg, disp_map, nb_labels=28, exclude_background_lbl=False, brain_study=True, dm_interp=None): |
|
if exclude_background_lbl: |
|
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ..., 1:], ohe=True, expected_lbls=range(1, nb_labels), brain_study=brain_study) |
|
else: |
|
fix_centroids, _ = get_segmentations_centroids(fix_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels), brain_study=brain_study) |
|
if dm_interp is None: |
|
disp = griddata(POINTS, disp_map.reshape([-1, 3]), fix_centroids, method='linear') |
|
else: |
|
disp = dm_interp(disp_map, fix_centroids, backwards=False) |
|
return fix_centroids, fix_centroids + disp, disp |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('-d', '--dir', type=str, help='Directory where to store the files', default='') |
|
parser.add_argument('--reldir', type=str, help='Relative path to dataset, in where to store the files', default='') |
|
parser.add_argument('--gpu', type=int, help='GPU', default=0) |
|
parser.add_argument('--dataset', type=str, help='Dataset to build the test set', default='') |
|
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False) |
|
parser.add_argument('--output_shape', help='If an int, a cubic shape is presumed. Otherwise provide it as a space separated sequence', nargs='+', default=128) |
|
args = parser.parse_args() |
|
|
|
assert args.dataset != '', "Missing original dataset dataset" |
|
if args.dir == '' and args.reldir != '': |
|
OUTPUT_FOLDER_DIR = os.path.join(args.dataset, 'test_dataset') |
|
elif args.dir != '' and args.reldir == '': |
|
OUTPUT_FOLDER_DIR = args.dir |
|
else: |
|
raise ValueError("Either provide 'dir' or 'reldir'") |
|
|
|
if args.erase: |
|
shutil.rmtree(OUTPUT_FOLDER_DIR, ignore_errors=True) |
|
os.makedirs(OUTPUT_FOLDER_DIR, exist_ok=True) |
|
print('DESTINATION FOLDER: ', OUTPUT_FOLDER_DIR) |
|
|
|
DATASET = args.dataset |
|
|
|
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
|
|
|
data_generator = BatchGenerator(DATASET, 1, False, 1.0, False, ['all'], return_isotropic_shape=True) |
|
|
|
img_generator = data_generator.get_train_generator() |
|
nb_labels = len(img_generator.get_segmentation_labels()) |
|
image_input_shape = img_generator.get_data_shape()[-1][:-1] |
|
|
|
if isinstance(args.output_shape, int): |
|
image_output_shape = [args.output_shape] * 3 |
|
elif isinstance(args.output_shape, list): |
|
assert len(args.output_shape) == 3, 'Invalid output shape, expected three values and got {}'.format(len(args.output_shape)) |
|
image_output_shape = [int(s) for s in args.output_shape] |
|
else: |
|
raise ValueError('Invalid output_shape. Must be an int or a space-separated sequence of ints') |
|
print('Scaling to: ', image_output_shape) |
|
|
|
|
|
xx = np.linspace(0, image_output_shape[0], image_output_shape[0], endpoint=False) |
|
yy = np.linspace(0, image_output_shape[1], image_output_shape[2], endpoint=False) |
|
zz = np.linspace(0, image_output_shape[2], image_output_shape[1], endpoint=False) |
|
|
|
xx, yy, zz = np.meshgrid(xx, yy, zz) |
|
|
|
POINTS = np.stack([xx.flatten(), yy.flatten(), zz.flatten()], axis=0).T |
|
|
|
input_augm = tf.keras.Input(shape=img_generator.get_data_shape()[0], name='input_augm') |
|
augm_layer = AugmentationLayer(max_displacement=C.MAX_AUG_DISP, |
|
max_deformation=C.MAX_AUG_DEF, |
|
max_rotation=C.MAX_AUG_ANGLE, |
|
num_control_points=C.NUM_CONTROL_PTS_AUG, |
|
num_augmentations=C.NUM_AUGMENTATIONS, |
|
gamma_augmentation=C.GAMMA_AUGMENTATION, |
|
brightness_augmentation=C.BRIGHTNESS_AUGMENTATION, |
|
in_img_shape=image_input_shape, |
|
out_img_shape=image_output_shape, |
|
only_image=False, |
|
only_resize=False, |
|
trainable=False, |
|
return_displacement_map=True) |
|
augm_model = tf.keras.Model(inputs=input_augm, outputs=augm_layer(input_augm)) |
|
|
|
fix_img_ph = tf.placeholder(dtype=tf.float32, shape=[1,] + image_input_shape + [1+nb_labels,], name='fix_image') |
|
|
|
augmentation_pipeline = augm_model(fix_img_ph) |
|
|
|
config = tf.compat.v1.ConfigProto() |
|
config.gpu_options.allow_growth = True |
|
config.log_device_placement = False |
|
|
|
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=8) |
|
|
|
sess = tf.Session(config=config) |
|
tf.keras.backend.set_session(sess) |
|
with sess.as_default(): |
|
sess.run(tf.global_variables_initializer()) |
|
progress_bar = tqdm(enumerate(img_generator, 1), desc='Generating samples', total=len(img_generator)) |
|
for step, (in_batch, _, isotropic_shape) in progress_bar: |
|
|
|
fix_img, mov_img, fix_seg, mov_seg, disp_map = sess.run(augmentation_pipeline, |
|
feed_dict={'fix_image:0': in_batch}) |
|
|
|
fix_centroids, mov_centroids, disp_centroids = get_mov_centroids(fix_seg, disp_map, nb_labels, dm_interp=dm_interp) |
|
|
|
out_file = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_{:04d}.h5'.format(step)) |
|
out_file_dm = os.path.join(OUTPUT_FOLDER_DIR, 'test_sample_dm_{:04d}.h5'.format(step)) |
|
img_shape = fix_img.shape |
|
segm_shape = fix_seg.shape |
|
disp_shape = disp_map.shape |
|
centroids_shape = fix_centroids.shape |
|
with h5py.File(out_file, 'w') as f: |
|
f.create_dataset('fix_image', shape=img_shape[1:], dtype=np.float32, data=fix_img[0, ...]) |
|
f.create_dataset('mov_image', shape=img_shape[1:], dtype=np.float32, data=mov_img[0, ...]) |
|
f.create_dataset('fix_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=fix_seg[0, ...]) |
|
f.create_dataset('mov_segmentations', shape=segm_shape[1:], dtype=np.uint8, data=mov_seg[0, ...]) |
|
f.create_dataset('fix_centroids', shape=centroids_shape, dtype=np.float32, data=fix_centroids) |
|
f.create_dataset('mov_centroids', shape=centroids_shape, dtype=np.float32, data=mov_centroids) |
|
f.create_dataset('isotropic_shape', data=np.squeeze(isotropic_shape)) |
|
with h5py.File(out_file_dm, 'w') as f: |
|
f.create_dataset('disp_map', shape=disp_shape[1:], dtype=np.float32, data=disp_map) |
|
f.create_dataset('disp_centroids', shape=centroids_shape, dtype=np.float32, data=disp_centroids) |
|
|
|
print('Done') |
|
|