Spaces:
Paused
Paused
import itertools | |
from collections import defaultdict | |
class Tree(object): | |
def __init__(self, content, parent, depth): | |
self.content = content | |
self.children = list() | |
self.parent = parent | |
if parent is not None: | |
parent.expand(self) | |
self.depth = depth | |
self.attribute = dict() | |
def expand(self, child): | |
self.children.append(child) | |
def expand_set(self, children): | |
self.children += children | |
def isroot(self): | |
return self.parent is None | |
def isleaf(self): | |
return len(self.children) == 0 | |
def get_subseq_trajs(self): | |
return [child.traj for child in self.children] | |
def get_all_leaves(self,leaf_set=[]): | |
if self.isleaf(): | |
leaf_set.append(self) | |
else: | |
for child in self.children: | |
leaf_set = child.get_all_leaves(leaf_set) | |
return leaf_set | |
def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True): | |
assert obj.depth<=depth | |
if nodes is None: | |
nodes = defaultdict(lambda: list()) | |
if obj.depth==depth: | |
nodes[depth].append(obj) | |
return nodes, True | |
else: | |
if obj.isleaf(): | |
return nodes, False | |
else: | |
flag = False | |
children_flags = dict() | |
for child in obj.children: | |
nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes) | |
children_flags[child] = child_flag | |
flag = flag or child_flag | |
if trim_short_branch: | |
obj.children = [child for child in obj.children if children_flags[child]] | |
if flag: | |
nodes[obj.depth].append(obj) | |
return nodes, flag | |
def get_children(obj): | |
if isinstance(obj, Tree): | |
return obj.children | |
elif isinstance(obj, list): | |
children = [node.children for node in obj] | |
children = list(itertools.chain.from_iterable(children)) | |
return children | |
else: | |
raise TypeError("obj must be a TrajTree or a list") |