Spaces:
Running
Running
| '''Boxes for defining and using PyTorch models.''' | |
| from enum import Enum | |
| import inspect | |
| from . import ops | |
| LAYERS = {} | |
| def define_pytorch_model(*, sub_flow): | |
| print('sub_flow:', sub_flow) | |
| return ops.Bundle(other={'model': str(sub_flow)}) | |
| def train_pytorch_model(model, graph): | |
| # import torch # Lazy import because it's slow. | |
| return 'hello ' + str(model) | |
| def register_layer(name): | |
| def decorator(func): | |
| sig = inspect.signature(func) | |
| inputs = { | |
| name: param.annotation | |
| for name, param in sig.parameters.items() | |
| if param.kind != param.KEYWORD_ONLY} | |
| params = { | |
| name: ops.Parameter(name, param.default, param.annotation) | |
| for name, param in sig.parameters.items() | |
| if param.kind == param.KEYWORD_ONLY} | |
| outputs = {'x': 'tensor'} | |
| LAYERS[name] = ops.Op(func, name, params=params, inputs=inputs, outputs=outputs, type='vertical') | |
| return func | |
| return decorator | |
| def layernorm(x): | |
| return 'LayerNorm' | |
| def dropout(x, *, p=0.5): | |
| return f'Dropout ({p})' | |
| def linear(*, output_dim: int): | |
| return f'Linear {output_dim}' | |
| class GraphConv(Enum): | |
| GCNConv = 'GCNConv' | |
| GATConv = 'GATConv' | |
| GATv2Conv = 'GATv2Conv' | |
| SAGEConv = 'SAGEConv' | |
| def graph_convolution(x, edges, *, type: GraphConv): | |
| return 'GraphConv' | |
| class Nonlinearity(Enum): | |
| Mish = 'Mish' | |
| ReLU = 'ReLU' | |
| Tanh = 'Tanh' | |
| def nonlinearity(x, *, type: Nonlinearity): | |
| return 'ReLU' | |
| def register_area(name, params=[]): | |
| '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.''' | |
| op = ops.Op(ops.no_op, name, params={p.name: p for p in params}, inputs={}, outputs={}, type='area') | |
| LAYERS[name] = op | |
| register_area('Repeat', params=[ops.Parameter('times', 1, int)]) | |