darabos commited on
Commit
083e188
·
1 Parent(s): 6f123b5

High level boxes for Neural ODE + GNN demo.

Browse files
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -29,7 +29,11 @@ reg("Input: graph edges", outputs=["edges"])
29
  reg("Input: label", outputs=["y"])
30
  reg("Input: positive sample", outputs=["x_pos"])
31
  reg("Input: negative sample", outputs=["x_neg"])
 
 
32
 
 
 
33
  reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
34
  reg("LayerNorm", inputs=["x"])
35
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
@@ -82,6 +86,14 @@ ops.register_passive_op(
82
  params=[ops.Parameter.basic("times", 1, int)],
83
  )
84
 
 
 
 
 
 
 
 
 
85
 
86
  def build_model(ws: workspace.Workspace, inputs: dict):
87
  """Builds the model described in the workspace."""
 
29
  reg("Input: label", outputs=["y"])
30
  reg("Input: positive sample", outputs=["x_pos"])
31
  reg("Input: negative sample", outputs=["x_neg"])
32
+ reg("Input: sequential", outputs=["y"])
33
+ reg("Input: zeros", outputs=["x"])
34
 
35
+ reg("LSTM", inputs=["x", "h"], outputs=["x", "h"])
36
+ reg("Neural ODE", inputs=["x"])
37
  reg("Attention", inputs=["q", "k", "v"], outputs=["x", "weights"])
38
  reg("LayerNorm", inputs=["x"])
39
  reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)])
 
86
  params=[ops.Parameter.basic("times", 1, int)],
87
  )
88
 
89
+ ops.register_passive_op(
90
+ ENV,
91
+ "Recurrent chain",
92
+ inputs=[ops.Input(name="input", position="top", type="tensor")],
93
+ outputs=[ops.Output(name="output", position="bottom", type="tensor")],
94
+ params=[],
95
+ )
96
+
97
 
98
  def build_model(ws: workspace.Workspace, inputs: dict):
99
  """Builds the model described in the workspace."""