|
import sys, os |
|
import numpy as np |
|
import nibabel as nib |
|
from scipy import ndimage as ndi |
|
from scipy.signal import convolve |
|
from numpy.linalg import norm |
|
import networkx as nx |
|
import logging |
|
import traceback |
|
import timeit |
|
import time |
|
import math |
|
from ast import literal_eval as make_tuple |
|
from skimage.measure import label |
|
import subprocess |
|
import platform |
|
import glob |
|
|
|
|
|
def loadVolume(volumeFolderPath, volumeName): |
|
""" |
|
Load nifti files (*.nii or *.nii.gz). |
|
Parameters |
|
---------- |
|
volumeFolderPath : str |
|
Folder of the volume file. |
|
volumeName : str |
|
Name of the volume file. |
|
|
|
Returns |
|
------- |
|
volume : ndarray |
|
Volume data in the form of numpy ndarray. |
|
affine : ndarray |
|
Associated affine transformation matrix in the form of numpy ndarray. |
|
""" |
|
volumeFilePath = os.path.join(volumeFolderPath, volumeName) |
|
volumeImg = nib.load(volumeFilePath) |
|
volume = volumeImg.get_data() |
|
shape = volume.shape |
|
affine = volumeImg.affine |
|
print('Volume loaded from {} with shape = {}.'.format(volumeFilePath, shape)) |
|
|
|
return volume, affine |
|
|
|
|
|
def saveVolume(volume, affine, path, astype=None): |
|
""" |
|
Save the given volume to the specified location in specified data type. |
|
Parameters |
|
---------- |
|
volume : ndarray |
|
Volume data to be saved. |
|
affine : ndarray |
|
The affine transformation matrix associated with the volume. |
|
path : str |
|
The absolute path where the volume is going to be saved. |
|
astype : numpy dtype, optional |
|
The desired data type of the volume data. |
|
""" |
|
if astype is None: |
|
astype = np.uint8 |
|
|
|
nib.save(nib.Nifti1Image(volume.astype(astype), affine), path) |
|
print('Volume saved to {} as type {}.'.format(path, astype)) |
|
|
|
|
|
def labelVolume(volume, minSize=1, maxHop=3): |
|
""" |
|
Partition the volume into several connected components and attach labels. |
|
Parameters |
|
---------- |
|
volume : ndarray |
|
Volume to be partitioned. |
|
minSize : int, optional |
|
The connected component that is less than this size will be disgarded. |
|
maxHop : int, optional |
|
Controls how neighboring voxels are defined. See `label` doc for details. |
|
|
|
Returns |
|
------- |
|
labeled : ndarray |
|
The partitioned and labeled volume. Each connected component has a label (a positive integer) and the background |
|
is labeled as 0. |
|
labelResult : list |
|
In the form of [[label1, size1], [label2, size2], ...] |
|
""" |
|
labeled, maxNum = label(volume, return_num=True, connectivity=maxHop) |
|
counts = np.bincount(labeled.ravel()) |
|
countLoc = np.nonzero(counts)[0] |
|
sizeList = counts[countLoc] |
|
labelResult = list(zip(countLoc[sizeList >= minSize], sizeList[sizeList >= minSize])) |
|
|
|
|
|
return labeled, labelResult |
|
|
|
|
|
def analyze(vesselVolumeMask, baseFolder): |
|
""" |
|
Main function to provoke the skeletonization process. Note that here I am using the docker version of the code. If |
|
you have already downloaded the original C++ code and successfully compiled it, then you can run that compiled code |
|
instead of this one. |
|
""" |
|
vesselVolumeMask = vesselVolumeMask.astype(np.uint8) |
|
vesselVolumeMask[vesselVolumeMask != 0] = 1 |
|
vesselVolumeMask = np.swapaxes(vesselVolumeMask, 0, 2) |
|
shape = vesselVolumeMask.shape |
|
|
|
vesselVolumeMaskLabeled, vesselVolumeMaskLabelResult = labelVolume(vesselVolumeMask, minSize=1) |
|
directory = os.path.join(baseFolder, 'skeletonizationResult') |
|
if not os.path.exists(directory): |
|
os.makedirs(directory) |
|
print('Directory {} created.'.format(directory)) |
|
|
|
vesselVolumeMaskLabelInfoFilename = 'vesselVolumeMaskLabelInfo.npz' |
|
vesselVolumeMaskLabelInfoFilePath = os.path.join(directory, vesselVolumeMaskLabelInfoFilename) |
|
np.savez_compressed(vesselVolumeMaskLabelInfoFilePath, vesselVolumeMaskLabeled=vesselVolumeMaskLabeled, |
|
vesselVolumeMaskLabelResult=vesselVolumeMaskLabelResult) |
|
print('{} saved to {}.'.format(vesselVolumeMaskLabelInfoFilename, vesselVolumeMaskLabelInfoFilePath)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BBFilePath = os.path.join(directory, 'BB.txt') |
|
f1 = open(BBFilePath, 'w') |
|
f1.write('1\n') |
|
f1.write('{} {} {}\n'.format(0, 0, 0)) |
|
f1.write('{} {} {}'.format(*shape)) |
|
f1.close() |
|
|
|
vesselCoords = np.array(np.where(vesselVolumeMask)).T |
|
xyzFilePath = os.path.join(directory, 'xyz.txt') |
|
np.savetxt(xyzFilePath, vesselCoords, fmt='%1u') |
|
f2 = open(xyzFilePath, "r") |
|
contents = f2.readlines() |
|
f2.close() |
|
|
|
contents.insert(0, '{}\n'.format(len(vesselCoords))) |
|
|
|
f2 = open(xyzFilePath, "w") |
|
contents = "".join(contents) |
|
f2.write(contents) |
|
f2.close() |
|
|
|
|
|
|
|
currentPlatform = platform.system() |
|
print('Current platform is {}.'.format(currentPlatform)) |
|
if currentPlatform == 'Windows': |
|
cmd = '"C:/Program Files/Docker/Docker/Resources/bin/docker.exe" run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker' |
|
elif currentPlatform == 'Darwin': |
|
cmd = 'docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2' |
|
elif currentPlatform == 'Linux': |
|
cmd = '/usr/local/bin/docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker' |
|
cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker' |
|
cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2' |
|
|
|
print('cmd={}'.format(cmd)) |
|
subprocess.call(cmd, shell=True) |
|
|
|
|
|
|
|
def combineSkeletonSegments(skeletonSegmentFolderPath): |
|
""" |
|
Collect and combine the results from the skeletonization. |
|
Parameters |
|
---------- |
|
skeletonSegmentFolderPath : str |
|
The folder that contains the segments information (result_segments_xyz*.txt). |
|
|
|
Returns |
|
------- |
|
segmentList : list |
|
A list containing the segment information. Each sublist represents a segment and each element in the sublist |
|
represents a centerpoint coordinates. |
|
""" |
|
segmentList = [] |
|
files = glob.glob(os.path.join(skeletonSegmentFolderPath, 'result_segments_xyz*.txt')) |
|
for segmentFile in files: |
|
result = readSegmentFile(segmentFile) |
|
segmentList += result |
|
|
|
return segmentList |
|
|
|
|
|
def readSegmentFile(segmentFile): |
|
""" |
|
Parse the segment files (result_segments_xyz*.txt) and return segments information in a list. |
|
Parameters |
|
---------- |
|
segmentFile : str |
|
Path to the segment file. |
|
|
|
Returns |
|
------- |
|
segmentList : list |
|
A list containing the segment information. Each sublist represents a segment and each element in the sublist |
|
represents a centerpoint coordinates. |
|
""" |
|
isFirstLine = True |
|
isSegmentLength = True |
|
segmentList = [] |
|
with open(segmentFile) as f: |
|
for line in f: |
|
if isFirstLine: |
|
numOfSegments = int(line) |
|
isFirstLine = False |
|
else: |
|
if isSegmentLength: |
|
segmentLength = int(line) |
|
isSegmentLength = False |
|
segmentCounter = 1 |
|
segment = [] |
|
else: |
|
if segmentCounter <= segmentLength: |
|
voxel = tuple([int(x) for x in line.split(' ')]) |
|
segment.append(voxel[::-1]) |
|
segmentCounter += 1 |
|
else: |
|
segmentCounter += 1 |
|
isSegmentLength = True |
|
segmentList.append(segment) |
|
assert (len(segment) == segmentLength) |
|
|
|
return segmentList |
|
|
|
|
|
|
|
|
|
|
|
def processSegments(segmentList, shape): |
|
""" |
|
Re-partition the segments so that each segment is a simple branch, i.e., it does not contain bifurcation point |
|
unless at the two ends. |
|
Note that this function might be replaced by another more concise function `getSegmentList`. |
|
Parameters |
|
---------- |
|
segmentList : list |
|
A list containing the segment information. Each sublist represents a segment and each element in the sublist |
|
represents a centerpoint coordinates. |
|
shape : tuple |
|
Shape of the vessel volume (used for ploting). |
|
|
|
Returns |
|
------- |
|
G : NetworkX graph |
|
A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch. |
|
segmentList : list |
|
A list containing the segment information. Each sublist represents a segment and each element in the sublist |
|
represents a centerpoint coordinates. |
|
errorSegments : list |
|
A list that contains segments that cannot be fixed. |
|
""" |
|
|
|
from pyqtgraph.Qt import QtCore, QtGui |
|
import pyqtgraph as pg |
|
import pyqtgraph.opengl as gl |
|
|
|
|
|
app = pg.QtGui.QApplication([]) |
|
w = gl.GLViewWidget() |
|
w.opts['distance'] = 800 |
|
w.setGeometry(0, 110, 1600, 900) |
|
offset = np.array(shape) / (-2.0) |
|
|
|
G = nx.Graph() |
|
colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')] |
|
colorPointer = 0 |
|
skeleton = np.full(shape, 0) |
|
for segment in segmentList: |
|
|
|
G.add_path(segment) |
|
segmentCoords = np.array(segment) |
|
skeleton[tuple(segmentCoords.T)] = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
voxelDegrees = np.array([v for _, v in G.degree(G.nodes())]) |
|
maxVoxelDegree = np.amax(voxelDegrees) |
|
voxelDegreesZippedResult = list(zip(np.arange(maxVoxelDegree + 1), np.bincount(voxelDegrees))) |
|
print('Voxel degree distribution is \n{}'.format(voxelDegreesZippedResult)) |
|
print('Number of cycles is {}'.format(len(nx.cycle_basis(G)))) |
|
|
|
|
|
keepList = np.full((len(segmentList),), True) |
|
duplicateCounter = 0 |
|
for idx, seg in enumerate(segmentList): |
|
for idx2, seg2 in enumerate(segmentList[idx + 1:]): |
|
if seg == seg2 or seg == seg2[::-1]: |
|
keepList[idx + idx2] = False |
|
duplicateCounter += 1 |
|
|
|
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] |
|
print('{} duplicate segments removed!'.format(duplicateCounter)) |
|
|
|
|
|
extraSegments = [] |
|
keepList = np.full((len(segmentList),), True) |
|
for idx, segment in enumerate(segmentList): |
|
voxelDegrees = np.array([v for _, v in G.degree(segment)]) |
|
if len(voxelDegrees) >= 3: |
|
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or (not np.all(voxelDegrees[1:-1] == 2)): |
|
keepList[idx] = False |
|
locs = np.nonzero(voxelDegrees != 2)[0] |
|
if voxelDegrees[0] == 2: |
|
locs = np.hstack((0, locs)) |
|
|
|
if voxelDegrees[-1] == 2: |
|
locs = np.hstack((locs, len(voxelDegrees))) |
|
|
|
newSegments = [] |
|
for ii in range(len(locs) - 1): |
|
newSegments.append(segment[locs[ii]:(locs[ii + 1] + 1)]) |
|
|
|
extraSegments += newSegments |
|
|
|
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] |
|
segmentList += extraSegments |
|
|
|
|
|
keepList = np.full((len(segmentList),), True) |
|
duplicateCounter = 0 |
|
for idx, seg in enumerate(segmentList): |
|
for idx2, seg2 in enumerate(segmentList[idx + 1:]): |
|
if seg == seg2 or seg == seg2[::-1]: |
|
keepList[idx + idx2] = False |
|
duplicateCounter += 1 |
|
|
|
segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] |
|
print('{} duplicate segments removed in the second stage!'.format(duplicateCounter)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hasInvalidSegments = False |
|
for idx, segment in enumerate(segmentList): |
|
voxelDegrees = np.array([v for _, v in G.degree(segment)]) |
|
if len(voxelDegrees) == 2: |
|
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2: |
|
|
|
hasInvalidSegments = True |
|
elif len(voxelDegrees) > 2: |
|
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2): |
|
|
|
hasInvalidSegments = True |
|
|
|
if not hasInvalidSegments: |
|
drawSegments(segmentList, shape) |
|
print('No errors!') |
|
errorSegments = [] |
|
return G, segmentList, errorSegments |
|
|
|
iterCounter = 1 |
|
while hasInvalidSegments: |
|
print('\n\nIter={}'.format(iterCounter)) |
|
keepList = np.full((len(segmentList),), True) |
|
extraSegments = [] |
|
for idx, segment in enumerate(segmentList): |
|
if keepList[idx]: |
|
voxelDegrees = np.array([v for _, v in G.degree(segment)]) |
|
if voxelDegrees[0] == 2 and voxelDegrees[-1] == 2: |
|
print('Both end have 2 neighbours') |
|
elif voxelDegrees[0] == 2 or voxelDegrees[-1] == 2: |
|
|
|
|
|
|
|
if voxelDegrees[0] == 2: |
|
otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if |
|
(seg[0] == segment[0] or seg[-1] == segment[0]) and keepList[ |
|
idx2] and idx != idx2] |
|
if len(otherSegmentInfo) != 0: |
|
if len(otherSegmentInfo) > 1: |
|
|
|
otherSegmentInfoTemp = [] |
|
for idx2, seg in otherSegmentInfo: |
|
if contains(segment, seg) or contains(segment[::-1], seg): |
|
keepList[idx] = False |
|
continue |
|
elif contains(seg, segment) or contains(seg[::-1], segment): |
|
keepList[idx2] = False |
|
otherSegmentInfoTemp.append((idx2, seg)) |
|
|
|
otherSegmentInfo = otherSegmentInfoTemp |
|
|
|
if len(otherSegmentInfo) > 1: |
|
print('More than one other segments found!') |
|
print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees)) |
|
for otherSegmentIdx, otherSegment in otherSegmentInfo: |
|
otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)]) |
|
print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment, |
|
otherSegmentVoxelDegrees)) |
|
elif len(otherSegmentInfo) == 1: |
|
otherSegmentIdx, otherSegment = otherSegmentInfo[0] |
|
else: |
|
print('No valid other segments found!') |
|
continue |
|
else: |
|
otherSegmentIdx, otherSegment = otherSegmentInfo[0] |
|
if contains(segment, otherSegment) or contains(segment[::-1], otherSegment): |
|
keepList[idx] = False |
|
continue |
|
elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment): |
|
keepList[otherSegmentIdx] = False |
|
continue |
|
|
|
newSegment = otherSegment + segment[1:] if otherSegment[-1] == segment[0] else otherSegment[ |
|
::-1] + segment[ |
|
1:] |
|
if not validateSegment(G, newSegment): |
|
newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)]) |
|
print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees, |
|
newSegmentVoxelDegrees)) |
|
else: |
|
print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx)) |
|
|
|
extraSegments.append(newSegment) |
|
keepList[idx] = False |
|
keepList[otherSegmentIdx] = False |
|
else: |
|
print( |
|
'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment, |
|
voxelDegrees)) |
|
possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if |
|
(seg[0] == segment[0] or seg[-1] == segment[0]) and idx != idx2] |
|
print('Possible segments: {}'.format(len(possibleSegmentsInfo))) |
|
|
|
elif voxelDegrees[-1] == 2: |
|
otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if |
|
(seg[0] == segment[-1] or seg[-1] == segment[-1]) and keepList[ |
|
idx2] and idx != idx2] |
|
if len(otherSegmentInfo) != 0: |
|
if len(otherSegmentInfo) > 1: |
|
|
|
otherSegmentInfoTemp = [] |
|
for idx2, seg in otherSegmentInfo: |
|
if contains(segment, seg) or contains(segment[::-1], seg): |
|
keepList[idx] = False |
|
continue |
|
elif contains(seg, segment) or contains(seg[::-1], segment): |
|
keepList[idx2] = False |
|
otherSegmentInfoTemp.append((idx2, seg)) |
|
|
|
otherSegmentInfo = otherSegmentInfoTemp |
|
|
|
if len(otherSegmentInfo) > 1: |
|
print('More than one other segments found!') |
|
print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees)) |
|
for otherSegmentIdx, otherSegment in otherSegmentInfo: |
|
otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)]) |
|
print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment, |
|
otherSegmentVoxelDegrees)) |
|
elif len(otherSegmentInfo) == 1: |
|
otherSegmentIdx, otherSegment = otherSegmentInfo[0] |
|
else: |
|
print('No valid other segments found!') |
|
continue |
|
else: |
|
otherSegmentIdx, otherSegment = otherSegmentInfo[0] |
|
if contains(segment, otherSegment) or contains(segment[::-1], otherSegment): |
|
keepList[idx] = False |
|
continue |
|
elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment): |
|
keepList[otherSegmentIdx] = False |
|
continue |
|
|
|
newSegment = segment[:-1] + otherSegment if otherSegment[0] == segment[-1] else segment[ |
|
:-1] + otherSegment[ |
|
::-1] |
|
if not validateSegment(G, newSegment): |
|
newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)]) |
|
print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees, |
|
newSegmentVoxelDegrees)) |
|
else: |
|
print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx)) |
|
|
|
extraSegments.append(newSegment) |
|
keepList[idx] = False |
|
keepList[otherSegmentIdx] = False |
|
else: |
|
print( |
|
'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment, |
|
voxelDegrees)) |
|
possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if |
|
(seg[0] == segment[-1] or seg[-1] == segment[-1]) and idx != idx2] |
|
print('Possible segments: {}'.format(len(possibleSegmentsInfo))) |
|
|
|
segmentList = [segment for idx, segment in enumerate(segmentList) if keepList[idx]] |
|
segmentList += extraSegments |
|
hasInvalidSegments = False |
|
errorSegments = [] |
|
for idx, segment in enumerate(segmentList): |
|
voxelDegrees = np.array([v for _, v in G.degree(segment)]) |
|
if len(voxelDegrees) == 2: |
|
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2: |
|
print('Degrees on either end is 2: {}'.format(voxelDegrees)) |
|
hasInvalidSegments = True |
|
errorSegments.append(segment) |
|
elif len(voxelDegrees) > 2: |
|
if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2): |
|
print('Degrees not correct: {}'.format(voxelDegrees)) |
|
hasInvalidSegments = True |
|
errorSegments.append(segment) |
|
|
|
print('hasInvalidSegments = {}'.format(hasInvalidSegments)) |
|
iterCounter += 1 |
|
if len(extraSegments) == 0: |
|
hasInvalidSegments = False |
|
print('While loop aborted because there is no change in segments!') |
|
|
|
for errorSegment in errorSegments: |
|
segmentList.remove(errorSegment) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(errorSegments) |
|
|
|
return G, segmentList, errorSegments |
|
|
|
|
|
def getSegmentList(G, nodeInfoDict): |
|
""" |
|
Generate segmentList from graph and nodeInfoDict. |
|
Parameters |
|
---------- |
|
G : NetworkX graph |
|
The graph representation of the network. |
|
nodeInfoDict : dict |
|
A dictionary containing the information about nodes. |
|
|
|
Returns |
|
------- |
|
segmentList : list |
|
A list of segments in which each segment is a simple branch. |
|
""" |
|
startNodeIDList = [nodeID for nodeID in nodeInfoDict.keys() if nodeInfoDict[nodeID]['parentNodeID'] == -1] |
|
print('startNodeIDList = {}'.format(startNodeIDList)) |
|
segmentList = [] |
|
for startNodeID in startNodeIDList: |
|
segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID) |
|
|
|
print('There are {} segments in segmentList'.format(len(segmentList))) |
|
print(segmentList) |
|
return segmentList |
|
|
|
|
|
def getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID): |
|
""" |
|
Implementation of `getSegmentList`. Use DFS to traverse all the segments. |
|
Parameters |
|
---------- |
|
G : NetworkX graph |
|
The graph representation of the network. |
|
nodeInfoDict : dict |
|
A dictionary containing the information about nodes. |
|
segmentList : list |
|
A list of segments in which each segment is a simple branch. |
|
startNodeID : int |
|
The index of the start point of a segment. |
|
|
|
Returns |
|
------- |
|
segmentList : list |
|
A list of segments in which each segment is a simple branch. |
|
""" |
|
neighborNodeIDList = [nodeID for nodeID in list(G[startNodeID].keys()) if |
|
'visited' not in G[startNodeID][nodeID]] |
|
newSegmentList = [] |
|
for neighborNodeID in neighborNodeIDList: |
|
newSegment = [startNodeID, neighborNodeID] |
|
G[startNodeID][neighborNodeID]['visited'] = True |
|
currentNodeID = neighborNodeID |
|
while G.degree(currentNodeID) == 2: |
|
newNodeID = [nodeID for nodeID in G[currentNodeID].keys() if 'visited' not in G[currentNodeID][nodeID]][0] |
|
G[currentNodeID][newNodeID]['visited'] = True |
|
newSegment.append(newNodeID) |
|
currentNodeID = newNodeID |
|
|
|
newSegmentList.append(newSegment) |
|
segmentList.append(newSegment) |
|
segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, currentNodeID) |
|
|
|
return segmentList |
|
|
|
|
|
def sublist(ls1, ls2): |
|
''' |
|
>>> sublist([], [1,2,3]) |
|
True |
|
>>> sublist([1,2,3,4], [2,5,3]) |
|
True |
|
>>> sublist([1,2,3,4], [0,3,2]) |
|
False |
|
>>> sublist([1,2,3,4], [1,2,5,6,7,8,5,76,4,3]) |
|
False |
|
''' |
|
|
|
def get_all_in(one, another): |
|
for element in one: |
|
if element in another: |
|
yield element |
|
|
|
for x1, x2 in zip(get_all_in(ls1, ls2), get_all_in(ls2, ls1)): |
|
if x1 != x2: |
|
return False |
|
|
|
return True |
|
|
|
|
|
def contains(lst1, lst2): |
|
lst1, lst2 = (lst2, lst1) if len(lst1) > len(lst2) else (lst1, lst2) |
|
if lst1[0] in lst2: |
|
startLoc = lst2.index(lst1[0]) |
|
else: |
|
return False |
|
|
|
if lst1[-1] in lst2: |
|
endLoc = lst2.index(lst1[-1]) |
|
else: |
|
return False |
|
|
|
if startLoc < endLoc: |
|
if lst1 == lst2[startLoc:(endLoc + 1)]: |
|
return True |
|
else: |
|
return False |
|
else: |
|
if lst1 == lst2[endLoc:(startLoc + 1)][::-1]: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def validateSegment(G, segment): |
|
""" |
|
Check whether a segment is a simple branch. |
|
Parameters |
|
---------- |
|
G : NetworkX graph |
|
A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch. |
|
segment : list |
|
A list containing the coordinates of the centerpoints of a segment. |
|
|
|
Returns |
|
------- |
|
result : bool |
|
If True, the segment is a simple branch. |
|
""" |
|
voxelDegrees = np.array([v for _, v in G.degree(segment)]) |
|
if voxelDegrees[0] != 2 and voxelDegrees[-1] != 2: |
|
if len(voxelDegrees) == 2: |
|
result = True |
|
elif len(voxelDegrees) > 2: |
|
if np.all(voxelDegrees[1:-1] == 2): |
|
result = True |
|
else: |
|
result = False |
|
else: |
|
print('Error! Segment with length 1 found!') |
|
result = False |
|
else: |
|
result = False |
|
|
|
return result |
|
|
|
|
|
def drawSegments(segmentList, shape): |
|
""" |
|
Plot all the segments in `segmentList`. Try to assign different colors to the segments connected to the same node. |
|
Parameters |
|
---------- |
|
segmentList : list |
|
A list containing the segment information. Each sublist represents a segment and each element in the sublist |
|
represents a centerpoint coordinates. |
|
shape : tuple |
|
Shape of the vessel volume (used for ploting). |
|
""" |
|
|
|
from pyqtgraph.Qt import QtCore, QtGui |
|
import pyqtgraph as pg |
|
import pyqtgraph.opengl as gl |
|
|
|
|
|
app = pg.QtGui.QApplication([]) |
|
w = gl.GLViewWidget() |
|
w.opts['distance'] = 800 |
|
w.setGeometry(0, 110, 1600, 900) |
|
offset = np.array(shape) / (-2.0) |
|
|
|
colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')] |
|
colorNames = ['Red', 'Green', 'Blue', 'Cyan', 'Magneta', 'Yellow'] |
|
numOfColors = len(colorList) |
|
nodeColorDict = {} |
|
for segment in segmentList: |
|
startVoxel = segment[0] |
|
endVoxel = segment[-1] |
|
if startVoxel in nodeColorDict and endVoxel in nodeColorDict: |
|
nodeColorDict[startVoxel].append([endVoxel, -1]) |
|
nodeColorDict[endVoxel].append([startVoxel, -1]) |
|
else: |
|
if startVoxel not in nodeColorDict: |
|
nodeColorDict[startVoxel] = [[endVoxel, -1]] |
|
else: |
|
nodeColorDict[startVoxel].append([endVoxel, -1]) |
|
|
|
if endVoxel not in nodeColorDict: |
|
nodeColorDict[endVoxel] = [[startVoxel, -1]] |
|
else: |
|
nodeColorDict[endVoxel].append([startVoxel, -1]) |
|
|
|
existingColorsInStart = [colorCode for _, colorCode in nodeColorDict[startVoxel]] |
|
existingColorsInEnd = [colorCode for _, colorCode in nodeColorDict[endVoxel]] |
|
availableColors = [colorCode for colorCode in range(numOfColors) if |
|
colorCode not in existingColorsInStart and colorCode not in existingColorsInEnd] |
|
|
|
chosenColor = availableColors[0] if len(availableColors) != 0 else 0 |
|
nodeColorDict[startVoxel][-1][1] = chosenColor |
|
nodeColorDict[endVoxel][-1][1] = chosenColor |
|
|
|
segmentCoords = np.array(segment) |
|
aa = gl.GLLinePlotItem(pos=segmentCoords, color=colorList[chosenColor], width=3) |
|
aa.translate(*offset) |
|
w.addItem(aa) |
|
|
|
w.show() |
|
pg.QtGui.QApplication.exec_() |
|
|
|
|
|
|
|
def main(): |
|
start_time = timeit.default_timer() |
|
baseFolder = os.path.abspath(os.path.dirname(__file__)) |
|
|
|
|
|
vesselVolumeMaskFolderPath = baseFolder |
|
vesselVolumeMaskFileName = 'vesselVolumeMask.nii.gz' |
|
vesselVolumeMask, vesselVolumeMaskAffine = loadVolume(vesselVolumeMaskFolderPath, vesselVolumeMaskFileName) |
|
|
|
|
|
|
|
|
|
skeletonSegmentFolderPath = os.path.join(baseFolder, 'skeletonizationResult/segments_by_cc') |
|
segmentListRough = combineSkeletonSegments(skeletonSegmentFolderPath) |
|
|
|
shape = vesselVolumeMask.shape |
|
|
|
|
|
G, segmentList, errorSegments = processSegments(segmentListRough, shape=shape) |
|
|
|
G = nx.Graph() |
|
segmentIndex = 0 |
|
for segment in segmentList: |
|
G.add_path(segment, segmentIndex=segmentIndex) |
|
segmentIndex += 1 |
|
|
|
|
|
graphFileName = 'graphRepresentation.graphml' |
|
graphFilePath = os.path.join(baseFolder, graphFileName) |
|
nx.write_graphml(G, graphFilePath) |
|
print('{} saved to {}.'.format(graphFileName, graphFilePath)) |
|
|
|
|
|
segmentListFileName = 'segmentList.npz' |
|
segmentListFilePath = os.path.join(baseFolder, segmentListFileName) |
|
np.savez_compressed(segmentListFilePath, segmentList=segmentList) |
|
print('{} saved to {}.'.format(segmentListFileName, segmentListFilePath)) |
|
|
|
|
|
skeleton = np.zeros_like(vesselVolumeMask) |
|
for segment in segmentList: |
|
skeleton[tuple(np.array(segment).T)] = 1 |
|
|
|
skeletonFileName = 'skeleton.nii.gz' |
|
skeletonFilePath = os.path.join(baseFolder, skeletonFileName) |
|
saveVolume(skeleton, vesselVolumeMaskAffine, skeletonFilePath, astype=np.uint8) |
|
|
|
elapsed = timeit.default_timer() - start_time |
|
print('Elapsed: {} sec'.format(elapsed)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|