darabos commited on
Commit
28c40a9
·
1 Parent(s): abb1488

Control plug position for each input/output.

Browse files
server/ops.py CHANGED
@@ -50,14 +50,30 @@ class Parameter(BaseConfig):
50
  type = typeof(default) if default else None
51
  return Parameter(name=name, default=default, type=type)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  class Op(BaseConfig):
55
  func: callable = pydantic.Field(exclude=True)
56
  name: str
57
  params: dict[str, Parameter]
58
- inputs: dict[str, Type] # name -> type
59
- outputs: dict[str, Type] # name -> type
60
- type: str # The UI to use for this operation.
61
  sub_nodes: list[Op] = None # If set, these nodes can be placed inside the operation's node.
62
 
63
  def __call__(self, *inputs, **params):
@@ -70,10 +86,10 @@ class Op(BaseConfig):
70
  params[p] = float(params[p])
71
  # Convert inputs.
72
  inputs = list(inputs)
73
- for i, (x, t) in enumerate(zip(inputs, self.inputs.values())):
74
- if t == nx.Graph and isinstance(x, Bundle):
75
  inputs[i] = x.to_nx()
76
- elif t == Bundle and isinstance(x, nx.Graph):
77
  inputs[i] = Bundle.from_nx(x)
78
  res = self.func(*inputs, **params)
79
  return res
@@ -147,14 +163,14 @@ def op(name, *, view='basic', sub_nodes=None):
147
  sig = inspect.signature(func)
148
  # Positional arguments are inputs.
149
  inputs = {
150
- name: param.annotation
151
  for name, param in sig.parameters.items()
152
  if param.kind != param.KEYWORD_ONLY}
153
  params = {}
154
  for n, param in sig.parameters.items():
155
  if param.kind == param.KEYWORD_ONLY:
156
  params[n] = Parameter.basic(n, param.default, param.annotation)
157
- outputs = {'output': 'yes'} if view == 'basic' else {} # Maybe more fancy later.
158
  op = Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs, type=view)
159
  if sub_nodes is not None:
160
  op.sub_nodes = sub_nodes
@@ -168,13 +184,22 @@ def no_op(*args, **kwargs):
168
  return args[0]
169
  return Bundle()
170
 
171
- def register_passive_op(name, inputs={'input': Bundle}, outputs={'output': Bundle}, params=[]):
172
  '''A passive operation has no associated code.'''
173
- op = Op(no_op, name, params={p.name: p for p in params}, inputs=inputs, outputs=outputs, type='basic')
 
 
 
 
 
 
 
 
 
174
  ALL_OPS[name] = op
175
  return op
176
 
177
  def register_area(name, params=[]):
178
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
179
- op = register_passive_op(name, params=params, inputs={}, outputs={})
180
  op.type = 'area'
 
50
  type = typeof(default) if default else None
51
  return Parameter(name=name, default=default, type=type)
52
 
53
+ class Input(BaseConfig):
54
+ name: str
55
+ type: Type
56
+ position: str = 'left'
57
+
58
+ class Output(BaseConfig):
59
+ name: str
60
+ type: Type
61
+ position: str = 'right'
62
+
63
+ MULTI_INPUT = Input(name='multi', type='*')
64
+ def basic_inputs(*names):
65
+ return {name: Input(name=name, type=None) for name in names}
66
+ def basic_outputs(*names):
67
+ return {name: Output(name=name, type=None) for name in names}
68
+
69
 
70
  class Op(BaseConfig):
71
  func: callable = pydantic.Field(exclude=True)
72
  name: str
73
  params: dict[str, Parameter]
74
+ inputs: dict[str, Input]
75
+ outputs: dict[str, Output]
76
+ type: str = 'basic' # The UI to use for this operation.
77
  sub_nodes: list[Op] = None # If set, these nodes can be placed inside the operation's node.
78
 
79
  def __call__(self, *inputs, **params):
 
86
  params[p] = float(params[p])
87
  # Convert inputs.
88
  inputs = list(inputs)
89
+ for i, (x, p) in enumerate(zip(inputs, self.inputs.values())):
90
+ if p.type == nx.Graph and isinstance(x, Bundle):
91
  inputs[i] = x.to_nx()
92
+ elif p.type == Bundle and isinstance(x, nx.Graph):
93
  inputs[i] = Bundle.from_nx(x)
94
  res = self.func(*inputs, **params)
95
  return res
 
163
  sig = inspect.signature(func)
164
  # Positional arguments are inputs.
165
  inputs = {
166
+ name: Input(name=name, type=param.annotation)
167
  for name, param in sig.parameters.items()
168
  if param.kind != param.KEYWORD_ONLY}
169
  params = {}
170
  for n, param in sig.parameters.items():
171
  if param.kind == param.KEYWORD_ONLY:
172
  params[n] = Parameter.basic(n, param.default, param.annotation)
