hyzhou404's picture
private scenes
7f3c2df
import itertools
from collections import defaultdict
class Tree(object):
def __init__(self, content, parent, depth):
self.content = content
self.children = list()
self.parent = parent
if parent is not None:
parent.expand(self)
self.depth = depth
self.attribute = dict()
def expand(self, child):
self.children.append(child)
def expand_set(self, children):
self.children += children
def isroot(self):
return self.parent is None
def isleaf(self):
return len(self.children) == 0
def get_subseq_trajs(self):
return [child.traj for child in self.children]
def get_all_leaves(self,leaf_set=[]):
if self.isleaf():
leaf_set.append(self)
else:
for child in self.children:
leaf_set = child.get_all_leaves(leaf_set)
return leaf_set
@staticmethod
def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True):
assert obj.depth<=depth
if nodes is None:
nodes = defaultdict(lambda: list())
if obj.depth==depth:
nodes[depth].append(obj)
return nodes, True
else:
if obj.isleaf():
return nodes, False
else:
flag = False
children_flags = dict()
for child in obj.children:
nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes)
children_flags[child] = child_flag
flag = flag or child_flag
if trim_short_branch:
obj.children = [child for child in obj.children if children_flags[child]]
if flag:
nodes[obj.depth].append(obj)
return nodes, flag
@staticmethod
def get_children(obj):
if isinstance(obj, Tree):
return obj.children
elif isinstance(obj, list):
children = [node.children for node in obj]
children = list(itertools.chain.from_iterable(children))
return children
else:
raise TypeError("obj must be a TrajTree or a list")