File size: 2,227 Bytes
7f3c2df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import itertools
from collections import defaultdict
class Tree(object):

    def __init__(self, content, parent, depth):
        self.content = content
        self.children = list()
        self.parent = parent
        if parent is not None:
            parent.expand(self)
        self.depth = depth
        self.attribute = dict()

    def expand(self, child):
        self.children.append(child)

    def expand_set(self, children):
        self.children += children

    def isroot(self):
        return self.parent is None

    def isleaf(self):
        return len(self.children) == 0

    def get_subseq_trajs(self):
        return [child.traj for child in self.children]


    def get_all_leaves(self,leaf_set=[]):
        if self.isleaf():
            leaf_set.append(self)
        else:
            for child in self.children:
                leaf_set = child.get_all_leaves(leaf_set)
        return leaf_set

    @staticmethod
    def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True):
        assert obj.depth<=depth
        if nodes is None:
            nodes = defaultdict(lambda: list())
        if obj.depth==depth:
            nodes[depth].append(obj)
            return nodes, True
        else:
            if obj.isleaf():
                return nodes, False

            else:
                flag = False
                children_flags = dict()
                for child in obj.children:
                    nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes)
                    children_flags[child] = child_flag
                    flag = flag or child_flag
                if trim_short_branch:
                    obj.children = [child for child in obj.children if children_flags[child]]
                if flag:
                    nodes[obj.depth].append(obj)
                return nodes, flag

    @staticmethod
    def get_children(obj):
        if isinstance(obj, Tree):
            return obj.children
        elif isinstance(obj, list):
            children = [node.children for node in obj]
            children = list(itertools.chain.from_iterable(children))
            return children
        else:
            raise TypeError("obj must be a TrajTree or a list")