jpdefrutos commited on
Commit
ed5ac4a
·
1 Parent(s): 476daa5

3D IRCA dataset formatting scripts

Browse files
Datasets/check_dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import DeepDeformationMapRegistration.utils.constants as C
6
+
7
+
8
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9
+ LITS_NONE = '/mnt/EncryptedData1/Users/javier/vessel_registration/LiTS/None'
10
+ LITS_TRANS = '/mnt/EncryptedData1/Users/javier/vessel_registration/LiTS/Translation'
11
+ LITS_AFFINE = '/mnt/EncryptedData1/Users/javier/vessel_registration/LiTS/Affine'
12
+
13
+
14
+ IMG_SHAPE = (64, 64, 64, 1)
15
+ for dataset in [LITS_NONE, LITS_AFFINE, LITS_TRANS]:
16
+ dataset_files = [os.path.join(dataset, d) for d in os.listdir(dataset) if os.path.isfile(os.path.join(dataset, d))]
17
+ f_iter = tqdm(dataset_files)
18
+ f_iter.set_description('Analyzing ' + dataset)
19
+ inv_shape_count = 0
20
+ inv_type_count = 0
21
+ for i, d in enumerate(f_iter):
22
+ f = h5py.File(d, 'r')
23
+ if f[C.H5_FIX_IMG][:].shape != IMG_SHAPE:
24
+ print(d + ' Invalid FIX IMG. Shape: ' + str(f[C.H5_FIX_IMG][:].shape))
25
+ inv_shape_count += 1
26
+ if f[C.H5_MOV_IMG][:].shape != IMG_SHAPE:
27
+ print(d + ' Invalid MOV IMG. Shape: ' + str(f[C.H5_MOV_IMG][:].shape))
28
+ inv_shape_count += 1
29
+ if f[C.H5_FIX_PARENCHYMA_MASK][:].shape != IMG_SHAPE:
30
+ print(d + ' Invalid FIX PARENCHYMA. Shape: ' + str(f[C.H5_FIX_PARENCHYMA_MASK][:].shape))
31
+ inv_shape_count += 1
32
+ if f[C.H5_MOV_PARENCHYMA_MASK][:].shape != IMG_SHAPE:
33
+ print(d + ' Invalid MOV PARENCHYMA. Shape: ' + str(f[C.H5_MOV_PARENCHYMA_MASK][:].shape))
34
+ inv_shape_count += 1
35
+ if f[C.H5_FIX_TUMORS_MASK][:].shape != IMG_SHAPE:
36
+ print(d + ' Invalid FIX TUMORS. Shape: ' + str(f[C.H5_FIX_TUMORS_MASK][:].shape))
37
+ inv_shape_count += 1
38
+ if f[C.H5_MOV_TUMORS_MASK][:].shape != IMG_SHAPE:
39
+ print(d + ' Invalid MOV TUMORS. Shape: ' + str(f[C.H5_MOV_TUMORS_MASK][:].shape))
40
+ inv_shape_count += 1
41
+
42
+ if f[C.H5_FIX_IMG][:].dtype != np.float32:
43
+ print(d + ' Invalid FIX IMG. Type: ' + str(f[C.H5_FIX_IMG][:].dtype))
44
+ inv_type_count += 1
45
+ if f[C.H5_MOV_IMG][:].dtype != np.float32:
46
+ print(d + ' Invalid MOV IMG. Type: ' + str(f[C.H5_MOV_IMG][:].dtype))
47
+ inv_type_count += 1
48
+ if f[C.H5_FIX_PARENCHYMA_MASK][:].dtype != np.float32:
49
+ print(d + ' Invalid FIX PARENCHYMA. Type: ' + str(f[C.H5_FIX_PARENCHYMA_MASK][:].dtype))
50
+ inv_type_count += 1
51
+ if f[C.H5_MOV_PARENCHYMA_MASK][:].dtype != np.float32:
52
+ print(d + ' Invalid MOV PARENCHYMA. Type: ' + str(f[C.H5_MOV_PARENCHYMA_MASK][:].dtype))
53
+ inv_type_count += 1
54
+ if f[C.H5_FIX_TUMORS_MASK][:].dtype != np.float32:
55
+ print(d + ' Invalid FIX TUMORS. Type: ' + str(f[C.H5_FIX_TUMORS_MASK][:].dtype))
56
+ inv_type_count += 1
57
+ if f[C.H5_MOV_TUMORS_MASK][:].dtype != np.float32:
58
+ print(d + ' Invalid MOV TUMORS. Type: ' + str(f[C.H5_MOV_TUMORS_MASK][:].dtype))
59
+ inv_type_count += 1
60
+
61
+ print('\n\n>>>>SUMMARY ' + dataset)
62
+ print('\t\tInvalid shape: ' + str(inv_shape_count) + '\n\t\tInvalid type: ' + str(inv_type_count))
Datasets/irca_pre_processing.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ currentdir = os.path.dirname(os.path.realpath(__file__))
3
+ parentdir = os.path.dirname(currentdir)
4
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
5
+
6
+ import h5py
7
+ import numpy as np
8
+ import zipfile
9
+ import re
10
+ import dicom2nifti as d2n
11
+ import nibabel as nib
12
+ from DeepDeformationMapRegistration.utils.nifti_utils import save_nifti
13
+ from DeepDeformationMapRegistration.utils.misc import try_mkdir
14
+ from tqdm import tqdm
15
+ import shutil
16
+
17
+
18
+ IRCA_PATH = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/'
19
+
20
+ SEGMENTATIONS = ('venoussystem', 'venacava', 'portalvein', 'liver', 'livertumor')
21
+
22
+ SEGS_VESSELS = ('venoussystem', 'venacava', 'portalvein')
23
+ SEGS_PARENCH = ('liver',)
24
+ SEGS_TUMOR = ('livertumor', 'tumor', 'liverkyst', 'livercyst')
25
+
26
+ SEGMENTATIONS = (SEGS_PARENCH + SEGS_VESSELS + SEGS_TUMOR)
27
+
28
+ DEST_FOLDER = os.path.join(IRCA_PATH, 'nifti3')
29
+
30
+ ZIP_EXT = '.zip'
31
+ NIFTI_EXT = '.nii.gz'
32
+ H5_EXT = '.hd5'
33
+
34
+ PATIENT_DICOM = 'PATIENT_DICOM'
35
+ TEMP = 'temp'
36
+ SEGS_DICOM = 'MASKS_DICOM'
37
+
38
+ CONVERTED_FILE = 'none' + NIFTI_EXT
39
+ VOL_FILE = 'volume-{:04d}' + NIFTI_EXT
40
+ SEG_FILE = 'segmentation-{:04d}' + NIFTI_EXT
41
+
42
+ SEG_TO_CT_ORIENTATION_MAT = np.eye(4)
43
+ SEG_TO_CT_ORIENTATION_MAT[0] = -1
44
+
45
+
46
+ def merge_segmentations(file_list):
47
+ nib_file = nib.concat_images(file_list)
48
+ np_file = np.asarray(nib_file.dataobj)
49
+ np_file = np.sign(np.sum(np_file, -1)) * np.max(np_file)
50
+ return nib.Nifti1Image(np_file, nib_file.affine)
51
+
52
+
53
+ if __name__ == '__main__':
54
+ # 1. List of folders
55
+ folder_list = [os.path.join(IRCA_PATH, d) for d in os.listdir(IRCA_PATH) if d.lower().startswith('3dircadb1.')]
56
+ folder_list.sort()
57
+
58
+ try_mkdir(DEST_FOLDER)
59
+ folder_iter = tqdm(folder_list)
60
+ for pat_dir in folder_iter:
61
+ pat_dir = folder_list[13]
62
+ i = int(pat_dir.split('.')[-1])
63
+ # 2. Unzip PATIENT_DICOM.zip
64
+ temp_folder = os.path.join(pat_dir, TEMP)
65
+ folder_iter.set_description('Volume DICOM: Unzipping PATIENT_DICOM.zip')
66
+ zipfile.ZipFile(os.path.join(pat_dir, PATIENT_DICOM + ZIP_EXT)).extractall(temp_folder)
67
+
68
+ folder_iter.set_description('Volume DICOM: Converting DICOM to Nifti')
69
+ d2n.convert_directory(os.path.join(temp_folder, PATIENT_DICOM), os.path.join(temp_folder, PATIENT_DICOM))
70
+ os.rename(os.path.join(temp_folder, PATIENT_DICOM, CONVERTED_FILE), os.path.join(DEST_FOLDER, VOL_FILE.format(i)))
71
+
72
+ folder_iter.set_description('Volume DICOM: CT stored in: ' + os.path.join(DEST_FOLDER, VOL_FILE.format(i)))
73
+ # os.rename also moves the file to the destination path. So the original one ceases to exist
74
+
75
+ # 3. Unzip MASKS_DICOM.zip
76
+ folder_iter.set_description('Segmentations DICOM: Unzipping MASKS_DICOM.zip')
77
+ zipfile.ZipFile(os.path.join(pat_dir, SEGS_DICOM + ZIP_EXT)).extractall(temp_folder)
78
+ seg_nib = list()
79
+ seg_ves = list()
80
+ seg_par = list()
81
+ seg_tumor = list()
82
+ seg_dirs = list()
83
+ for root, dir_list, file_list in os.walk(os.path.join(temp_folder, SEGS_DICOM)):
84
+ for fold in dir_list:
85
+ if fold.startswith(SEGMENTATIONS):
86
+ # if 'liverkyst' in fold:
87
+ # continue
88
+ # else:
89
+ seg_dirs.append(fold)
90
+
91
+ seg_dirs.sort()
92
+ for fold in seg_dirs:
93
+ folder_iter.set_description('Segmentations DICOM: Converting ' + fold)
94
+ d2n.convert_directory(os.path.join(temp_folder, SEGS_DICOM, fold),
95
+ os.path.join(temp_folder, SEGS_DICOM))
96
+ os.rename(os.path.join(temp_folder, SEGS_DICOM, CONVERTED_FILE),
97
+ os.path.join(temp_folder, SEGS_DICOM, fold + '_nifti_' + NIFTI_EXT))
98
+ if fold.startswith(SEGS_VESSELS):
99
+ seg_ves.append(os.path.join(temp_folder, SEGS_DICOM, fold + '_nifti_' + NIFTI_EXT))
100
+ elif fold.startswith(SEGS_TUMOR):
101
+ seg_tumor.append(os.path.join(temp_folder, SEGS_DICOM, fold + '_nifti_' + NIFTI_EXT))
102
+ elif fold.startswith(SEGS_PARENCH):
103
+ seg_par.append(os.path.join(temp_folder, SEGS_DICOM, fold + '_nifti_' + NIFTI_EXT))
104
+ else:
105
+ continue
106
+
107
+ folder_iter.set_description('Segmentations DICOM: Concatenating segmentations')
108
+ # Merge the vessel segmentations
109
+
110
+ segs_to_merge = tuple()
111
+ if len(seg_par) > 1:
112
+ segs_to_merge += (merge_segmentations(seg_par),)
113
+ else:
114
+ segs_to_merge += tuple(seg_par) # seg_par is a list
115
+
116
+ if len(seg_ves) > 1:
117
+ segs_to_merge += (merge_segmentations(seg_ves),)
118
+ else:
119
+ segs_to_merge += tuple(seg_ves) # seg_ves is a list
120
+
121
+ if len(seg_tumor) > 1:
122
+ segs_to_merge += (merge_segmentations(seg_tumor),)
123
+ else:
124
+ segs_to_merge += tuple(seg_tumor) # seg_tumor is a list
125
+
126
+ # # Merge the tumors segmentations
127
+ # if len(seg_tumor):
128
+ # segs_to_merge.append(merge_segmentations(seg_tumor))
129
+ # else:
130
+ # print('No tumors found in ' + pat_dir)
131
+
132
+
133
+ # Merge with the parenchyma and save
134
+ folder_iter.set_description('Segmentations DICOM: Saving segmentations')
135
+ if len(segs_to_merge) > 1:
136
+ nib.save(nib.concat_images(segs_to_merge, check_affines=True), os.path.join(DEST_FOLDER, SEG_FILE.format(i)))
137
+ else:
138
+ nib.save(segs_to_merge[0], os.path.join(DEST_FOLDER, SEG_FILE.format(i)))
139
+
140
+ folder_iter.set_description('Segmentations DICOM: Segmentation stored in ' + os.path.join(DEST_FOLDER, SEG_FILE.format(i)))
141
+
142
+ shutil.rmtree(temp_folder)
143
+ folder_iter.set_description('Temporal file deleted')
144
+ # 4. Load DICOM and transform to nifti
145
+ # 5. Store as nifty in hd5
Datasets/ircad_dataset.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
4
+ os.environ['CUDA_VISIBLE_DEVICES'] = "1" # Check availability before running using 'nvidia-smi'
5
+ currentdir = os.path.dirname(os.path.realpath(__file__))
6
+ parentdir = os.path.dirname(currentdir)
7
+ sys.path.append(parentdir)
8
+ import multiprocessing as mp
9
+
10
+ mp.set_start_method('spawn')
11
+
12
+ import tensorflow as tf
13
+
14
+ # tf.enable_eager_execution()
15
+
16
+ import numpy as np
17
+ import nibabel as nib
18
+ from skimage.transform import resize
19
+ from skimage.filters import median
20
+ from scipy.ndimage import binary_dilation, generate_binary_structure
21
+ from nilearn.image import math_img
22
+ import h5py
23
+ from tqdm import tqdm
24
+ import re
25
+
26
+ currentdir = os.path.dirname(os.path.realpath(__file__))
27
+ parentdir = os.path.dirname(currentdir)
28
+ sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
29
+ from DeepDeformationMapRegistration.utils.cmd_args_parser import parse_arguments
30
+ import DeepDeformationMapRegistration.utils.constants as const
31
+ from tools.thinPlateSplines_tf import ThinPlateSplines
32
+ from keras_model.ext.neuron.layers import SpatialTransformer
33
+ from tools.voxelMorph import interpn
34
+ from generate_dataset.utils import plot_central_slices, plot_def_map, single_img_gif, two_img_gif, plot_slices, \
35
+ crop_images, plot_displacement_map, bbox_3D
36
+ from generate_dataset import utils
37
+ from tools.misc import try_mkdir
38
+ from generate_dataset.utils import unzip_file, delete_temp
39
+
40
+ DATASTE_RAW_FILES = '/mnt/EncryptedData1/Users/javier/vessel_registration/3Dirca/nifti'
41
+ LITS_SEGMENTATION_FILE = 'segmentations'
42
+ LITS_CT_FILE = 'volume'
43
+
44
+ IMG_SIZE_LARGE = const.IMG_SHAPE[:-1]
45
+ IMG_SIZE_LARGE_x2 = [2 * x for x in const.IMG_SHAPE[:-1]]
46
+ FINE_GRID_SHAPE = tuple(x // 1 for x in IMG_SIZE_LARGE_x2) # tuple(np.asarray(IMG_SIZE_LARGE) // 10)
47
+ CTRL_GRID = const.CoordinatesGrid()
48
+ CTRL_GRID.set_coords_grid(IMG_SIZE_LARGE_x2, [const.TPS_NUM_CTRL_PTS_PER_AXIS, const.TPS_NUM_CTRL_PTS_PER_AXIS,
49
+ const.TPS_NUM_CTRL_PTS_PER_AXIS], batches=False, norm=False,
50
+ img_type=tf.float32)
51
+
52
+ FULL_FINE_GRID = const.CoordinatesGrid()
53
+ FULL_FINE_GRID.set_coords_grid(IMG_SIZE_LARGE_x2, FINE_GRID_SHAPE, batches=False, norm=False)
54
+
55
+ OFFSET_NAME_NUM = 0
56
+
57
+ TH_BIN = 0.50
58
+
59
+ DILATION_STRUCT = generate_binary_structure(3, 1)
60
+
61
+ LARGE_PT_DIM = CTRL_GRID.shape_grid_flat + np.asarray([9, 0])
62
+ SINGLE_PT_DIM = CTRL_GRID.shape_grid_flat + np.asarray([1, 0])
63
+ USE_LARGE_PT = False
64
+ ADD_AFFINE_TRF = False
65
+
66
+ config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
67
+ config.gpu_options.allow_growth = True
68
+ config.log_device_placement = False ## to log device placement (on which device the operation ran)
69
+
70
+
71
+
72
+ def tf_graph_translation():
73
+ # Place holders
74
+ fix_img = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_img')
75
+ fix_tumors = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_tumors')
76
+ fix_parenchyma = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_parenchyma')
77
+
78
+ # Apply Affine translation
79
+ w = tf.constant(np.random.uniform(-1, 1, 3) * const.MAX_DISP_DM_PERC * IMG_SIZE_LARGE_x2[0], dtype=tf.float32)
80
+ pad = tf.cast(tf.abs(w) + 1., tf.int32)
81
+ padding = tf.stack([pad, pad], 1)
82
+ ## PURE TRANSLATION
83
+ # Shift the target grid 'w' units
84
+ #control_grid = tf.identity(CTRL_GRID.grid_flat())
85
+ #trg_grid = tf.add(control_grid, w)
86
+
87
+ #tps = ThinPlateSplines(control_grid, trg_grid)
88
+ #def_grid = tps.interpolate(FULL_FINE_GRID.grid_flat())
89
+ ## PURE TRANSLATION
90
+
91
+ def_grid = tf.add(FULL_FINE_GRID.grid_flat(), w)
92
+ disp_map = def_grid - FULL_FINE_GRID.grid_flat()
93
+ disp_map = tf.reshape(disp_map, (*FINE_GRID_SHAPE, -1))
94
+ # disp_map = interpn(disp_map, FULL_FINE_GRID.grid)
95
+
96
+ # add the batch and channel dimensions
97
+ fix_img = tf.pad(fix_img, padding, "CONSTANT", constant_values=0.)
98
+ fix_tumors = tf.pad(fix_tumors, padding, "CONSTANT", constant_values=0.)
99
+ fix_parenchyma = tf.pad(fix_parenchyma, padding, "CONSTANT", constant_values=0.)
100
+
101
+ sampl_grid = tf.add(def_grid, tf.cast(pad, def_grid.dtype)) # Because of the padding, the sampling points are now translated 'pad' units
102
+ fix_img = tf.expand_dims(fix_img, -1)
103
+ fix_tumors = tf.expand_dims(fix_tumors, -1)
104
+ fix_parenchyma = tf.expand_dims(fix_parenchyma, -1)
105
+
106
+ mov_img = interpn(fix_img, sampl_grid, interp_method='linear')
107
+ mov_img = tf.squeeze(tf.reshape(mov_img, IMG_SIZE_LARGE_x2))
108
+
109
+ mov_tumors = interpn(fix_tumors, sampl_grid, interp_method='linear')
110
+ mov_tumors = tf.squeeze(tf.reshape(mov_tumors, IMG_SIZE_LARGE_x2))
111
+
112
+ mov_parenchyma = interpn(fix_parenchyma, sampl_grid, interp_method='linear')
113
+ mov_parenchyma = tf.squeeze(tf.reshape(mov_parenchyma, IMG_SIZE_LARGE_x2))
114
+
115
+ disp_map = tf.cast(disp_map, tf.float32)
116
+ return mov_img, mov_parenchyma, mov_tumors, disp_map, w # , w, trg_grid, def_grid
117
+
118
+
119
+ def build_affine_trf(img_size, alpha, beta, gamma, ti, tj, tk):
120
+ img_centre = tf.expand_dims(tf.divide(img_size, 2.), -1)
121
+
122
+ # Rotation matrix around the image centre
123
+ # R* = T(p) R(ang) T(-p)
124
+ # tf.cos and tf.sin expect radians
125
+ zero = tf.zeros((1,))
126
+ one = tf.ones((1,))
127
+ R = tf.convert_to_tensor([[tf.math.cos(gamma) * tf.math.cos(beta),
128
+ tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.sin(alpha) - tf.math.sin(gamma) * tf.math.cos(alpha),
129
+ tf.math.cos(gamma) * tf.math.sin(beta) * tf.math.cos(alpha) + tf.math.sin(gamma) * tf.math.sin(alpha),
130
+ zero],
131
+ [tf.math.sin(gamma) * tf.math.cos(beta),
132
+ tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.sin(gamma) + tf.math.cos(gamma) * tf.math.cos(alpha),
133
+ tf.math.sin(gamma) * tf.math.sin(beta) * tf.math.cos(gamma) - tf.math.cos(gamma) * tf.math.sin(gamma),
134
+ zero],
135
+ [-tf.math.sin(beta),
136
+ tf.math.cos(beta) * tf.math.sin(alpha),
137
+ tf.math.cos(beta) * tf.math.cos(alpha),
138
+ zero],
139
+ [zero, zero, zero, one]], tf.float32)
140
+ R = tf.squeeze(R)
141
+
142
+ Tc = tf.convert_to_tensor([[one, zero, zero, img_centre[0]],
143
+ [zero, one, zero, img_centre[1]],
144
+ [zero, zero, one, img_centre[2]],
145
+ [zero, zero, zero, one]], tf.float32)
146
+ Tc = tf.squeeze(Tc)
147
+ Tc_ = tf.convert_to_tensor([[one, zero, zero, -img_centre[0]],
148
+ [zero, one, zero, -img_centre[1]],
149
+ [zero, zero, one, -img_centre[2]],
150
+ [zero, zero, zero, one]], tf.float32)
151
+ Tc_ = tf.squeeze(Tc_)
152
+
153
+ T = tf.convert_to_tensor([[one, zero, zero, ti],
154
+ [zero, one, zero, tj],
155
+ [zero, zero, one, tk],
156
+ [zero, zero, zero, one]], tf.float32)
157
+ T = tf.squeeze(T)
158
+
159
+ return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
160
+
161
+
162
+ def transform_points(points: tf.Tensor):
163
+ alpha = tf.random.uniform((1,), -const.MAX_ANGLE_RAD, const.MAX_ANGLE_RAD)
164
+ beta = tf.random.uniform((1,), -const.MAX_ANGLE_RAD, const.MAX_ANGLE_RAD)
165
+ gamma = tf.random.uniform((1,), -const.MAX_ANGLE_RAD, const.MAX_ANGLE_RAD)
166
+
167
+ ti = tf.constant(np.random.uniform(-1, 1, 1) * const.MAX_DISP_DM / 2, dtype=tf.float32)
168
+ tj = tf.constant(np.random.uniform(-1, 1, 1) * const.MAX_DISP_DM / 2, dtype=tf.float32)
169
+ tk = tf.constant(np.random.uniform(-1, 1, 1) * const.MAX_DISP_DM / 2, dtype=tf.float32)
170
+
171
+ M = build_affine_trf(tf.convert_to_tensor(IMG_SIZE_LARGE_x2, tf.float32), alpha, beta, gamma, ti, tj, tk)
172
+ if points.shape.as_list()[-1] == 3:
173
+ points = tf.transpose(points)
174
+ new_pts = tf.matmul(M[:3, :3], points)
175
+ new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
176
+ return tf.transpose(new_pts), M # Remove the last row of ones
177
+
178
+
179
+ def tf_graph_deform():
180
+ # Place holders
181
+ fix_img = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_img')
182
+ fix_tumors = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_tumors')
183
+ fix_vessels = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_vessels')
184
+ fix_parenchyma = tf.placeholder(tf.float32, IMG_SIZE_LARGE_x2, 'fix_parenchyma')
185
+ large_point = tf.placeholder_with_default(input=tf.constant(False, tf.bool), shape=(), name='large_point')
186
+ add_affine = tf.placeholder_with_default(input=tf.constant(False, tf.bool), shape=(), name='add_affine')
187
+
188
+ search_voxels = tf.cond(tf.equal(tf.reduce_sum(fix_tumors), 0.0),
189
+ lambda: fix_parenchyma,
190
+ lambda: fix_tumors)
191
+
192
+ # Apply TPS deformation
193
+ # 1. get a point in the label img and add it to the control grid and target grid
194
+ idx_points_in_label = tf.where(tf.greater(search_voxels, 0.0)) # Indices of the points in the label image with intensity greater than 0
195
+
196
+ random_idx = tf.random.uniform([], minval=0, maxval=tf.shape(idx_points_in_label)[0],
197
+ dtype=tf.int32) # Randomly select one of the points
198
+ disp_location = tf.gather_nd(idx_points_in_label, tf.expand_dims(random_idx, 0)) # And get the coordinates
199
+ disp_location = tf.cast(disp_location, tf.float32)
200
+ # Get the coordinates of the control point displaces
201
+ rand_disp = tf.constant(np.random.uniform(-1, 1, 3) * const.MAX_DISP_DM, dtype=tf.float32)
202
+ warped_location = disp_location + rand_disp
203
+
204
+ def get_box_neighbours(location, radius=3):
205
+ n1 = tf.add(rand_disp, tf.constant(np.asarray([radius, radius, radius]), location.dtype))
206
+ n2 = tf.add(rand_disp, tf.constant(np.asarray([-radius, radius, radius]), location.dtype))
207
+ n3 = tf.add(rand_disp, tf.constant(np.asarray([radius, -radius, radius]), location.dtype))
208
+ n4 = tf.add(rand_disp, tf.constant(np.asarray([-radius, -radius, radius]), location.dtype))
209
+ n5 = tf.add(rand_disp, tf.constant(np.asarray([radius, radius, -radius]), location.dtype))
210
+ n6 = tf.add(rand_disp, tf.constant(np.asarray([-radius, radius, -radius]), location.dtype))
211
+ n7 = tf.add(rand_disp, tf.constant(np.asarray([radius, -radius, -radius]), location.dtype))
212
+ n8 = tf.add(rand_disp, tf.constant(np.asarray([-radius, -radius, -radius]), location.dtype))
213
+
214
+ return tf.stack([location, n1, n2, n3, n4, n5, n6, n7, n8], 0)
215
+
216
+ disp_location, warped_location = tf.cond(large_point,
217
+ lambda: (get_box_neighbours(disp_location, 3), get_box_neighbours(warped_location, 3)),
218
+ lambda: (tf.expand_dims(rand_disp, 0), tf.expand_dims(warped_location, 0)))
219
+
220
+ # 2. Add the new point to the control grid and the target grid
221
+ control_grid = tf.concat([CTRL_GRID.grid_flat(), disp_location], axis=0)
222
+ trg_grid = tf.concat([CTRL_GRID.grid_flat(), warped_location], axis=0)
223
+
224
+ trg_grid, aff = tf.cond(add_affine,
225
+ lambda: transform_points(trg_grid),
226
+ lambda: (trg_grid, tf.eye(4, 4)))
227
+
228
+ # I need to know the shape before running TPS
229
+ control_grid.set_shape([73, 3] if USE_LARGE_PT else [65, 3])
230
+ trg_grid.set_shape([73, 3] if USE_LARGE_PT else [65, 3])
231
+
232
+ tps = ThinPlateSplines(control_grid, trg_grid)
233
+ def_grid = tps.interpolate(FULL_FINE_GRID.grid_flat())
234
+
235
+ disp_map = def_grid - FULL_FINE_GRID.grid_flat()
236
+ disp_map = tf.reshape(disp_map, (*FINE_GRID_SHAPE, -1))
237
+ # disp_map = interpn(disp_map, FULL_FINE_GRID.grid)
238
+
239
+ # add the batch and channel dimensions
240
+ fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
241
+ fix_tumors = tf.expand_dims(tf.expand_dims(fix_tumors, -1), 0)
242
+ fix_vessels = tf.expand_dims(tf.expand_dims(fix_vessels, -1), 0)
243
+ fix_parenchyma = tf.expand_dims(tf.expand_dims(fix_parenchyma, -1), 0)
244
+ disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
245
+
246
+ mov_tumors = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_tumors, disp_map])
247
+ mov_vessels = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_vessels, disp_map])
248
+ mov_parenchyma = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_parenchyma, disp_map])
249
+ mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
250
+
251
+ return tf.squeeze(mov_img),\
252
+ tf.squeeze(mov_parenchyma),\
253
+ tf.squeeze(mov_tumors),\
254
+ tf.squeeze(mov_vessels),\
255
+ tf.squeeze(disp_map),\
256
+ disp_location,\
257
+ rand_disp,\
258
+ aff #, w, trg_grid, def_grid
259
+
260
+
261
+ if __name__ == '__main__':
262
+ parse_arguments(sys.argv[1:])
263
+ volume_list = [os.path.join(DATASTE_RAW_FILES, f) for f in os.listdir(DATASTE_RAW_FILES) if f.startswith(LITS_CT_FILE)]
264
+ volume_list.sort()
265
+ segmentation_list = [os.path.join(DATASTE_RAW_FILES, f) for f in os.listdir(DATASTE_RAW_FILES) if
266
+ f.startswith(LITS_SEGMENTATION_FILE)]
267
+ segmentation_list.sort()
268
+
269
+ file_path_pairs = [[v, s] for v, s in zip(volume_list, segmentation_list)]
270
+
271
+ print('Generating HD5 files at {} ...', format(const.DESTINATION_FOLDER))
272
+ # with Pool(10) as p, tf.Session(config=config) as sess:
273
+ # tqdm(p.map(generate_training_sample, file_path_pairs))
274
+ intensity_window_w = 350
275
+ intensity_window_l = 40
276
+ intensity_clipping_range = intensity_window_l + np.asarray([-intensity_window_w // 2, intensity_window_w // 2],
277
+ np.int) # Slicer range for abdominal CT
278
+
279
+ try_mkdir(const.DESTINATION_FOLDER)
280
+
281
+ print('PART 1: Deformation')
282
+ # Then do the fancy stuff
283
+ init = tf.initialize_all_variables()
284
+ get_mov_img = tf_graph_deform()
285
+ sess = tf.Session(config=config)
286
+ with sess.as_default():
287
+ sess.run(init)
288
+ sess.graph.finalize()
289
+ for img_path, labels_path in tqdm(file_path_pairs):
290
+ if img_path is not None and labels_path is not None:
291
+ #img_path = unzip_file(img_path)
292
+ #labels_path = unzip_file(labels_path)
293
+
294
+ fix_img = nib.load(img_path) # By convention, nibabel world axes are always in RAS+ orientation
295
+ img_header = fix_img.header
296
+ fix_labels = nib.load(labels_path)
297
+ fix_img = np.asarray(fix_img.dataobj)
298
+ fix_labels = np.asarray(fix_labels.dataobj)
299
+ if fix_labels.shape[-1] < 4:
300
+ print('[INF] ' + img_path + ' has no tumor segmentations')
301
+ continue
302
+ # fix_artery = fix_labels[..., 0]
303
+ fix_vessels = fix_labels[..., 1]
304
+ fix_parenchyma = fix_labels[..., 2]
305
+ fix_tumors = fix_labels[..., 3]
306
+
307
+ # Clip intensity values
308
+ fix_img = utils.intesity_clipping(fix_img, intensity_clipping_range, augment=True)
309
+
310
+ # Reshape
311
+ fix_img = resize(fix_img, IMG_SIZE_LARGE_x2)
312
+ fix_parenchyma = resize(fix_parenchyma, IMG_SIZE_LARGE_x2)
313
+ fix_tumors = resize(fix_tumors, IMG_SIZE_LARGE_x2)
314
+ fix_vessels = resize(fix_vessels, IMG_SIZE_LARGE_x2)
315
+
316
+ fix_parenchyma = median(fix_parenchyma, np.ones((5, 5, 5)))
317
+
318
+ # Compute deformation
319
+ mov_img, mov_parenchyma, mov_tumors, mov_vessels, disp_map, disp_loc, disp_vec, aff = sess.run(get_mov_img,
320
+ feed_dict={
321
+ 'fix_img:0': fix_img,
322
+ 'fix_tumors:0': fix_tumors,
323
+ 'fix_vessels:0': fix_vessels,
324
+ 'fix_parenchyma:0': fix_parenchyma,
325
+ 'large_point:0': USE_LARGE_PT,
326
+ 'add_affine:0': ADD_AFFINE_TRF})
327
+ # Cleaning
328
+ mov_img = utils.intesity_clipping(mov_img, intensity_clipping_range)
329
+
330
+ if USE_LARGE_PT:
331
+ disp_loc = disp_loc[0, ...]
332
+
333
+ # Define the bbox around the union of the parenchyma of both volumes, so none falls outside
334
+ bbox_mask = np.sign(mov_parenchyma + fix_parenchyma)
335
+ bbox_mask = binary_dilation(bbox_mask, DILATION_STRUCT)
336
+ bbox_mask = binary_dilation(bbox_mask, DILATION_STRUCT).astype(np.float32)
337
+
338
+ # The point of application is referred to the whole image coordinate, not to the local BB
339
+ min_i, _, min_j, _, min_k, _ = bbox_3D(bbox_mask)
340
+ disp_loc = (disp_loc - np.asarray([min_i, min_j, min_k])) / 2
341
+ # Crop the image to only contain the liver
342
+ # The origin moved according to the mask information. And the images will be resized in a factor of 2!!
343
+
344
+ fix_img, _ = crop_images(fix_img, bbox_mask, IMG_SIZE_LARGE)
345
+ fix_tumors, _ = crop_images(fix_tumors, bbox_mask, IMG_SIZE_LARGE)
346
+ fix_vessels, _ = crop_images(fix_vessels, bbox_mask, IMG_SIZE_LARGE)
347
+ disp_map, _ = crop_images(disp_map, bbox_mask, IMG_SIZE_LARGE)
348
+ fix_parenchyma, _ = crop_images(fix_parenchyma, bbox_mask, IMG_SIZE_LARGE)
349
+
350
+ # We will later crop even further, so we don't want to downsample too much
351
+ # Crop the image to only contain the liver
352
+ mov_img, _ = crop_images(mov_img, bbox_mask, IMG_SIZE_LARGE)
353
+ mov_tumors, _ = crop_images(mov_tumors, bbox_mask, IMG_SIZE_LARGE)
354
+ mov_vessels, _ = crop_images(mov_vessels, bbox_mask, IMG_SIZE_LARGE)
355
+ mov_parenchyma, _ = crop_images(mov_parenchyma, bbox_mask, IMG_SIZE_LARGE)
356
+
357
+ # Just to be sure we have binary masks
358
+ fix_tumors[fix_tumors > TH_BIN] = 1.0
359
+ fix_tumors[fix_tumors < 1.0] = 0.0
360
+
361
+ fix_vessels[fix_vessels > TH_BIN] = 1.0
362
+ fix_vessels[fix_vessels < 1.0] = 0.0
363
+
364
+ fix_parenchyma[fix_parenchyma > TH_BIN] = 1.0
365
+ fix_parenchyma[fix_parenchyma < 1.0] = 0.0
366
+
367
+ mov_tumors[mov_tumors > TH_BIN] = 1.0
368
+ mov_tumors[mov_tumors < 1.0] = 0.0
369
+
370
+ mov_vessels[mov_vessels > TH_BIN] = 1.0
371
+ mov_vessels[mov_vessels < 1.0] = 0.0
372
+
373
+ mov_parenchyma[mov_parenchyma > TH_BIN] = 1.0
374
+ mov_parenchyma[mov_parenchyma < 1.0] = 0.0
375
+
376
+ # Save everything
377
+ fix_img = np.expand_dims(fix_img, -1)
378
+ fix_tumors = np.expand_dims(fix_tumors, -1)
379
+ fix_vessels = np.expand_dims(fix_vessels, -1)
380
+ fix_parenchyma = np.expand_dims(fix_parenchyma, -1)
381
+ fix_segmentations = np.stack([fix_parenchyma, fix_vessels, fix_tumors], -1)
382
+
383
+ mov_img = np.expand_dims(mov_img, -1)
384
+ mov_tumors = np.expand_dims(mov_tumors, -1)
385
+ mov_vessels = np.expand_dims(mov_vessels, -1)
386
+ mov_parenchyma = np.expand_dims(mov_parenchyma, -1)
387
+
388
+ # Save everything
389
+ file_name = os.path.split(img_path)[-1].split('.')[0]
390
+ vol_num = int(re.split('-|_', file_name)[-1])
391
+ hd5_filename = 'volume-{:04d}'.format(vol_num + OFFSET_NAME_NUM)
392
+ hd5_filename = os.path.join(const.DESTINATION_FOLDER, hd5_filename + '.hd5')
393
+ hd5_file = h5py.File(hd5_filename, 'w')
394
+
395
+ hd5_file.create_dataset(const.H5_FIX_IMG, data=fix_img, dtype='float32')
396
+ hd5_file.create_dataset(const.H5_FIX_PARENCHYMA_MASK, data=fix_parenchyma, dtype='float32')
397
+ hd5_file.create_dataset(const.H5_FIX_VESSELS_MASK, data=fix_vessels, dtype='float32')
398
+ hd5_file.create_dataset(const.H5_FIX_TUMORS_MASK, data=fix_tumors, dtype='float32')
399
+ hd5_file.create_dataset(const.H5_FIX_SEGMENTATIONS, data=fix_segmentations, dtype='float32')
400
+
401
+ hd5_file.create_dataset(const.H5_PARAMS_INTENSITY_RANGE, (2,), data=intensity_clipping_range,
402
+ dtype='float32')
403
+
404
+ hd5_file.create_dataset(const.H5_MOV_IMG, const.IMG_SHAPE, data=mov_img, dtype='float32')
405
+ hd5_file.create_dataset(const.H5_MOV_PARENCHYMA_MASK, const.IMG_SHAPE, data=mov_parenchyma,
406
+ dtype='float32')
407
+ hd5_file.create_dataset(const.H5_MOV_VESSELS_MASK, const.IMG_SHAPE, data=mov_vessels, dtype='float32')
408
+ hd5_file.create_dataset(const.H5_MOV_TUMORS_MASK, const.IMG_SHAPE, data=mov_tumors, dtype='float32')
409
+ hd5_file.create_dataset(const.H5_MOV_SEGMENTATIONS, data=fix_segmentations, dtype='float32')
410
+
411
+ hd5_file.create_dataset(const.H5_GT_DISP, const.DISP_MAP_SHAPE, data=disp_map, dtype='float32')
412
+ hd5_file.create_dataset(const.H5_GT_DISP_VECT_LOC, data=disp_loc, dtype='float32')
413
+ hd5_file.create_dataset(const.H5_GT_DISP_VECT, data=disp_vec, dtype='float32')
414
+ hd5_file.create_dataset(const.H5_GT_AFFINE_M, data=aff, dtype='float32')
415
+
416
+ hd5_file.create_dataset('params/voxel_size', data=img_header.get_zooms()[:3])
417
+ hd5_file.create_dataset('params/original_shape', data=img_header.get_data_shape())
418
+ hd5_file.create_dataset('params/bbox_origin', data=[min_i, min_j, min_k])
419
+ hd5_file.create_dataset('params/first_reshape', data=IMG_SIZE_LARGE_x2)
420
+
421
+ # delete_temp(img_path)
422
+ # delete_temp(labels_path)
423
+
424
+ hd5_file.close()
425
+ sess.close()
426
+ print('...Done generating HD5 files')