DDMR / Datasets /ircad_dataset.py
andreped's picture
Renamed module to ddmr
a27d55f
import os, sys
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "1" # Check availability before running using 'nvidia-smi'
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)
import multiprocessing as mp
mp.set_start_method('spawn')
import tensorflow as tf
# tf.enable_eager_execution()
import numpy as np
import nibabel as nib
from skimage.transform import resize
from skimage.filters import median
from scipy.ndimage import binary_dilation, generate_binary_structure
from nilearn.image import math_img
import h5py
from tqdm import tqdm
import re
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
from ddmr.utils.cmd_args_parser import parse_arguments
import ddmr.utils.constants as const
from tools.thinPlateSplines_tf import ThinPlateSplines
from keras_model.ext.neuron.layers import SpatialTransformer
from tools.voxelMorph import interpn
from generate_dataset.utils import plot_central_slices, plot_def_map, single_img_gif, two_img_gif, plot_slices, \
crop_images, plot_displacement_map, bbox_3D
from generate_dataset import utils
from tools.misc import try_mkdir
from generate_dataset.utils import unzip_file, delete_temp
DATASTE_RAW_FILES = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/nifti'
LITS_SEGMENTATION_FILE = 'segmentations'
LITS_CT_FILE = 'volume'
IMG_SIZE_LARGE = const.IMG_SHAPE[:-1]
IMG_SIZE_LARGE_x2 = [2 * x for x in const.IMG_SHAPE[:-1]]
FINE_GRID_SHAPE = tuple(x // 1 for x in IMG_SIZE_LARGE_x2) # tuple(np.asarray(IMG_SIZE_LARGE) // 10)
CTRL_GRID = const.CoordinatesGrid()
CTRL_GRID.set_coords_grid(IMG_SIZE_LARGE_x2, [const.TPS_NUM_CTRL_PTS_PER_AXIS, const.TPS_NUM_CTRL_PTS_PER_AXIS,
const.TPS_NUM_CTRL_PTS_PER_AXIS], batches=False, norm=False,
img_type=tf.float32)
FULL_FINE_GRID = const.CoordinatesGrid()
FULL_FINE_GRID.set_coords_grid(IMG_SIZE_LARGE_x2, FINE_GRID_SHAPE, batches=False, norm=False)
OFFSET_NAME_NUM = 0
TH_BIN = 0.50
DILATION_STRUCT = generate_binary_structure(3, 1)
LARGE_PT_DIM = CTRL_GRID.shape_grid_flat + np.asarray([9, 0])
SINGLE_PT_DIM = CTRL_GRID.shape_grid_flat + np.asarray([1, 0])
USE_LARGE_PT = False
ADD_AFFINE_TRF = False
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
config.gpu_options.allow_growth = True
config.log_device_placement = False ## to log device placement (on which device the operation ran)
def tf_graph_translation():
# Place holders
fix_img = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_img')
fix_tumors = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_tumors')
fix_parenchyma = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_parenchyma')
# Apply Affine translation
w = tf.constant(np.random.uniform(-1, 1, 3) * const.MAX_DISP_DM_PERC * IMG_SIZE_LARGE_x2[0], dtype=tf.float32)
pad = tf.cast(tf.abs(w) + 1., tf.int32)
padding = tf.stack([pad, pad], 1)
## PURE TRANSLATION
# Shift the target grid 'w' units
#control_grid = tf.identity(CTRL_GRID.grid_flat())
#trg_grid = tf.add(control_grid, w)
#tps = ThinPlateSplines(control_grid, trg_grid)
#def_grid = tps.interpolate(FULL_FINE_GRID.grid_flat())
## PURE TRANSLATION
def_grid = tf.add(FULL_FINE_GRID.grid_flat(), w)
disp_map = def_grid - FULL_FINE_GRID.grid_flat()
disp_map = tf.reshape(disp_map, (*FINE_GRID_SHAPE, -1))
# disp_map = interpn(disp_map, FULL_FINE_GRID.grid)
# add the batch and channel dimensions
fix_img = tf.pad(fix_img, padding, "CONSTANT", constant_values=0.)
fix_tumors = tf.pad(fix_tumors, padding, "CONSTANT", constant_values=0.)
fix_parenchyma = tf.pad(fix_parenchyma, padding, "CONSTANT", constant_values=0.)
sampl_grid = tf.add(def_grid, tf.cast(pad, def_grid.dtype)) # Because of the padding, the sampling points are now translated 'pad' units
fix_img = tf.expand_dims(fix_img, -1)
fix_tumors = tf.expand_dims(fix_tumors, -1)
fix_parenchyma = tf.expand_dims(fix_parenchyma, -1)
mov_img = interpn(fix_img, sampl_grid, interp_method='linear')
mov_img = tf.squeeze(tf.reshape(mov_img, IMG_SIZE_LARGE_x2))
mov_tumors = interpn(fix_tumors, sampl_grid, interp_method='linear')
mov_tumors = tf.squeeze(tf.reshape(mov_tumors, IMG_SIZE_LARGE_x2))
mov_parenchyma = interpn(fix_parenchyma, sampl_grid, interp_method='linear')
mov_parenchyma = tf.squeeze(tf.reshape(mov_parenchyma, IMG_SIZE_LARGE_x2))
disp_map = tf.cast(disp_map, tf.float32)
return mov_img, mov_parenchyma, mov_tumors, disp_map, w # , w, trg_grid, def_grid
def build_affine_trf(img_size, alpha, beta, gamma, ti, tj, tk):
img_centre = tf.expand_dims(tf.divide(img_size, 2.), -1)
# Rotation matrix around the image centre
# R* = T(p) R(ang) T(-p)
# tf.cos and tf.sin expect radians
zero = tf.zeros((1,))
one = tf.ones((1,))
R = tf.convert_to_tensor([[tf.math.cos(gamma) * tf.math.cos(beta),
tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.sin(alpha) - tf.math.sin(gamma) * tf.math.cos(alpha),
tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.cos(alpha) + tf.math.sin(gamma) * tf.math.sin(alpha),
zero],
[tf.math.sin(gamma) * tf.math.cos(beta),
tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.sin(gamma) + tf.math.cos(gamma) * tf.math.cos(alpha),
tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.cos(gamma) - tf.math.cos(gamma) * tf.math.sin(gamma),
zero],
[-tf.math.sin(beta),
tf.math.cos(beta) * tf.math.sin(alpha),
tf.math.cos(beta) * tf.math.cos(alpha),
zero],
[zero, zero, zero, one]], tf.float32)
R = tf.squeeze(R)
Tc = tf.convert_to_tensor([[one, zero, zero, img_centre[0]],
[zero, one, zero, img_centre[1]],
[zero, zero, one, img_centre[2]],
[zero, zero, zero, one]], tf.float32)
Tc = tf.squeeze(Tc)
Tc_ = tf.convert_to_tensor([[one, zero, zero, -img_centre[0]],
[zero, one, zero, -img_centre[1]],
[zero, zero, one, -img_centre[2]],
[zero, zero, zero, one]], tf.float32)
Tc_ = tf.squeeze(Tc_)
T = tf.convert_to_tensor([[one, zero, zero, ti],
[zero, one, zero, tj],
[zero, zero, one, tk],
[zero, zero, zero, one]], tf.float32)
T = tf.squeeze(T)
return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
def transform_points(points: tf.Tensor):
alpha = tf.random.uniform((1,), -const.MAX_ANGLE_RAD, const.MAX_ANGLE_RAD)
beta = tf.random.uniform((1,), -const.MAX_ANGLE_RAD, const.MAX_ANGLE_RAD)
gamma = tf.random.uniform((1,), -const.MAX_ANGLE_RAD, const.MAX_ANGLE_RAD)
ti = tf.constant(np.random.uniform(-1, 1, 1) * const.MAX_DISP_DM / 2, dtype=tf.float32)
tj = tf.constant(np.random.uniform(-1, 1, 1) * const.MAX_DISP_DM / 2, dtype=tf.float32)
tk = tf.constant(np.random.uniform(-1, 1, 1) * const.MAX_DISP_DM / 2, dtype=tf.float32)
M = build_affine_trf(tf.convert_to_tensor(IMG_SIZE_LARGE_x2, tf.float32), alpha, beta, gamma, ti, tj, tk)
if points.shape.as_list()[-1] == 3:
points = tf.transpose(points)
new_pts = tf.matmul(M[:3, :3], points)
new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
return tf.transpose(new_pts), M # Remove the last row of ones
def tf_graph_deform():
# Place holders
fix_img = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_img')
fix_tumors = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_tumors')
fix_vessels = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_vessels')
fix_parenchyma = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_parenchyma')
large_point = tf.placeholder_with_default(input=tf.constant(False, tf.bool), shape=(), name='large_point')
add_affine = tf.placeholder_with_default(input=tf.constant(False, tf.bool), shape=(), name='add_affine')
search_voxels = tf.cond(tf.equal(tf.reduce_sum(fix_tumors), 0.0),
lambda: fix_parenchyma,
lambda: fix_tumors)
# Apply TPS deformation
# 1. get a point in the label img and add it to the control grid and target grid
idx_points_in_label = tf.where(tf.greater(search_voxels, 0.0)) # Indices of the points in the label image with intensity greater than 0
random_idx = tf.random.uniform([], minval=0, maxval=tf.shape(idx_points_in_label)[0],
dtype=tf.int32) # Randomly select one of the points
disp_location = tf.gather_nd(idx_points_in_label, tf.expand_dims(random_idx, 0)) # And get the coordinates
disp_location = tf.cast(disp_location, tf.float32)
# Get the coordinates of the control point displaces
rand_disp = tf.constant(np.random.uniform(-1, 1, 3) * const.MAX_DISP_DM, dtype=tf.float32)
warped_location = disp_location + rand_disp
def get_box_neighbours(location, radius=3):
n1 = tf.add(rand_disp, tf.constant(np.asarray([radius, radius, radius]), location.dtype))
n2 = tf.add(rand_disp, tf.constant(np.asarray([-radius, radius, radius]), location.dtype))
n3 = tf.add(rand_disp, tf.constant(np.asarray([radius, -radius, radius]), location.dtype))
n4 = tf.add(rand_disp, tf.constant(np.asarray([-radius, -radius, radius]), location.dtype))
n5 = tf.add(rand_disp, tf.constant(np.asarray([radius, radius, -radius]), location.dtype))
n6 = tf.add(rand_disp, tf.constant(np.asarray([-radius, radius, -radius]), location.dtype))
n7 = tf.add(rand_disp, tf.constant(np.asarray([radius, -radius, -radius]), location.dtype))
n8 = tf.add(rand_disp, tf.constant(np.asarray([-radius, -radius, -radius]), location.dtype))
return tf.stack([location, n1, n2, n3, n4, n5, n6, n7, n8], 0)
disp_location, warped_location = tf.cond(large_point,
lambda: (get_box_neighbours(disp_location, 3), get_box_neighbours(warped_location, 3)),
lambda: (tf.expand_dims(rand_disp, 0), tf.expand_dims(warped_location, 0)))
# 2. Add the new point to the control grid and the target grid
control_grid = tf.concat([CTRL_GRID.grid_flat(), disp_location], axis=0)
trg_grid = tf.concat([CTRL_GRID.grid_flat(), warped_location], axis=0)
trg_grid, aff = tf.cond(add_affine,
lambda: transform_points(trg_grid),
lambda: (trg_grid, tf.eye(4, 4)))
# I need to know the shape before running TPS
control_grid.set_shape([73, 3] if USE_LARGE_PT else [65, 3])
trg_grid.set_shape([73, 3] if USE_LARGE_PT else [65, 3])
tps = ThinPlateSplines(control_grid, trg_grid)
def_grid = tps.interpolate(FULL_FINE_GRID.grid_flat())
disp_map = def_grid - FULL_FINE_GRID.grid_flat()
disp_map = tf.reshape(disp_map, (*FINE_GRID_SHAPE, -1))
# disp_map = interpn(disp_map, FULL_FINE_GRID.grid)
# add the batch and channel dimensions
fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
fix_tumors = tf.expand_dims(tf.expand_dims(fix_tumors, -1), 0)
fix_vessels = tf.expand_dims(tf.expand_dims(fix_vessels, -1), 0)
fix_parenchyma = tf.expand_dims(tf.expand_dims(fix_parenchyma, -1), 0)
disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
mov_tumors = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_tumors, disp_map])
mov_vessels = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_vessels, disp_map])
mov_parenchyma = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_parenchyma, disp_map])
mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
return tf.squeeze(mov_img),\
tf.squeeze(mov_parenchyma),\
tf.squeeze(mov_tumors),\
tf.squeeze(mov_vessels),\
tf.squeeze(disp_map),\
disp_location,\
rand_disp,\
aff #, w, trg_grid, def_grid
if __name__ == '__main__':
parse_arguments(sys.argv[1:])
volume_list = [os.path.join(DATASTE_RAW_FILES, f) for f in os.listdir(DATASTE_RAW_FILES) if f.startswith(LITS_CT_FILE)]
volume_list.sort()
segmentation_list = [os.path.join(DATASTE_RAW_FILES, f) for f in os.listdir(DATASTE_RAW_FILES) if
f.startswith(LITS_SEGMENTATION_FILE)]
segmentation_list.sort()
file_path_pairs = [[v, s] for v, s in zip(volume_list, segmentation_list)]
print('Generating HD5 files at {} ...', format(const.DESTINATION_FOLDER))
# with Pool(10) as p, tf.Session(config=config) as sess:
# tqdm(p.map(generate_training_sample, file_path_pairs))
intensity_window_w = 350
intensity_window_l = 40
intensity_clipping_range = intensity_window_l + np.asarray([-intensity_window_w // 2, intensity_window_w // 2],
np.int) # Slicer range for abdominal CT
try_mkdir(const.DESTINATION_FOLDER)
print('PART 1: Deformation')
# Then do the fancy stuff
init = tf.initialize_all_variables()
get_mov_img = tf_graph_deform()
sess = tf.Session(config=config)
with sess.as_default():
sess.run(init)
sess.graph.finalize()
for img_path, labels_path in tqdm(file_path_pairs):
if img_path is not None and labels_path is not None:
#img_path = unzip_file(img_path)
#labels_path = unzip_file(labels_path)
fix_img = nib.load(img_path) # By convention, nibabel world axes are always in RAS+ orientation
img_header = fix_img.header
fix_labels = nib.load(labels_path)
fix_img = np.asarray(fix_img.dataobj)
fix_labels = np.asarray(fix_labels.dataobj)
if fix_labels.shape[-1] < 4:
print('[INF] ' + img_path + ' has no tumor segmentations')
continue
# fix_artery = fix_labels[..., 0]
fix_vessels = fix_labels[..., 1]
fix_parenchyma = fix_labels[..., 2]
fix_tumors = fix_labels[..., 3]
# Clip intensity values
fix_img = utils.intesity_clipping(fix_img, intensity_clipping_range, augment=True)
# Reshape
fix_img = resize(fix_img, IMG_SIZE_LARGE_x2)
fix_parenchyma = resize(fix_parenchyma, IMG_SIZE_LARGE_x2)
fix_tumors = resize(fix_tumors, IMG_SIZE_LARGE_x2)
fix_vessels = resize(fix_vessels, IMG_SIZE_LARGE_x2)
fix_parenchyma = median(fix_parenchyma, np.ones((5, 5, 5)))
# Compute deformation
mov_img, mov_parenchyma, mov_tumors, mov_vessels, disp_map, disp_loc, disp_vec, aff = sess.run(get_mov_img,
feed_dict={
'fix_img:0': fix_img,
'fix_tumors:0': fix_tumors,
'fix_vessels:0': fix_vessels,
'fix_parenchyma:0': fix_parenchyma,
'large_point:0': USE_LARGE_PT,
'add_affine:0': ADD_AFFINE_TRF})
# Cleaning
mov_img = utils.intesity_clipping(mov_img, intensity_clipping_range)
if USE_LARGE_PT:
disp_loc = disp_loc[0, ...]
# Define the bbox around the union of the parenchyma of both volumes, so none falls outside
bbox_mask = np.sign(mov_parenchyma + fix_parenchyma)
bbox_mask = binary_dilation(bbox_mask, DILATION_STRUCT)
bbox_mask = binary_dilation(bbox_mask, DILATION_STRUCT).astype(np.float32)
# The point of application is referred to the whole image coordinate, not to the local BB
min_i, _, min_j, _, min_k, _ = bbox_3D(bbox_mask)
disp_loc = (disp_loc - np.asarray([min_i, min_j, min_k])) / 2
# Crop the image to only contain the liver
# The origin moved according to the mask information. And the images will be resized in a factor of 2!!
fix_img, _ = crop_images(fix_img, bbox_mask, IMG_SIZE_LARGE)
fix_tumors, _ = crop_images(fix_tumors, bbox_mask, IMG_SIZE_LARGE)
fix_vessels, _ = crop_images(fix_vessels, bbox_mask, IMG_SIZE_LARGE)
disp_map, _ = crop_images(disp_map, bbox_mask, IMG_SIZE_LARGE)
fix_parenchyma, _ = crop_images(fix_parenchyma, bbox_mask, IMG_SIZE_LARGE)
# We will later crop even further, so we don't want to downsample too much
# Crop the image to only contain the liver
mov_img, _ = crop_images(mov_img, bbox_mask, IMG_SIZE_LARGE)
mov_tumors, _ = crop_images(mov_tumors, bbox_mask, IMG_SIZE_LARGE)
mov_vessels, _ = crop_images(mov_vessels, bbox_mask, IMG_SIZE_LARGE)
mov_parenchyma, _ = crop_images(mov_parenchyma, bbox_mask, IMG_SIZE_LARGE)
# Just to be sure we have binary masks
fix_tumors[fix_tumors > TH_BIN] = 1.0
fix_tumors[fix_tumors < 1.0] = 0.0
fix_vessels[fix_vessels > TH_BIN] = 1.0
fix_vessels[fix_vessels < 1.0] = 0.0
fix_parenchyma[fix_parenchyma > TH_BIN] = 1.0
fix_parenchyma[fix_parenchyma < 1.0] = 0.0
mov_tumors[mov_tumors > TH_BIN] = 1.0
mov_tumors[mov_tumors < 1.0] = 0.0
mov_vessels[mov_vessels > TH_BIN] = 1.0
mov_vessels[mov_vessels < 1.0] = 0.0
mov_parenchyma[mov_parenchyma > TH_BIN] = 1.0
mov_parenchyma[mov_parenchyma < 1.0] = 0.0
# Save everything
fix_img = np.expand_dims(fix_img, -1)
fix_tumors = np.expand_dims(fix_tumors, -1)
fix_vessels = np.expand_dims(fix_vessels, -1)
fix_parenchyma = np.expand_dims(fix_parenchyma, -1)
fix_segmentations = np.stack([fix_parenchyma, fix_vessels, fix_tumors], -1)
mov_img = np.expand_dims(mov_img, -1)
mov_tumors = np.expand_dims(mov_tumors, -1)
mov_vessels = np.expand_dims(mov_vessels, -1)
mov_parenchyma = np.expand_dims(mov_parenchyma, -1)
# Save everything
file_name = os.path.split(img_path)[-1].split('.')[0]
vol_num = int(re.split('-|_', file_name)[-1])
hd5_filename = 'volume-{:04d}'.format(vol_num + OFFSET_NAME_NUM)
hd5_filename = os.path.join(const.DESTINATION_FOLDER, hd5_filename + '.hd5')
hd5_file = h5py.File(hd5_filename, 'w')
hd5_file.create_dataset(const.H5_FIX_IMG, data=fix_img, dtype='float32')
hd5_file.create_dataset(const.H5_FIX_PARENCHYMA_MASK, data=fix_parenchyma, dtype='float32')
hd5_file.create_dataset(const.H5_FIX_VESSELS_MASK, data=fix_vessels, dtype='float32')
hd5_file.create_dataset(const.H5_FIX_TUMORS_MASK, data=fix_tumors, dtype='float32')
hd5_file.create_dataset(const.H5_FIX_SEGMENTATIONS, data=fix_segmentations, dtype='float32')
hd5_file.create_dataset(const.H5_PARAMS_INTENSITY_RANGE, (2,), data=intensity_clipping_range,
dtype='float32')
hd5_file.create_dataset(const.H5_MOV_IMG, const.IMG_SHAPE, data=mov_img, dtype='float32')
hd5_file.create_dataset(const.H5_MOV_PARENCHYMA_MASK, const.IMG_SHAPE, data=mov_parenchyma,
dtype='float32')
hd5_file.create_dataset(const.H5_MOV_VESSELS_MASK, const.IMG_SHAPE, data=mov_vessels, dtype='float32')
hd5_file.create_dataset(const.H5_MOV_TUMORS_MASK, const.IMG_SHAPE, data=mov_tumors, dtype='float32')
hd5_file.create_dataset(const.H5_MOV_SEGMENTATIONS, data=fix_segmentations, dtype='float32')
hd5_file.create_dataset(const.H5_GT_DISP, const.DISP_MAP_SHAPE, data=disp_map, dtype='float32')
hd5_file.create_dataset(const.H5_GT_DISP_VECT_LOC, data=disp_loc, dtype='float32')
hd5_file.create_dataset(const.H5_GT_DISP_VECT, data=disp_vec, dtype='float32')
hd5_file.create_dataset(const.H5_GT_AFFINE_M, data=aff, dtype='float32')
hd5_file.create_dataset('params/voxel_size', data=img_header.get_zooms()[:3])
hd5_file.create_dataset('params/original_shape', data=img_header.get_data_shape())
hd5_file.create_dataset('params/bbox_origin', data=[min_i, min_j, min_k])
hd5_file.create_dataset('params/first_reshape', data=IMG_SIZE_LARGE_x2)
# delete_temp(img_path)
# delete_temp(labels_path)
hd5_file.close()
sess.close()
print('...Done generating HD5 files')