import sys, os import numpy as np import nibabel as nib from scipy import ndimage as ndi from scipy.signal import convolve from numpy.linalg import norm import networkx as nx import logging import traceback import timeit import time import math from ast import literal_eval as make_tuple from skimage.measure import label import subprocess import platform import glob def loadVolume(volumeFolderPath, volumeName): """ Load nifti files (*.nii or *.nii.gz). Parameters ---------- volumeFolderPath : str Folder of the volume file. volumeName : str Name of the volume file. Returns ------- volume : ndarray Volume data in the form of numpy ndarray. affine : ndarray Associated affine transformation matrix in the form of numpy ndarray. """ volumeFilePath = os.path.join(volumeFolderPath, volumeName) volumeImg = nib.load(volumeFilePath) volume = volumeImg.get_data() shape = volume.shape affine = volumeImg.affine print('Volume loaded from {} with shape = {}.'.format(volumeFilePath, shape)) return volume, affine def saveVolume(volume, affine, path, astype=None): """ Save the given volume to the specified location in specified data type. Parameters ---------- volume : ndarray Volume data to be saved. affine : ndarray The affine transformation matrix associated with the volume. path : str The absolute path where the volume is going to be saved. astype : numpy dtype, optional The desired data type of the volume data. """ if astype is None: astype = np.uint8 nib.save(nib.Nifti1Image(volume.astype(astype), affine), path) print('Volume saved to {} as type {}.'.format(path, astype)) def labelVolume(volume, minSize=1, maxHop=3): """ Partition the volume into several connected components and attach labels. Parameters ---------- volume : ndarray Volume to be partitioned. minSize : int, optional The connected component that is less than this size will be disgarded. maxHop : int, optional Controls how neighboring voxels are defined. See `label` doc for details. Returns ------- labeled : ndarray The partitioned and labeled volume. Each connected component has a label (a positive integer) and the background is labeled as 0. labelResult : list In the form of [[label1, size1], [label2, size2], ...] """ labeled, maxNum = label(volume, return_num=True, connectivity=maxHop) counts = np.bincount(labeled.ravel()) countLoc = np.nonzero(counts)[0] sizeList = counts[countLoc] labelResult = list(zip(countLoc[sizeList >= minSize], sizeList[sizeList >= minSize])) # print(labelResult) # print('Total segments: {}'.format(np.count_nonzero(sizeList >= minSize))) return labeled, labelResult def analyze(vesselVolumeMask, baseFolder): """ Main function to provoke the skeletonization process. Note that here I am using the docker version of the code. If you have already downloaded the original C++ code and successfully compiled it, then you can run that compiled code instead of this one. """ vesselVolumeMask = vesselVolumeMask.astype(np.uint8) vesselVolumeMask[vesselVolumeMask != 0] = 1 vesselVolumeMask = np.swapaxes(vesselVolumeMask, 0, 2) shape = vesselVolumeMask.shape vesselVolumeMaskLabeled, vesselVolumeMaskLabelResult = labelVolume(vesselVolumeMask, minSize=1) directory = os.path.join(baseFolder, 'skeletonizationResult') if not os.path.exists(directory): os.makedirs(directory) print('Directory {} created.'.format(directory)) vesselVolumeMaskLabelInfoFilename = 'vesselVolumeMaskLabelInfo.npz' vesselVolumeMaskLabelInfoFilePath = os.path.join(directory, vesselVolumeMaskLabelInfoFilename) np.savez_compressed(vesselVolumeMaskLabelInfoFilePath, vesselVolumeMaskLabeled=vesselVolumeMaskLabeled, vesselVolumeMaskLabelResult=vesselVolumeMaskLabelResult) print('{} saved to {}.'.format(vesselVolumeMaskLabelInfoFilename, vesselVolumeMaskLabelInfoFilePath)) # directory2 = directory + 'labelNum=' + str(labelNum) + '/' # if not os.path.exists(directory2): # os.makedirs(directory2) # with open(directory2 + 'BB.txt', 'w') as f1: # f1.write('1\n') # f1.write('{} {} {}\n'.format(0, 0, 0)) # f1.write('{} {} {}'.format(*shape)) # ''' BBFilePath = os.path.join(directory, 'BB.txt') f1 = open(BBFilePath, 'w') f1.write('1\n') f1.write('{} {} {}\n'.format(0, 0, 0)) f1.write('{} {} {}'.format(*shape)) f1.close() vesselCoords = np.array(np.where(vesselVolumeMask)).T xyzFilePath = os.path.join(directory, 'xyz.txt') np.savetxt(xyzFilePath, vesselCoords, fmt='%1u') f2 = open(xyzFilePath, "r") contents = f2.readlines() f2.close() contents.insert(0, '{}\n'.format(len(vesselCoords))) f2 = open(xyzFilePath, "w") contents = "".join(contents) f2.write(contents) f2.close() # ''' # ''' currentPlatform = platform.system() print('Current platform is {}.'.format(currentPlatform)) if currentPlatform == 'Windows': cmd = '"C:/Program Files/Docker/Docker/Resources/bin/docker.exe" run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker' elif currentPlatform == 'Darwin': cmd = 'docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2' elif currentPlatform == 'Linux': cmd = '/usr/local/bin/docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker' cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-medeiros-2018-docker' cmd = 'sudo docker run -v ' + '"' + directory + '"' + ':/write_directory -e THRESH=1e-12 -e CC_FLAG=1 -e CONVERSION_TYPE=1 amytabb/curveskel-tabb-nih-aug2018-docker2' print('cmd={}'.format(cmd)) subprocess.call(cmd, shell=True) # ''' def combineSkeletonSegments(skeletonSegmentFolderPath): """ Collect and combine the results from the skeletonization. Parameters ---------- skeletonSegmentFolderPath : str The folder that contains the segments information (result_segments_xyz*.txt). Returns ------- segmentList : list A list containing the segment information. Each sublist represents a segment and each element in the sublist represents a centerpoint coordinates. """ segmentList = [] files = glob.glob(os.path.join(skeletonSegmentFolderPath, 'result_segments_xyz*.txt')) for segmentFile in files: result = readSegmentFile(segmentFile) segmentList += result return segmentList def readSegmentFile(segmentFile): """ Parse the segment files (result_segments_xyz*.txt) and return segments information in a list. Parameters ---------- segmentFile : str Path to the segment file. Returns ------- segmentList : list A list containing the segment information. Each sublist represents a segment and each element in the sublist represents a centerpoint coordinates. """ isFirstLine = True isSegmentLength = True segmentList = [] with open(segmentFile) as f: for line in f: if isFirstLine: numOfSegments = int(line) isFirstLine = False else: if isSegmentLength: segmentLength = int(line) isSegmentLength = False segmentCounter = 1 segment = [] else: if segmentCounter <= segmentLength: voxel = tuple([int(x) for x in line.split(' ')]) segment.append(voxel[::-1]) segmentCounter += 1 else: segmentCounter += 1 isSegmentLength = True segmentList.append(segment) assert (len(segment) == segmentLength) return segmentList # def drawSegments(segmentList): # pass def processSegments(segmentList, shape): """ Re-partition the segments so that each segment is a simple branch, i.e., it does not contain bifurcation point unless at the two ends. Note that this function might be replaced by another more concise function `getSegmentList`. Parameters ---------- segmentList : list A list containing the segment information. Each sublist represents a segment and each element in the sublist represents a centerpoint coordinates. shape : tuple Shape of the vessel volume (used for ploting). Returns ------- G : NetworkX graph A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch. segmentList : list A list containing the segment information. Each sublist represents a segment and each element in the sublist represents a centerpoint coordinates. errorSegments : list A list that contains segments that cannot be fixed. """ ## Import pyqtgraph ## from pyqtgraph.Qt import QtCore, QtGui import pyqtgraph as pg import pyqtgraph.opengl as gl ## Init ## app = pg.QtGui.QApplication([]) w = gl.GLViewWidget() w.opts['distance'] = 800 w.setGeometry(0, 110, 1600, 900) offset = np.array(shape) / (-2.0) G = nx.Graph() colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')] colorPointer = 0 skeleton = np.full(shape, 0) for segment in segmentList: # G.add_path(list(map(tuple, segment))) G.add_path(segment) segmentCoords = np.array(segment) skeleton[tuple(segmentCoords.T)] = 1 # segmentCoordsView = segmentCoords + offset # aa = gl.GLLinePlotItem(pos=segmentCoordsView, color=colorList[colorPointer], width=3) # w.addItem(aa) # colorPointer = colorPointer + 1 if colorPointer < len(colorList) - 1 else 0 # skeletonCoords = np.array(np.where(skeleton)).T # skeletonCoordsView = (skeletonCoords + offset) * affineTransform # aa = gl.GLScatterPlotItem(pos=skeletonCoordsView, size=5) # w.addItem(aa) # w.show() voxelDegrees = np.array([v for _, v in G.degree(G.nodes())]) maxVoxelDegree = np.amax(voxelDegrees) voxelDegreesZippedResult = list(zip(np.arange(maxVoxelDegree + 1), np.bincount(voxelDegrees))) print('Voxel degree distribution is \n{}'.format(voxelDegreesZippedResult)) print('Number of cycles is {}'.format(len(nx.cycle_basis(G)))) # Remove duplicate segments keepList = np.full((len(segmentList),), True) duplicateCounter = 0 for idx, seg in enumerate(segmentList): for idx2, seg2 in enumerate(segmentList[idx + 1:]): if seg == seg2 or seg == seg2[::-1]: keepList[idx + idx2] = False duplicateCounter += 1 segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] print('{} duplicate segments removed!'.format(duplicateCounter)) # Cut segments into sub-pieces if there are bifurcation points in the middle extraSegments = [] keepList = np.full((len(segmentList),), True) for idx, segment in enumerate(segmentList): voxelDegrees = np.array([v for _, v in G.degree(segment)]) if len(voxelDegrees) >= 3: if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or (not np.all(voxelDegrees[1:-1] == 2)): keepList[idx] = False locs = np.nonzero(voxelDegrees != 2)[0] if voxelDegrees[0] == 2: locs = np.hstack((0, locs)) if voxelDegrees[-1] == 2: locs = np.hstack((locs, len(voxelDegrees))) newSegments = [] for ii in range(len(locs) - 1): newSegments.append(segment[locs[ii]:(locs[ii + 1] + 1)]) extraSegments += newSegments segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] segmentList += extraSegments # Remove duplicate segments again keepList = np.full((len(segmentList),), True) duplicateCounter = 0 for idx, seg in enumerate(segmentList): for idx2, seg2 in enumerate(segmentList[idx + 1:]): if seg == seg2 or seg == seg2[::-1]: keepList[idx + idx2] = False duplicateCounter += 1 segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] print('{} duplicate segments removed in the second stage!'.format(duplicateCounter)) # Remove segment if it is completely contained in another segment # keepList = np.full((len(segmentList),), True) # sublistCounter = 0 # for idx, seg in enumerate(segmentList): # for idx2, seg2 in enumerate(segmentList[idx + 1:]): # if contains(seg, seg2): # keepList[idx] = False # sublistCounter += 1 # elif contains(seg2, seg): # keepList[idx + idx2] = False # sublistCounter += 1 # segmentList = [seg for idx, seg in enumerate(segmentList) if keepList[idx]] # print('{} sublist segments removed!'.format(sublistCounter)) # Treat the segment if either end is not correct hasInvalidSegments = False for idx, segment in enumerate(segmentList): voxelDegrees = np.array([v for _, v in G.degree(segment)]) if len(voxelDegrees) == 2: if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2: # print('Degrees on either end is 2: {}'.format(voxelDegrees)) hasInvalidSegments = True elif len(voxelDegrees) > 2: if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2): # print('Degrees not correct: {}'.format(voxelDegrees)) hasInvalidSegments = True if not hasInvalidSegments: drawSegments(segmentList, shape) print('No errors!') errorSegments = [] return G, segmentList, errorSegments iterCounter = 1 while hasInvalidSegments: print('\n\nIter={}'.format(iterCounter)) keepList = np.full((len(segmentList),), True) extraSegments = [] for idx, segment in enumerate(segmentList): if keepList[idx]: voxelDegrees = np.array([v for _, v in G.degree(segment)]) if voxelDegrees[0] == 2 and voxelDegrees[-1] == 2: print('Both end have 2 neighbours') elif voxelDegrees[0] == 2 or voxelDegrees[-1] == 2: # print('Degrees on either end is 2: {}'.format(voxelDegrees)) # pass # segmentCoords = np.array(segment) if voxelDegrees[0] == 2: otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if (seg[0] == segment[0] or seg[-1] == segment[0]) and keepList[ idx2] and idx != idx2] if len(otherSegmentInfo) != 0: if len(otherSegmentInfo) > 1: # print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment)) otherSegmentInfoTemp = [] for idx2, seg in otherSegmentInfo: if contains(segment, seg) or contains(segment[::-1], seg): keepList[idx] = False continue elif contains(seg, segment) or contains(seg[::-1], segment): keepList[idx2] = False otherSegmentInfoTemp.append((idx2, seg)) otherSegmentInfo = otherSegmentInfoTemp # otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))] if len(otherSegmentInfo) > 1: print('More than one other segments found!') print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees)) for otherSegmentIdx, otherSegment in otherSegmentInfo: otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)]) print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment, otherSegmentVoxelDegrees)) elif len(otherSegmentInfo) == 1: otherSegmentIdx, otherSegment = otherSegmentInfo[0] else: print('No valid other segments found!') continue else: otherSegmentIdx, otherSegment = otherSegmentInfo[0] if contains(segment, otherSegment) or contains(segment[::-1], otherSegment): keepList[idx] = False continue elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment): keepList[otherSegmentIdx] = False continue newSegment = otherSegment + segment[1:] if otherSegment[-1] == segment[0] else otherSegment[ ::-1] + segment[ 1:] if not validateSegment(G, newSegment): newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)]) print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees, newSegmentVoxelDegrees)) else: print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx)) extraSegments.append(newSegment) keepList[idx] = False keepList[otherSegmentIdx] = False else: print( 'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment, voxelDegrees)) possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if (seg[0] == segment[0] or seg[-1] == segment[0]) and idx != idx2] print('Possible segments: {}'.format(len(possibleSegmentsInfo))) elif voxelDegrees[-1] == 2: otherSegmentInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if (seg[0] == segment[-1] or seg[-1] == segment[-1]) and keepList[ idx2] and idx != idx2] if len(otherSegmentInfo) != 0: if len(otherSegmentInfo) > 1: # print(contains(segment, otherSegmentInfo[0][1]), contains(otherSegmentInfo[1][1], segment)) otherSegmentInfoTemp = [] for idx2, seg in otherSegmentInfo: if contains(segment, seg) or contains(segment[::-1], seg): keepList[idx] = False continue elif contains(seg, segment) or contains(seg[::-1], segment): keepList[idx2] = False otherSegmentInfoTemp.append((idx2, seg)) otherSegmentInfo = otherSegmentInfoTemp # otherSegmentInfo = [segInfo for segInfo in otherSegmentInfo if not (contains(segment, segInfo[1]) or contains(segInfo[1], segment))] if len(otherSegmentInfo) > 1: print('More than one other segments found!') print('Current segment ({}) is {} ({})'.format(idx, segment, voxelDegrees)) for otherSegmentIdx, otherSegment in otherSegmentInfo: otherSegmentVoxelDegrees = np.array([v for _, v in G.degree(otherSegment)]) print('Idx = {}: {} ({})'.format(otherSegmentIdx, otherSegment, otherSegmentVoxelDegrees)) elif len(otherSegmentInfo) == 1: otherSegmentIdx, otherSegment = otherSegmentInfo[0] else: print('No valid other segments found!') continue else: otherSegmentIdx, otherSegment = otherSegmentInfo[0] if contains(segment, otherSegment) or contains(segment[::-1], otherSegment): keepList[idx] = False continue elif contains(otherSegment, segment) or contains(otherSegment[::-1], segment): keepList[otherSegmentIdx] = False continue newSegment = segment[:-1] + otherSegment if otherSegment[0] == segment[-1] else segment[ :-1] + otherSegment[ ::-1] if not validateSegment(G, newSegment): newSegmentVoxelDegrees = np.array([v for _, v in G.degree(newSegment)]) print('Old degree is {} () and new degree is {} ()'.format(voxelDegrees, newSegmentVoxelDegrees)) else: print('Two segments ({} and {}) merged together!'.format(idx, otherSegmentIdx)) extraSegments.append(newSegment) keepList[idx] = False keepList[otherSegmentIdx] = False else: print( 'Could not find other segments for segment({}) {} with degrees {}'.format(idx, segment, voxelDegrees)) possibleSegmentsInfo = [(idx2, seg) for idx2, seg in enumerate(segmentList) if (seg[0] == segment[-1] or seg[-1] == segment[-1]) and idx != idx2] print('Possible segments: {}'.format(len(possibleSegmentsInfo))) segmentList = [segment for idx, segment in enumerate(segmentList) if keepList[idx]] segmentList += extraSegments hasInvalidSegments = False errorSegments = [] for idx, segment in enumerate(segmentList): voxelDegrees = np.array([v for _, v in G.degree(segment)]) if len(voxelDegrees) == 2: if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2: print('Degrees on either end is 2: {}'.format(voxelDegrees)) hasInvalidSegments = True errorSegments.append(segment) elif len(voxelDegrees) > 2: if voxelDegrees[0] == 2 or voxelDegrees[-1] == 2 or np.any(voxelDegrees[1:-1] != 2): print('Degrees not correct: {}'.format(voxelDegrees)) hasInvalidSegments = True errorSegments.append(segment) print('hasInvalidSegments = {}'.format(hasInvalidSegments)) iterCounter += 1 if len(extraSegments) == 0: hasInvalidSegments = False print('While loop aborted because there is no change in segments!') for errorSegment in errorSegments: segmentList.remove(errorSegment) # np.savez_compressed(directory + 'segmentList.npz', segmentList=segmentList) # if partIdx != 10: # nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNamePartial + str(partIdx) + '.nii.gz') # else: # nib.save(nib.Nifti1Image(skeleton.astype(np.int16), vesselImg.affine), directory + skeletonNameTotal + '.nii.gz') # nx.write_graphml(G, directory + 'graphRepresentation.graphml') # drawAbstractGraph(offset, segmentList) # drawAbstractGraph(offset, errorSegments) print(errorSegments) return G, segmentList, errorSegments def getSegmentList(G, nodeInfoDict): """ Generate segmentList from graph and nodeInfoDict. Parameters ---------- G : NetworkX graph The graph representation of the network. nodeInfoDict : dict A dictionary containing the information about nodes. Returns ------- segmentList : list A list of segments in which each segment is a simple branch. """ startNodeIDList = [nodeID for nodeID in nodeInfoDict.keys() if nodeInfoDict[nodeID]['parentNodeID'] == -1] print('startNodeIDList = {}'.format(startNodeIDList)) segmentList = [] for startNodeID in startNodeIDList: segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID) print('There are {} segments in segmentList'.format(len(segmentList))) print(segmentList) return segmentList def getSegmentListDetail(G, nodeInfoDict, segmentList, startNodeID): """ Implementation of `getSegmentList`. Use DFS to traverse all the segments. Parameters ---------- G : NetworkX graph The graph representation of the network. nodeInfoDict : dict A dictionary containing the information about nodes. segmentList : list A list of segments in which each segment is a simple branch. startNodeID : int The index of the start point of a segment. Returns ------- segmentList : list A list of segments in which each segment is a simple branch. """ neighborNodeIDList = [nodeID for nodeID in list(G[startNodeID].keys()) if 'visited' not in G[startNodeID][nodeID]] # use adjacency dict to find neighbors newSegmentList = [] for neighborNodeID in neighborNodeIDList: newSegment = [startNodeID, neighborNodeID] G[startNodeID][neighborNodeID]['visited'] = True currentNodeID = neighborNodeID while G.degree(currentNodeID) == 2: newNodeID = [nodeID for nodeID in G[currentNodeID].keys() if 'visited' not in G[currentNodeID][nodeID]][0] G[currentNodeID][newNodeID]['visited'] = True newSegment.append(newNodeID) currentNodeID = newNodeID newSegmentList.append(newSegment) segmentList.append(newSegment) segmentList = getSegmentListDetail(G, nodeInfoDict, segmentList, currentNodeID) return segmentList def sublist(ls1, ls2): ''' >>> sublist([], [1,2,3]) True >>> sublist([1,2,3,4], [2,5,3]) True >>> sublist([1,2,3,4], [0,3,2]) False >>> sublist([1,2,3,4], [1,2,5,6,7,8,5,76,4,3]) False ''' def get_all_in(one, another): for element in one: if element in another: yield element for x1, x2 in zip(get_all_in(ls1, ls2), get_all_in(ls2, ls1)): if x1 != x2: return False return True def contains(lst1, lst2): lst1, lst2 = (lst2, lst1) if len(lst1) > len(lst2) else (lst1, lst2) if lst1[0] in lst2: startLoc = lst2.index(lst1[0]) else: return False if lst1[-1] in lst2: endLoc = lst2.index(lst1[-1]) else: return False if startLoc < endLoc: if lst1 == lst2[startLoc:(endLoc + 1)]: return True else: return False else: if lst1 == lst2[endLoc:(startLoc + 1)][::-1]: return True else: return False def validateSegment(G, segment): """ Check whether a segment is a simple branch. Parameters ---------- G : NetworkX graph A graph in which each node represents a centerpoint and each edge represents a portion of a vessel branch. segment : list A list containing the coordinates of the centerpoints of a segment. Returns ------- result : bool If True, the segment is a simple branch. """ voxelDegrees = np.array([v for _, v in G.degree(segment)]) if voxelDegrees[0] != 2 and voxelDegrees[-1] != 2: if len(voxelDegrees) == 2: result = True elif len(voxelDegrees) > 2: if np.all(voxelDegrees[1:-1] == 2): result = True else: result = False else: print('Error! Segment with length 1 found!') result = False else: result = False return result def drawSegments(segmentList, shape): """ Plot all the segments in `segmentList`. Try to assign different colors to the segments connected to the same node. Parameters ---------- segmentList : list A list containing the segment information. Each sublist represents a segment and each element in the sublist represents a centerpoint coordinates. shape : tuple Shape of the vessel volume (used for ploting). """ ## Import pyqtgraph ## from pyqtgraph.Qt import QtCore, QtGui import pyqtgraph as pg import pyqtgraph.opengl as gl ## Init ## app = pg.QtGui.QApplication([]) w = gl.GLViewWidget() w.opts['distance'] = 800 w.setGeometry(0, 110, 1600, 900) offset = np.array(shape) / (-2.0) colorList = [pg.glColor('r'), pg.glColor('g'), pg.glColor('b'), pg.glColor('c'), pg.glColor('m'), pg.glColor('y')] colorNames = ['Red', 'Green', 'Blue', 'Cyan', 'Magneta', 'Yellow'] numOfColors = len(colorList) nodeColorDict = {} for segment in segmentList: startVoxel = segment[0] endVoxel = segment[-1] if startVoxel in nodeColorDict and endVoxel in nodeColorDict: # and endVoxel in [voxel for voxel, _ in nodeColorDict[startVoxel]]: nodeColorDict[startVoxel].append([endVoxel, -1]) nodeColorDict[endVoxel].append([startVoxel, -1]) else: if startVoxel not in nodeColorDict: nodeColorDict[startVoxel] = [[endVoxel, -1]] else: nodeColorDict[startVoxel].append([endVoxel, -1]) if endVoxel not in nodeColorDict: nodeColorDict[endVoxel] = [[startVoxel, -1]] else: nodeColorDict[endVoxel].append([startVoxel, -1]) existingColorsInStart = [colorCode for _, colorCode in nodeColorDict[startVoxel]] existingColorsInEnd = [colorCode for _, colorCode in nodeColorDict[endVoxel]] availableColors = [colorCode for colorCode in range(numOfColors) if colorCode not in existingColorsInStart and colorCode not in existingColorsInEnd] # print('color in start: {} and color in end: {}'.format(existingColorsInStart, existingColorsInEnd)) chosenColor = availableColors[0] if len(availableColors) != 0 else 0 nodeColorDict[startVoxel][-1][1] = chosenColor nodeColorDict[endVoxel][-1][1] = chosenColor segmentCoords = np.array(segment) aa = gl.GLLinePlotItem(pos=segmentCoords, color=colorList[chosenColor], width=3) aa.translate(*offset) w.addItem(aa) w.show() pg.QtGui.QApplication.exec_() # sys.exit(app.exec_()) def main(): start_time = timeit.default_timer() baseFolder = os.path.abspath(os.path.dirname(__file__)) ## Load existing volume ## vesselVolumeMaskFolderPath = baseFolder vesselVolumeMaskFileName = 'vesselVolumeMask.nii.gz' vesselVolumeMask, vesselVolumeMaskAffine = loadVolume(vesselVolumeMaskFolderPath, vesselVolumeMaskFileName) ## Skeletonization ## # analyze(vesselVolumeMask, baseFolder) skeletonSegmentFolderPath = os.path.join(baseFolder, 'skeletonizationResult/segments_by_cc') segmentListRough = combineSkeletonSegments(skeletonSegmentFolderPath) shape = vesselVolumeMask.shape # drawSegments(segmentListRough, shape) G, segmentList, errorSegments = processSegments(segmentListRough, shape=shape) # drawSegments(segmentList, shape) G = nx.Graph() segmentIndex = 0 for segment in segmentList: G.add_path(segment, segmentIndex=segmentIndex) segmentIndex += 1 ## Save graph representation ## graphFileName = 'graphRepresentation.graphml' graphFilePath = os.path.join(baseFolder, graphFileName) nx.write_graphml(G, graphFilePath) print('{} saved to {}.'.format(graphFileName, graphFilePath)) ## Save segmentList ## segmentListFileName = 'segmentList.npz' segmentListFilePath = os.path.join(baseFolder, segmentListFileName) np.savez_compressed(segmentListFilePath, segmentList=segmentList) print('{} saved to {}.'.format(segmentListFileName, segmentListFilePath)) ## Save skeleton.nii.gz ## skeleton = np.zeros_like(vesselVolumeMask) for segment in segmentList: skeleton[tuple(np.array(segment).T)] = 1 skeletonFileName = 'skeleton.nii.gz' skeletonFilePath = os.path.join(baseFolder, skeletonFileName) saveVolume(skeleton, vesselVolumeMaskAffine, skeletonFilePath, astype=np.uint8) elapsed = timeit.default_timer() - start_time print('Elapsed: {} sec'.format(elapsed)) if __name__ == "__main__": main()