173
+ outputs = {'output': Output(name='output', type=None)} if view == 'basic' else {} # Maybe more fancy later.
174
  op = Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs, type=view)
175
  if sub_nodes is not None:
176
  op.sub_nodes = sub_nodes
 
184
  return args[0]
185
  return Bundle()
186
 
187
+ def register_passive_op(name, inputs=[], outputs=['output'], params=[]):
188
  '''A passive operation has no associated code.'''
189
+ op = Op(
190
+ func=no_op,
191
+ name=name,
192
+ params={p.name: p for p in params},
193
+ inputs=dict(
194
+ (i, Input(name=i, type=None)) if isinstance(i, str)
195
+ else (i.name, i) for i in inputs),
196
+ outputs=dict(
197
+ (o, Output(name=o, type=None)) if isinstance(o, str)
198
+ else (o.name, o) for o in outputs))
199
  ALL_OPS[name] = op
200
  return op
201
 
202
  def register_area(name, params=[]):
203
  '''A node that represents an area. It can contain other nodes, but does not restrict movement in any way.'''
204
+ op = register_passive_op(name, params=params)
205
  op.type = 'area'
server/pytorch_model_ops.py CHANGED
@@ -19,15 +19,15 @@ def register_layer(name):
19
  def decorator(func):
20
  sig = inspect.signature(func)
21
  inputs = {
22
- name: param.annotation
23
  for name, param in sig.parameters.items()
24
  if param.kind != param.KEYWORD_ONLY}
25
  params = {
26
  name: ops.Parameter.basic(name, param.default, param.annotation)
27
  for name, param in sig.parameters.items()
28
  if param.kind == param.KEYWORD_ONLY}
29
- outputs = {'x': 'tensor'}
30
- LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs, type='vertical')
31
  return func
32
  return decorator
33
 
 
19
  def decorator(func):
20
  sig = inspect.signature(func)
21
  inputs = {
22
+ name: ops.Input(name=name, type=param.annotation, position='bottom')
23
  for name, param in sig.parameters.items()
24
  if param.kind != param.KEYWORD_ONLY}
25
  params = {
26
  name: ops.Parameter.basic(name, param.default, param.annotation)
27
  for name, param in sig.parameters.items()
28
  if param.kind == param.KEYWORD_ONLY}
29
+ outputs = {'x': ops.Output(name='x', type='tensor', position='top')}
30
+ LAYERS[name] = ops.Op(func=func, name=name, params=params, inputs=inputs, outputs=outputs)
31
  return func
32
  return decorator
33
 
server/workspace.py CHANGED
@@ -21,6 +21,8 @@ class WorkspaceNodeData(BaseConfig):
21
  params: dict
22
  display: Optional[object] = None
23
  error: Optional[str] = None
 
 
24
 
25
  class WorkspaceNode(BaseConfig):
26
  id: str
@@ -118,6 +120,7 @@ def _update_metadata(ws):
118
  continue
119
  if op:
120
  data.meta = op
 
121
  if data.error == 'Unknown operation.':
122
  data.error = None
123
  else:
 
21
  params: dict
22
  display: Optional[object] = None
23
  error: Optional[str] = None
24
+ # Also contains a "meta" field when going out.
25
+ # This is ignored when coming back from the frontend.
26
 
27
  class WorkspaceNode(BaseConfig):
28
  id: str
 
120
  continue
121
  if op:
122
  data.meta = op
123
+ node.type = op.type
124
  if data.error == 'Unknown operation.':
125
  data.error = None
126
  else:
web/src/LynxKiteFlow.svelte CHANGED
@@ -16,7 +16,6 @@
16
  } from '@xyflow/svelte';
17
  import { useQuery, useMutation, useQueryClient } from '@sveltestack/svelte-query';
18
  import NodeWithParams from './NodeWithParams.svelte';
19
- import NodeWithParamsVertical from './NodeWithParamsVertical.svelte';
20
  import NodeWithVisualization from './NodeWithVisualization.svelte';
21
  import NodeWithTableView from './NodeWithTableView.svelte';
22
  import NodeWithSubFlow from './NodeWithSubFlow.svelte';
@@ -47,7 +46,6 @@
47
 
48
  const nodeTypes: NodeTypes = {
49
  basic: NodeWithParams,
50
- vertical: NodeWithParamsVertical,
51
  visualization: NodeWithVisualization,
52
  table_view: NodeWithTableView,
53
  sub_flow: NodeWithSubFlow,
@@ -88,8 +86,6 @@
88
  title: meta.name,
89
  params: Object.fromEntries(
90
  Object.values(meta.params).map((p) => [p.name, p.default])),
91
- inputs: meta.inputs,
92
- outputs: meta.outputs,
93
  },
94
  };
95
  node.position = screenToFlowPosition({x: nodeSearchSettings.pos.x, y: nodeSearchSettings.pos.y});
 
16
  } from '@xyflow/svelte';
