Spaces:
Running
Running
| """Boxes for defining PyTorch models.""" | |
| from lynxkite.core import ops | |
| from lynxkite.core.ops import Parameter as P | |
| ENV = "PyTorch model" | |
| def reg(name, inputs=[], outputs=None, params=[]): | |
| if outputs is None: | |
| outputs = inputs | |
| return ops.register_passive_op( | |
| ENV, | |
| name, | |
| inputs=[ | |
| ops.Input(name=name, position="bottom", type="tensor") for name in inputs | |
| ], | |
| outputs=[ | |
| ops.Output(name=name, position="top", type="tensor") for name in outputs | |
| ], | |
| params=params, | |
| ) | |
| reg("Input: features", outputs=["x"]) | |
| reg("Input: graph edges", outputs=["edges"]) | |
| reg("Input: label", outputs=["y"]) | |
| reg("Input: positive sample", outputs=["x_pos"]) | |
| reg("Input: negative sample", outputs=["x_neg"]) | |
| reg("Attention", inputs=["q", "k", "v"], outputs=["x"]) | |
| reg("LayerNorm", inputs=["x"]) | |
| reg("Dropout", inputs=["x"], params=[P.basic("p", 0.5)]) | |
| reg("Linear", inputs=["x"], params=[P.basic("output_dim", "same")]) | |
| reg( | |
| "Graph conv", | |
| inputs=["x", "edges"], | |
| outputs=["x"], | |
| params=[P.options("type", ["GCNConv", "GATConv", "GATv2Conv", "SAGEConv"])], | |
| ) | |
| reg( | |
| "Activation", | |
| inputs=["x"], | |
| params=[P.options("type", ["ReLU", "LeakyReLU", "Tanh", "Mish"])], | |
| ) | |
| reg("Supervised loss", inputs=["x", "y"], outputs=["loss"]) | |
| reg("Triplet loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"]) | |
| reg( | |
| "Optimizer", | |
| inputs=["loss"], | |
| outputs=[], | |
| params=[ | |
| P.options( | |
| "type", | |
| [ | |
| "AdamW", | |
| "Adafactor", | |
| "Adagrad", | |
| "SGD", | |
| "Lion", | |
| "Paged AdamW", | |
| "Galore AdamW", | |
| ], | |
| ), | |
| P.basic("lr", 0.001), | |
| ], | |
| ) | |
| ops.register_passive_op( | |
| ENV, | |
| "Repeat", | |
| inputs=[ops.Input(name="input", position="top", type="tensor")], | |
| outputs=[ops.Output(name="output", position="bottom", type="tensor")], | |
| params=[ops.Parameter.basic("times", 1, int)], | |
| ) |