# SRC: https://github.com/Image-Py/sknw/blob/master/sknw/sknw.py import numpy as np import networkx as nx from Centerline.graph_utils import subsample_graph def neighbors(shape): dim = len(shape) block = np.ones([3] * dim) block[tuple([1] * dim)] = 0 idx = np.where(block > 0) idx = np.array(idx, dtype=np.uint8).T idx = np.array(idx - [1] * dim) acc = np.cumprod((1,) + shape[::-1][:-1]) return np.dot(idx, acc[::-1]) # my mark def mark(img, nbs): # mark the array use (0, 1, 2) img = img.ravel() for p in range(len(img)): if img[p] == 0: continue s = 0 for dp in nbs: if img[p + dp] != 0: s += 1 if s == 2: img[p] = 1 else: img[p] = 2 # trans index to r, c... def idx2rc(idx, acc): rst = np.zeros((len(idx), len(acc)), dtype=np.int16) for i in range(len(idx)): for j in range(len(acc)): rst[i, j] = idx[i] // acc[j] idx[i] -= rst[i, j] * acc[j] rst -= 1 return rst # fill a node (may be two or more points) def fill(img, p, num, nbs, acc, buf): back = img[p] img[p] = num buf[0] = p cur = 0; s = 1; while True: p = buf[cur] for dp in nbs: cp = p + dp if img[cp] == back: img[cp] = num buf[s] = cp s += 1 cur += 1 if cur == s: break return idx2rc(buf[:s], acc) # trace the edge and use a buffer, then buf.copy, if use [] numba not works def trace(img, p, nbs, acc, buf): c1 = 0; c2 = 0; newp = 0 cur = 0 while True: buf[cur] = p img[p] = 0 cur += 1 for dp in nbs: cp = p + dp if img[cp] >= 10: if c1 == 0: c1 = img[cp] else: c2 = img[cp] if img[cp] == 1: newp = cp p = newp if c2 != 0: break return (c1 - 10, c2 - 10, idx2rc(buf[:cur], acc)) # parse the image then get the nodes and edges def parse_struc(img, pts, nbs, acc): img = img.ravel() buf = np.zeros(131072, dtype=np.int64) num = 10 nodes = [] for p in pts: if img[p] == 2: nds = fill(img, p, num, nbs, acc, buf) num += 1 nodes.append(nds) edges = [] for p in pts: for dp in nbs: if img[p + dp] == 1: edge = trace(img, p + dp, nbs, acc, buf) edges.append(edge) return nodes, edges # use nodes and edges build a networkx graph def build_graph(nodes, edges, multi=False): graph = nx.MultiGraph() if multi else nx.Graph() for i in range(len(nodes)): graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0)) for s, e, pts in edges: l = np.linalg.norm(pts[1:] - pts[:-1], axis=1).sum() graph.add_edge(s, e, pts=pts, weight=l) return graph def buffer(ske): buf = np.zeros(tuple(np.array(ske.shape) + 2), dtype=np.uint16) buf[tuple([slice(1, -1)] * buf.ndim)] = ske return buf def build_sknw(ske, multi=False): buf = buffer(ske) nbs = neighbors(buf.shape) acc = np.cumprod((1,) + buf.shape[::-1][:-1])[::-1] mark(buf, nbs) pts = np.array(np.where(buf.ravel() == 2))[0] nodes, edges = parse_struc(buf, pts, nbs, acc) return build_graph(nodes, edges, multi) # draw the graph def draw_graph(img, graph, cn=255, ce=128): acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1] img = img.ravel() for idx in graph.nodes(): pts = graph.nodes[idx]['pts'] img[np.dot(pts, acc)] = cn for (s, e) in graph.edges(): eds = graph[s][e] for i in eds: pts = eds[i]['pts'] img[np.dot(pts, acc)] = ce def get_graph_from_skeleton(mask, subsample=False): graph = build_sknw(mask, False) if len(graph.nodes) > 1 and len(graph.edges) and subsample: graph = subsample_graph(graph, 3) return graph if __name__ == '__main__': g = nx.MultiGraph() g.add_nodes_from([1, 2, 3, 4, 5]) g.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 5), (5, 4)]) print(g.nodes()) print(g.edges()) a = g.subgraph(1) print('d') print(a) print('d')