17
  import { useQuery, useMutation, useQueryClient } from '@sveltestack/svelte-query';
18
  import NodeWithParams from './NodeWithParams.svelte';
 
19
  import NodeWithVisualization from './NodeWithVisualization.svelte';
20
  import NodeWithTableView from './NodeWithTableView.svelte';
21
  import NodeWithSubFlow from './NodeWithSubFlow.svelte';
 
46
 
47
  const nodeTypes: NodeTypes = {
48
  basic: NodeWithParams,
 
49
  visualization: NodeWithVisualization,
50
  table_view: NodeWithTableView,
51
  sub_flow: NodeWithSubFlow,
 
86
  title: meta.name,
87
  params: Object.fromEntries(
88
  Object.values(meta.params).map((p) => [p.name, p.default])),
 
 
89
  },
90
  };
91
  node.position = screenToFlowPosition({x: nodeSearchSettings.pos.x, y: nodeSearchSettings.pos.y});
web/src/LynxKiteNode.svelte CHANGED
@@ -29,8 +29,8 @@
29
  function asPx(n: number | undefined) {
30
  return n ? n + 'px' : undefined;
31
  }
32
- $: inputs = Object.entries(data.inputs || {});
33
- $: outputs = Object.entries(data.outputs || {});
34
  const handleOffsetDirection = { top: 'left', bottom: 'left', left: 'top', right: 'top' };
35
  </script>
36
 
@@ -47,18 +47,18 @@
47
  {/if}
48
  <slot />
49
  {/if}
50
- {#each inputs as [name, input], i}
51
  <Handle
52
- id={name} type="target" position={targetPosition || 'left'}
53
- style="{handleOffsetDirection[targetPosition || 'left']}: {100 * (i + 1) / (inputs.length + 1)}%">
54
- {#if inputs.length>1}<span class="handle-name">{name.replace(/_/g, " ")}</span>{/if}
55
  </Handle>
56
  {/each}
57
- {#each outputs as [name, output], i}
58
  <Handle
59
- id={name} type="source" position={sourcePosition || 'right'}
60
- style="{handleOffsetDirection[sourcePosition || 'right']}: {100 * (i + 1) / (outputs.length + 1)}%">
61
- {#if outputs.length>1}<span class="handle-name">{name.replace(/_/g, " ")}</span>{/if}
62
  </Handle>
63
  {/each}
64
  </div>
 
29
  function asPx(n: number | undefined) {
30
  return n ? n + 'px' : undefined;
31
  }
32
+ $: inputs = Object.values(data.meta?.inputs || {});
33
+ $: outputs = Object.values(data.meta?.outputs || {});
34
  const handleOffsetDirection = { top: 'left', bottom: 'left', left: 'top', right: 'top' };
35
  </script>
36
 
 
47
  {/if}
48
  <slot />
49
  {/if}
50
+ {#each inputs as input, i}
51
  <Handle
52
+ id={input.name} type="target" position={input.position}
53
+ style="{handleOffsetDirection[input.position]}: {100 * (i + 1) / (inputs.length + 1)}%">
54
+ {#if inputs.length>1}<span class="handle-name">{input.name.replace(/_/g, " ")}</span>{/if}
55
  </Handle>
56
  {/each}
57
+ {#each outputs as output, i}
58
  <Handle
59
+ id={output.name} type="source" position={output.position}
60
+ style="{handleOffsetDirection[output.position]}: {100 * (i + 1) / (outputs.length + 1)}%">
61
+ {#if outputs.length>1}<span class="handle-name">{output.name.replace(/_/g, " ")}</span>{/if}
62
  </Handle>
63
  {/each}
64
  </div>
web/src/NodeWithParamsVertical.svelte DELETED
@@ -1,33 +0,0 @@
1
- <script lang="ts">
2
- import { type NodeProps, useSvelteFlow } from '@xyflow/svelte';
3
- import LynxKiteNode from './LynxKiteNode.svelte';
4
- import NodeParameter from './NodeParameter.svelte';
5
- type $$Props = NodeProps;
6
- export let id: $$Props['id'];
7
- export let data: $$Props['data'];
8
- const { updateNodeData } = useSvelteFlow();
9
- $: metaParams = data.meta?.params;
10
- </script>
11
-
12
- <LynxKiteNode {...$$props} sourcePosition="top" targetPosition="bottom">
13
- {#each Object.entries(data.params) as [name, value]}
14
- <NodeParameter
15
- {name}
16
- {value}
17
- meta={metaParams?.[name]}
18
- onChange={(newValue) => updateNodeData(id, { params: { ...data.params, [name]: newValue } })}
19
- />
20
- {/each}
21
- </LynxKiteNode>
22
- <style>
23
- .param {
24
- padding: 8px;
25
- }
26
- .param label {
27
- font-size: 12px;
28
- display: block;
29
- }
30
- .param input {
31
- width: calc(100% - 8px);
32
- }
33
- </style>