File size: 4,569 Bytes
ca01fa3
 
b6d30cb
ca01fa3
b6d30cb
75c875f
ca01fa3
75c875f
 
 
ca01fa3
 
 
 
75c875f
9e91869
ca01fa3
 
 
75c875f
0c44583
abadddf
0c44583
b6d30cb
 
75c875f
 
 
d4a220c
 
 
b6d30cb
 
 
 
 
ca01fa3
75c875f
0c44583
b6d30cb
 
 
 
76e9e8e
b6d30cb
abadddf
b6d30cb
abadddf
a6b7675
 
b6d30cb
a6b7675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abadddf
a6b7675
 
 
 
 
 
 
 
 
 
9e91869
b6d30cb
a18645a
75c875f
3885cb0
a18645a
 
 
 
3885cb0
 
 
a18645a
 
75c875f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
'''Some operations. To be split into separate files when we have more.'''
from . import ops
import matplotlib
import networkx as nx
import pandas as pd
import traceback

op = ops.op_registration('LynxKite')

@op("Import Parquet")
def import_parquet(*, filename: str):
  '''Imports a parquet file.'''
  return pd.read_parquet(filename)

@op("Create scale-free graph")
def create_scale_free_graph(*, nodes: int = 10):
  '''Creates a scale-free graph with the given number of nodes.'''
  return nx.scale_free_graph(nodes)

@op("Compute PageRank")
@ops.nx_node_attribute_func('pagerank')
def compute_pagerank(graph: nx.Graph, *, damping=0.85, iterations=100):
  return nx.pagerank(graph, alpha=damping, max_iter=iterations)


@op("Sample graph")
def sample_graph(graph: nx.Graph, *, nodes: int = 100):
  '''Takes a subgraph.'''
  return nx.scale_free_graph(nodes)


def _map_color(value):
  cmap = matplotlib.cm.get_cmap('viridis')
  value = (value - value.min()) / (value.max() - value.min())
  rgba = cmap(value)
  return ['#{:02x}{:02x}{:02x}'.format(int(r*255), int(g*255), int(b*255)) for r, g, b in rgba[:, :3]]

@op("Visualize graph", view="visualization")
def visualize_graph(graph: ops.Bundle, *, color_nodes_by: 'node_attribute' = None):
  nodes = graph.dfs['nodes'].copy()
  if color_nodes_by:
    nodes['color'] = _map_color(nodes[color_nodes_by])
  nodes = nodes.to_records()
  edges = graph.dfs['edges'].drop_duplicates(['source', 'target'])
  edges = edges.to_records()
  pos = nx.spring_layout(graph.to_nx(), iterations=max(1, int(10000/len(nodes))))
  v = {
    'animationDuration': 500,
    'animationEasingUpdate': 'quinticInOut',
    'series': [
      {
        'type': 'graph',
        'roam': True,
        'lineStyle': {
          'color': 'gray',
          'curveness': 0.3,
        },
        'emphasis': {
          'focus': 'adjacency',
          'lineStyle': {
            'width': 10,
          }
        },
        'data': [
          {
            'id': str(n.id),
            'x': float(pos[n.id][0]), 'y': float(pos[n.id][1]),
            # Adjust node size to cover the same area no matter how many nodes there are.
            'symbolSize': 50 / len(nodes) ** 0.5,
            'itemStyle': {'color': n.color} if color_nodes_by else {},
          }
          for n in nodes],
        'links': [
          {'source': str(r.source), 'target': str(r.target)}
          for r in edges],
      },
    ],
  }
  return v

@op("View tables", view="table_view")
def view_tables(bundle: ops.Bundle):
  v = {
    'dataframes': { name: {
      'columns': [str(c) for c in df.columns],
      'data': df.values.tolist(),
    } for name, df in bundle.dfs.items() },
    'relations': bundle.relations,
    'other': bundle.other,
  }
  return v

@ops.register_executor('LynxKite')
def execute(ws):
    catalog = ops.CATALOGS['LynxKite']
    # Nodes are responsible for interpreting/executing their child nodes.
    nodes = [n for n in ws.nodes if not n.parentId]
    children = {}
    for n in ws.nodes:
        if n.parentId:
            children.setdefault(n.parentId, []).append(n)
    outputs = {}
    failed = 0
    while len(outputs) + failed < len(nodes):
        for node in nodes:
            if node.id in outputs:
                continue
            inputs = [edge.source for edge in ws.edges if edge.target == node.id]
            if all(input in outputs for input in inputs):
                inputs = [outputs[input] for input in inputs]
                data = node.data
                op = catalog[data.title]
                params = {**data.params}
                if op.sub_nodes:
                    sub_nodes = children.get(node.id, [])
                    sub_node_ids = [node.id for node in sub_nodes]
                    sub_edges = [edge for edge in ws.edges if edge.source in sub_node_ids]
                    params['sub_flow'] = {'nodes': sub_nodes, 'edges': sub_edges}
                try:
                  output = op(*inputs, **params)
                except Exception as e:
                  traceback.print_exc()
                  data.error = str(e)
                  failed += 1
                  continue
                if len(op.inputs) == 1 and op.inputs.get('multi') == '*':
                    # It's a flexible input. Create n+1 handles.
                    data.inputs = {f'input{i}': None for i in range(len(inputs) + 1)}
                data.error = None
                outputs[node.id] = output
                if op.type == 'visualization' or op.type == 'table_view':
                    data.view = output