File size: 2,110 Bytes
3010d5b
3885cb0
3010d5b
 
 
 
 
75c875f
 
 
3010d5b
 
3885cb0
 
75c875f
3885cb0
 
 
3010d5b
 
 
 
 
28c40a9
3010d5b
 
 
d4a220c
3010d5b
 
28c40a9
 
3010d5b
 
 
 
801415b
3010d5b
 
 
801415b
3010d5b
 
 
 
 
 
3885cb0
 
 
 
 
 
3010d5b
801415b
3010d5b
 
3885cb0
 
 
 
 
3010d5b
801415b
3010d5b
 
801415b
 
d4a220c
801415b
 
d4a220c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''Boxes for defining and using PyTorch models.'''
from enum import Enum
import inspect
from . import ops

LAYERS = {}

op = ops.op_registration('LynxKite')

@op("Define PyTorch model", sub_nodes=LAYERS)
def define_pytorch_model(*, sub_flow):
  print('sub_flow:', sub_flow)
  return ops.Bundle(other={'model': str(sub_flow)})

@op("Train PyTorch model")
def train_pytorch_model(model, graph):
  # import torch # Lazy import because it's slow.
  return 'hello ' + str(model)

def register_layer(name):
  def decorator(func):
    sig = inspect.signature(func)
    inputs = {
      name: ops.Input(name=name, type=param.annotation, position='bottom')
      for name, param in sig.parameters.items()
      if param.kind != param.KEYWORD_ONLY}
    params = {
      name: ops.Parameter.basic(name, param.default, param.annotation)
      for name, param in sig.parameters.items()
      if param.kind == param.KEYWORD_ONLY}
    outputs = {'x': ops.Output(name='x', type='tensor', position='top')}
    LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs)
    return func
  return decorator

@register_layer('LayerNorm')
def layernorm(x):
  return 'LayerNorm'

@register_layer('Dropout')
def dropout(x, *, p=0.5):
  return f'Dropout ({p})'

@register_layer('Linear')
def linear(*, output_dim: int):
  return f'Linear {output_dim}'

class GraphConv(Enum):
  GCNConv = 'GCNConv'
  GATConv = 'GATConv'
  GATv2Conv = 'GATv2Conv'
  SAGEConv = 'SAGEConv'

@register_layer('Graph Convolution')
def graph_convolution(x, edges, *, type: GraphConv):
  return 'GraphConv'

class Nonlinearity(Enum):
  Mish = 'Mish'
  ReLU = 'ReLU'
  Tanh = 'Tanh'

@register_layer('Nonlinearity')
def nonlinearity(x, *, type: Nonlinearity):
  return 'ReLU'

def register_area(name, params=[]):
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
  op = ops.Op(func=ops.no_op, name=name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area')
  LAYERS[name] = op

register_area('Repeat', params=[ops.Parameter.basic('times', 1, int)])