Spaces:
Paused
Paused
File size: 2,227 Bytes
7f3c2df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import itertools
from collections import defaultdict
class Tree(object):
def __init__(self, content, parent, depth):
self.content = content
self.children = list()
self.parent = parent
if parent is not None:
parent.expand(self)
self.depth = depth
self.attribute = dict()
def expand(self, child):
self.children.append(child)
def expand_set(self, children):
self.children += children
def isroot(self):
return self.parent is None
def isleaf(self):
return len(self.children) == 0
def get_subseq_trajs(self):
return [child.traj for child in self.children]
def get_all_leaves(self,leaf_set=[]):
if self.isleaf():
leaf_set.append(self)
else:
for child in self.children:
leaf_set = child.get_all_leaves(leaf_set)
return leaf_set
@staticmethod
def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True):
assert obj.depth<=depth
if nodes is None:
nodes = defaultdict(lambda: list())
if obj.depth==depth:
nodes[depth].append(obj)
return nodes, True
else:
if obj.isleaf():
return nodes, False
else:
flag = False
children_flags = dict()
for child in obj.children:
nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes)
children_flags[child] = child_flag
flag = flag or child_flag
if trim_short_branch:
obj.children = [child for child in obj.children if children_flags[child]]
if flag:
nodes[obj.depth].append(obj)
return nodes, flag
@staticmethod
def get_children(obj):
if isinstance(obj, Tree):
return obj.children
elif isinstance(obj, list):
children = [node.children for node in obj]
children = list(itertools.chain.from_iterable(children))
return children
else:
raise TypeError("obj must be a TrajTree or a list") |