Spaces:
Running
Running
Repeat boxes working.
Browse files
examples/Model definition
CHANGED
|
@@ -438,8 +438,8 @@
|
|
| 438 |
"height": 200.0,
|
| 439 |
"id": "MSE loss 2",
|
| 440 |
"position": {
|
| 441 |
-
"x":
|
| 442 |
-
"y": -
|
| 443 |
},
|
| 444 |
"type": "basic",
|
| 445 |
"width": 200.0
|
|
|
|
| 438 |
"height": 200.0,
|
| 439 |
"id": "MSE loss 2",
|
| 440 |
"position": {
|
| 441 |
+
"x": 309.4422414664647,
|
| 442 |
+
"y": -552.1056805642488
|
| 443 |
},
|
| 444 |
"type": "basic",
|
| 445 |
"width": 200.0
|
examples/Model use
CHANGED
|
@@ -579,54 +579,54 @@
|
|
| 579 |
],
|
| 580 |
"data": [
|
| 581 |
[
|
| 582 |
-
"[0.
|
| 583 |
-
"[1.
|
| 584 |
-
"[1.
|
| 585 |
],
|
| 586 |
[
|
| 587 |
-
"[0.
|
| 588 |
-
"[1.
|
| 589 |
-
"[1.
|
| 590 |
],
|
| 591 |
[
|
| 592 |
-
"[0.
|
| 593 |
-
"[1.
|
| 594 |
-
"[1.
|
| 595 |
],
|
| 596 |
[
|
| 597 |
-
"[0.
|
| 598 |
-
"[1.
|
| 599 |
-
"[1.
|
| 600 |
],
|
| 601 |
[
|
| 602 |
-
"[0.
|
| 603 |
-
"[1.
|
| 604 |
-
"[1.
|
| 605 |
],
|
| 606 |
[
|
| 607 |
-
"[0.
|
| 608 |
-
"[1.
|
| 609 |
-
"[1.
|
| 610 |
],
|
| 611 |
[
|
| 612 |
-
"[0.
|
| 613 |
-
"[1.
|
| 614 |
-
"[1.
|
| 615 |
],
|
| 616 |
[
|
| 617 |
-
"[0.
|
| 618 |
-
"[1.
|
| 619 |
-
"[1.
|
| 620 |
],
|
| 621 |
[
|
| 622 |
-
"[0.
|
| 623 |
-
"[1.
|
| 624 |
-
"[1.
|
| 625 |
],
|
| 626 |
[
|
| 627 |
-
"[0.
|
| 628 |
-
"[1.
|
| 629 |
-
"[1.
|
| 630 |
]
|
| 631 |
]
|
| 632 |
},
|
|
@@ -648,6 +648,10 @@
|
|
| 648 |
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
| 649 |
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
| 650 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
[
|
| 652 |
"[0.76807946 0.98855817 0.08259124 0.01730657]",
|
| 653 |
"[1.76807952 1.98855817 1.0825913 1.01730657]"
|
|
@@ -681,12 +685,12 @@
|
|
| 681 |
"[1.98324287 1.99464178 1.14008355 1.47651017]"
|
| 682 |
],
|
| 683 |
[
|
| 684 |
-
"[0.
|
| 685 |
-
"[1.
|
| 686 |
],
|
| 687 |
[
|
| 688 |
-
"[0.
|
| 689 |
-
"[1.
|
| 690 |
],
|
| 691 |
[
|
| 692 |
"[0.04508126 0.76880038 0.80721325 0.62542385]",
|
|
@@ -708,6 +712,10 @@
|
|
| 708 |
"[0.24388778 0.07268471 0.68350857 0.73431659]",
|
| 709 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
| 710 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
[
|
| 712 |
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
| 713 |
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
|
@@ -720,6 +728,10 @@
|
|
| 720 |
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
| 721 |
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
| 722 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
[
|
| 724 |
"[0.68062544 0.98093534 0.14778823 0.53244978]",
|
| 725 |
"[1.68062544 1.98093534 1.14778829 1.53244972]"
|
|
@@ -740,10 +752,6 @@
|
|
| 740 |
"[0.23942459 0.90487361 0.69337189 0.65089428]",
|
| 741 |
"[1.23942459 1.90487361 1.69337189 1.65089428]"
|
| 742 |
],
|
| 743 |
-
[
|
| 744 |
-
"[0.94516498 0.08422136 0.5608117 0.07652664]",
|
| 745 |
-
"[1.94516492 1.08422136 1.56081176 1.07652664]"
|
| 746 |
-
],
|
| 747 |
[
|
| 748 |
"[0.26661873 0.45946234 0.13510543 0.81294441]",
|
| 749 |
"[1.26661873 1.4594624 1.13510537 1.81294441]"
|
|
@@ -772,10 +780,6 @@
|
|
| 772 |
"[0.78956431 0.87284744 0.06880784 0.03455889]",
|
| 773 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
| 774 |
],
|
| 775 |
-
[
|
| 776 |
-
"[0.94221359 0.57740951 0.98649532 0.40934443]",
|
| 777 |
-
"[1.94221354 1.57740951 1.98649526 1.40934443]"
|
| 778 |
-
],
|
| 779 |
[
|
| 780 |
"[0.00497234 0.39319336 0.57054168 0.75150961]",
|
| 781 |
"[1.00497234 1.39319336 1.57054162 1.75150967]"
|
|
@@ -788,6 +792,14 @@
|
|
| 788 |
"[0.72290605 0.96945059 0.68354797 0.15270454]",
|
| 789 |
"[1.72290611 1.96945059 1.68354797 1.15270448]"
|
| 790 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
[
|
| 792 |
"[0.52784437 0.54268694 0.12358981 0.72116476]",
|
| 793 |
"[1.52784443 1.54268694 1.12358975 1.7211647 ]"
|
|
@@ -796,10 +808,6 @@
|
|
| 796 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
| 797 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
| 798 |
],
|
| 799 |
-
[
|
| 800 |
-
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
| 801 |
-
"[1.34084332 1.73018837 1.54168713 1.91440833]"
|
| 802 |
-
],
|
| 803 |
[
|
| 804 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
| 805 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
|
@@ -836,6 +844,10 @@
|
|
| 836 |
"[0.95928186 0.84273899 0.71514636 0.38619852]",
|
| 837 |
"[1.95928192 1.84273899 1.7151463 1.38619852]"
|
| 838 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
[
|
| 840 |
"[0.9829582 0.59269661 0.40120947 0.95487177]",
|
| 841 |
"[1.9829582 1.59269667 1.40120947 1.95487177]"
|
|
@@ -844,10 +856,6 @@
|
|
| 844 |
"[0.79905868 0.89367443 0.75429088 0.3190186 ]",
|
| 845 |
"[1.79905868 1.89367437 1.75429082 1.3190186 ]"
|
| 846 |
],
|
| 847 |
-
[
|
| 848 |
-
"[0.54914117 0.03810108 0.87531954 0.73044223]",
|
| 849 |
-
"[1.54914117 1.03810108 1.87531948 1.73044229]"
|
| 850 |
-
],
|
| 851 |
[
|
| 852 |
"[0.67418337 0.79634351 0.23229051 0.71345252]",
|
| 853 |
"[1.67418337 1.79634356 1.23229051 1.71345258]"
|
|
@@ -860,14 +868,14 @@
|
|
| 860 |
"[0.81788456 0.58174163 0.29376316 0.7971254 ]",
|
| 861 |
"[1.81788456 1.58174157 1.29376316 1.79712534]"
|
| 862 |
],
|
| 863 |
-
[
|
| 864 |
-
"[0.94559073 0.65736622 0.25761551 0.48553199]",
|
| 865 |
-
"[1.94559073 1.65736628 1.25761557 1.48553205]"
|
| 866 |
-
],
|
| 867 |
[
|
| 868 |
"[0.60075855 0.12234765 0.00614399 0.30560958]",
|
| 869 |
"[1.60075855 1.12234759 1.00614405 1.30560958]"
|
| 870 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
[
|
| 872 |
"[0.02162331 0.81861657 0.92468154 0.07808572]",
|
| 873 |
"[1.02162337 1.81861663 1.92468154 1.07808566]"
|
|
@@ -896,10 +904,6 @@
|
|
| 896 |
"[0.60609657 0.96257663 0.19292736 0.95702219]",
|
| 897 |
"[1.60609651 1.96257663 1.19292736 1.95702219]"
|
| 898 |
],
|
| 899 |
-
[
|
| 900 |
-
"[0.80654246 0.08253473 0.74478531 0.71257162]",
|
| 901 |
-
"[1.8065424 1.08253479 1.74478531 1.71257162]"
|
| 902 |
-
],
|
| 903 |
[
|
| 904 |
"[0.70167565 0.26930219 0.5660674 0.61194974]",
|
| 905 |
"[1.70167565 1.26930213 1.56606746 1.61194968]"
|
|
@@ -908,10 +912,6 @@
|
|
| 908 |
"[0.76933283 0.86241865 0.44114518 0.65644735]",
|
| 909 |
"[1.76933289 1.86241865 1.44114518 1.65644741]"
|
| 910 |
],
|
| 911 |
-
[
|
| 912 |
-
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
| 913 |
-
"[1.59492421 1.90274489 1.38069057 1.46101224]"
|
| 914 |
-
],
|
| 915 |
[
|
| 916 |
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
| 917 |
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
|
@@ -920,6 +920,10 @@
|
|
| 920 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
| 921 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
| 922 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
[
|
| 924 |
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
| 925 |
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
|
@@ -948,18 +952,10 @@
|
|
| 948 |
"[0.80893755 0.92237449 0.88346356 0.93164903]",
|
| 949 |
"[1.80893755 1.92237449 1.88346362 1.93164897]"
|
| 950 |
],
|
| 951 |
-
[
|
| 952 |
-
"[0.12858278 0.09930819 0.83222693 0.72485673]",
|
| 953 |
-
"[1.12858272 1.09930825 1.83222699 1.72485673]"
|
| 954 |
-
],
|
| 955 |
[
|
| 956 |
"[0.72470158 0.4940322 0.41027349 0.89364016]",
|
| 957 |
"[1.72470164 1.49403214 1.41027355 1.89364016]"
|
| 958 |
],
|
| 959 |
-
[
|
| 960 |
-
"[0.47856545 0.46267092 0.6376707 0.84747767]",
|
| 961 |
-
"[1.47856545 1.46267092 1.63767076 1.84747767]"
|
| 962 |
-
],
|
| 963 |
[
|
| 964 |
"[0.49584109 0.80599248 0.07096875 0.75872749]",
|
| 965 |
"[1.49584103 1.80599248 1.07096875 1.75872755]"
|
|
@@ -992,6 +988,10 @@
|
|
| 992 |
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
| 993 |
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
| 994 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
[
|
| 996 |
"[0.68891573 0.25576538 0.96339929 0.503833 ]",
|
| 997 |
"[1.68891573 1.25576544 1.96339929 1.50383306]"
|
|
@@ -1000,7 +1000,7 @@
|
|
| 1000 |
}
|
| 1001 |
},
|
| 1002 |
"other": {
|
| 1003 |
-
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at
|
| 1004 |
},
|
| 1005 |
"relations": []
|
| 1006 |
},
|
|
@@ -1035,8 +1035,8 @@
|
|
| 1035 |
"Input__tensor_1_x"
|
| 1036 |
],
|
| 1037 |
"loss_inputs": [
|
| 1038 |
-
"
|
| 1039 |
-
"
|
| 1040 |
],
|
| 1041 |
"outputs": [
|
| 1042 |
"END_Repeat_1_output"
|
|
@@ -1210,8 +1210,8 @@
|
|
| 1210 |
"Input__tensor_1_x"
|
| 1211 |
],
|
| 1212 |
"loss_inputs": [
|
| 1213 |
-
"
|
| 1214 |
-
"
|
| 1215 |
],
|
| 1216 |
"outputs": [
|
| 1217 |
"END_Repeat_1_output"
|
|
@@ -1270,7 +1270,7 @@
|
|
| 1270 |
"type": "basic"
|
| 1271 |
},
|
| 1272 |
"params": {
|
| 1273 |
-
"epochs": "
|
| 1274 |
"input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
|
| 1275 |
"model_name": "model"
|
| 1276 |
},
|
|
@@ -1322,8 +1322,8 @@
|
|
| 1322 |
"Input__tensor_1_x"
|
| 1323 |
],
|
| 1324 |
"loss_inputs": [
|
| 1325 |
-
"
|
| 1326 |
-
"
|
| 1327 |
],
|
| 1328 |
"outputs": [
|
| 1329 |
"END_Repeat_1_output"
|
|
|
|
| 579 |
],
|
| 580 |
"data": [
|
| 581 |
[
|
| 582 |
+
"[0.94559073 0.65736622 0.25761551 0.48553199]",
|
| 583 |
+
"[1.94559073 1.65736628 1.25761557 1.48553205]",
|
| 584 |
+
"[1.5948047637939453, 1.619612693786621, 1.5269112586975098, -0.008817584253847599]"
|
| 585 |
],
|
| 586 |
[
|
| 587 |
+
"[0.47856545 0.46267092 0.6376707 0.84747767]",
|
| 588 |
+
"[1.47856545 1.46267092 1.63767076 1.84747767]",
|
| 589 |
+
"[1.5928349494934082, 1.6176562309265137, 1.52553391456604, -0.008808750659227371]"
|
| 590 |
],
|
| 591 |
[
|
| 592 |
+
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
| 593 |
+
"[1.59492421 1.90274489 1.38069057 1.46101224]",
|
| 594 |
+
"[1.5592631101608276, 1.5841729640960693, 1.5020174980163574, -0.008660631254315376]"
|
| 595 |
],
|
| 596 |
[
|
| 597 |
+
"[0.12858278 0.09930819 0.83222693 0.72485673]",
|
| 598 |
+
"[1.12858272 1.09930825 1.83222699 1.72485673]",
|
| 599 |
+
"[1.469609260559082, 1.4953691959381104, 1.4395854473114014, -0.00825517252087593]"
|
| 600 |
],
|
| 601 |
[
|
| 602 |
+
"[0.94516498 0.08422136 0.5608117 0.07652664]",
|
| 603 |
+
"[1.94516492 1.08422136 1.56081176 1.07652664]",
|
| 604 |
+
"[1.5648787021636963, 1.5899150371551514, 1.5060429573059082, -0.008683123625814915]"
|
| 605 |
],
|
| 606 |
[
|
| 607 |
+
"[0.54914117 0.03810108 0.87531954 0.73044223]",
|
| 608 |
+
"[1.54914117 1.03810108 1.87531948 1.73044229]",
|
| 609 |
+
"[1.6262837648391724, 1.650843620300293, 1.548863410949707, -0.00895910244435072]"
|
| 610 |
],
|
| 611 |
[
|
| 612 |
+
"[0.94221359 0.57740951 0.98649532 0.40934443]",
|
| 613 |
+
"[1.94221354 1.57740951 1.98649526 1.40934443]",
|
| 614 |
+
"[1.8493703603744507, 1.8721930980682373, 1.704444169998169, -0.009961890056729317]"
|
| 615 |
],
|
| 616 |
[
|
| 617 |
+
"[0.80654246 0.08253473 0.74478531 0.71257162]",
|
| 618 |
+
"[1.8065424 1.08253479 1.74478531 1.71257162]",
|
| 619 |
+
"[1.672502040863037, 1.6967051029205322, 1.5810983180999756, -0.009166811592876911]"
|
| 620 |
],
|
| 621 |
[
|
| 622 |
+
"[0.50272274 0.54912758 0.17663097 0.79070699]",
|
| 623 |
+
"[1.50272274 1.54912758 1.17663097 1.79070699]",
|
| 624 |
+
"[1.4309396743774414, 1.4570224285125732, 1.4126286506652832, -0.008081027306616306]"
|
| 625 |
],
|
| 626 |
[
|
| 627 |
+
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
| 628 |
+
"[1.34084332 1.73018837 1.54168713 1.91440833]",
|
| 629 |
+
"[1.5581963062286377, 1.5832865238189697, 1.5013742446899414, -0.008653069846332073]"
|
| 630 |
]
|
| 631 |
]
|
| 632 |
},
|
|
|
|
| 648 |
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
| 649 |
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
| 650 |
],
|
| 651 |
+
[
|
| 652 |
+
"[0.19409031 0.68692201 0.60667384 0.57829887]",
|
| 653 |
+
"[1.19409037 1.68692207 1.60667384 1.57829881]"
|
| 654 |
+
],
|
| 655 |
[
|
| 656 |
"[0.76807946 0.98855817 0.08259124 0.01730657]",
|
| 657 |
"[1.76807952 1.98855817 1.0825913 1.01730657]"
|
|
|
|
| 685 |
"[1.98324287 1.99464178 1.14008355 1.47651017]"
|
| 686 |
],
|
| 687 |
[
|
| 688 |
+
"[0.11693293 0.49860179 0.55020827 0.88832849]",
|
| 689 |
+
"[1.11693287 1.49860179 1.55020833 1.88832855]"
|
| 690 |
],
|
| 691 |
[
|
| 692 |
+
"[0.48959708 0.48549271 0.32688856 0.356677 ]",
|
| 693 |
+
"[1.48959708 1.48549271 1.32688856 1.35667706]"
|
| 694 |
],
|
| 695 |
[
|
| 696 |
"[0.04508126 0.76880038 0.80721325 0.62542385]",
|
|
|
|
| 712 |
"[0.24388778 0.07268471 0.68350857 0.73431659]",
|
| 713 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
| 714 |
],
|
| 715 |
+
[
|
| 716 |
+
"[0.62569475 0.9881897 0.83639616 0.9828859 ]",
|
| 717 |
+
"[1.62569475 1.9881897 1.83639622 1.98288584]"
|
| 718 |
+
],
|
| 719 |
[
|
| 720 |
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
| 721 |
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
|
|
|
| 728 |
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
| 729 |
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
| 730 |
],
|
| 731 |
+
[
|
| 732 |
+
"[0.48507756 0.80808765 0.77162558 0.47834778]",
|
| 733 |
+
"[1.48507762 1.80808759 1.77162552 1.47834778]"
|
| 734 |
+
],
|
| 735 |
[
|
| 736 |
"[0.68062544 0.98093534 0.14778823 0.53244978]",
|
| 737 |
"[1.68062544 1.98093534 1.14778829 1.53244972]"
|
|
|
|
| 752 |
"[0.23942459 0.90487361 0.69337189 0.65089428]",
|
| 753 |
"[1.23942459 1.90487361 1.69337189 1.65089428]"
|
| 754 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
[
|
| 756 |
"[0.26661873 0.45946234 0.13510543 0.81294441]",
|
| 757 |
"[1.26661873 1.4594624 1.13510537 1.81294441]"
|
|
|
|
| 780 |
"[0.78956431 0.87284744 0.06880784 0.03455889]",
|
| 781 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
| 782 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
[
|
| 784 |
"[0.00497234 0.39319336 0.57054168 0.75150961]",
|
| 785 |
"[1.00497234 1.39319336 1.57054162 1.75150967]"
|
|
|
|
| 792 |
"[0.72290605 0.96945059 0.68354797 0.15270454]",
|
| 793 |
"[1.72290611 1.96945059 1.68354797 1.15270448]"
|
| 794 |
],
|
| 795 |
+
[
|
| 796 |
+
"[0.75292218 0.81470108 0.49657214 0.56217098]",
|
| 797 |
+
"[1.75292218 1.81470108 1.49657214 1.56217098]"
|
| 798 |
+
],
|
| 799 |
+
[
|
| 800 |
+
"[0.33480108 0.59181517 0.76198453 0.98062384]",
|
| 801 |
+
"[1.33480108 1.59181523 1.76198459 1.98062384]"
|
| 802 |
+
],
|
| 803 |
[
|
| 804 |
"[0.52784437 0.54268694 0.12358981 0.72116476]",
|
| 805 |
"[1.52784443 1.54268694 1.12358975 1.7211647 ]"
|
|
|
|
| 808 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
| 809 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
| 810 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
[
|
| 812 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
| 813 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
|
|
|
| 844 |
"[0.95928186 0.84273899 0.71514636 0.38619852]",
|
| 845 |
"[1.95928192 1.84273899 1.7151463 1.38619852]"
|
| 846 |
],
|
| 847 |
+
[
|
| 848 |
+
"[0.32565445 0.90939188 0.07488042 0.13730896]",
|
| 849 |
+
"[1.32565451 1.90939188 1.07488036 1.13730896]"
|
| 850 |
+
],
|
| 851 |
[
|
| 852 |
"[0.9829582 0.59269661 0.40120947 0.95487177]",
|
| 853 |
"[1.9829582 1.59269667 1.40120947 1.95487177]"
|
|
|
|
| 856 |
"[0.79905868 0.89367443 0.75429088 0.3190186 ]",
|
| 857 |
"[1.79905868 1.89367437 1.75429082 1.3190186 ]"
|
| 858 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
[
|
| 860 |
"[0.67418337 0.79634351 0.23229051 0.71345252]",
|
| 861 |
"[1.67418337 1.79634356 1.23229051 1.71345258]"
|
|
|
|
| 868 |
"[0.81788456 0.58174163 0.29376316 0.7971254 ]",
|
| 869 |
"[1.81788456 1.58174157 1.29376316 1.79712534]"
|
| 870 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
[
|
| 872 |
"[0.60075855 0.12234765 0.00614399 0.30560958]",
|
| 873 |
"[1.60075855 1.12234759 1.00614405 1.30560958]"
|
| 874 |
],
|
| 875 |
+
[
|
| 876 |
+
"[0.39147133 0.29854035 0.84663737 0.58175623]",
|
| 877 |
+
"[1.39147139 1.29854035 1.84663737 1.58175623]"
|
| 878 |
+
],
|
| 879 |
[
|
| 880 |
"[0.02162331 0.81861657 0.92468154 0.07808572]",
|
| 881 |
"[1.02162337 1.81861663 1.92468154 1.07808566]"
|
|
|
|
| 904 |
"[0.60609657 0.96257663 0.19292736 0.95702219]",
|
| 905 |
"[1.60609651 1.96257663 1.19292736 1.95702219]"
|
| 906 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
[
|
| 908 |
"[0.70167565 0.26930219 0.5660674 0.61194974]",
|
| 909 |
"[1.70167565 1.26930213 1.56606746 1.61194968]"
|
|
|
|
| 912 |
"[0.76933283 0.86241865 0.44114518 0.65644735]",
|
| 913 |
"[1.76933289 1.86241865 1.44114518 1.65644741]"
|
| 914 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
[
|
| 916 |
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
| 917 |
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
|
|
|
| 920 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
| 921 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
| 922 |
],
|
| 923 |
+
[
|
| 924 |
+
"[0.91730917 0.22574073 0.09591609 0.33056474]",
|
| 925 |
+
"[1.91730917 1.22574067 1.09591603 1.33056474]"
|
| 926 |
+
],
|
| 927 |
[
|
| 928 |
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
| 929 |
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
|
|
|
| 952 |
"[0.80893755 0.92237449 0.88346356 0.93164903]",
|
| 953 |
"[1.80893755 1.92237449 1.88346362 1.93164897]"
|
| 954 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
[
|
| 956 |
"[0.72470158 0.4940322 0.41027349 0.89364016]",
|
| 957 |
"[1.72470164 1.49403214 1.41027355 1.89364016]"
|
| 958 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
[
|
| 960 |
"[0.49584109 0.80599248 0.07096875 0.75872749]",
|
| 961 |
"[1.49584103 1.80599248 1.07096875 1.75872755]"
|
|
|
|
| 988 |
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
| 989 |
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
| 990 |
],
|
| 991 |
+
[
|
| 992 |
+
"[0.87608397 0.93200487 0.80169648 0.37758952]",
|
| 993 |
+
"[1.87608397 1.93200493 1.80169654 1.37758946]"
|
| 994 |
+
],
|
| 995 |
[
|
| 996 |
"[0.68891573 0.25576538 0.96339929 0.503833 ]",
|
| 997 |
"[1.68891573 1.25576544 1.96339929 1.50383306]"
|
|
|
|
| 1000 |
}
|
| 1001 |
},
|
| 1002 |
"other": {
|
| 1003 |
+
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> START_Repeat_1_output\n (4) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (5) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (6) - Identity(): Activation_1_output -> START_Repeat_1_output\n (7) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (8) - <function leaky_relu at 0x75e660938220>: Linear_2_output -> Activation_1_output\n (9) - Identity(): Activation_1_output -> END_Repeat_1_output\n (10) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['Input__tensor_3_x', 'END_Repeat_1_output'], loss=Sequential(\n (0) - <function mse_loss at 0x75e660939d00>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
|
| 1004 |
},
|
| 1005 |
"relations": []
|
| 1006 |
},
|
|
|
|
| 1035 |
"Input__tensor_1_x"
|
| 1036 |
],
|
| 1037 |
"loss_inputs": [
|
| 1038 |
+
"Input__tensor_3_x",
|
| 1039 |
+
"END_Repeat_1_output"
|
| 1040 |
],
|
| 1041 |
"outputs": [
|
| 1042 |
"END_Repeat_1_output"
|
|
|
|
| 1210 |
"Input__tensor_1_x"
|
| 1211 |
],
|
| 1212 |
"loss_inputs": [
|
| 1213 |
+
"Input__tensor_3_x",
|
| 1214 |
+
"END_Repeat_1_output"
|
| 1215 |
],
|
| 1216 |
"outputs": [
|
| 1217 |
"END_Repeat_1_output"
|
|
|
|
| 1270 |
"type": "basic"
|
| 1271 |
},
|
| 1272 |
"params": {
|
| 1273 |
+
"epochs": "110",
|
| 1274 |
"input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
|
| 1275 |
"model_name": "model"
|
| 1276 |
},
|
|
|
|
| 1322 |
"Input__tensor_1_x"
|
| 1323 |
],
|
| 1324 |
"loss_inputs": [
|
| 1325 |
+
"Input__tensor_3_x",
|
| 1326 |
+
"END_Repeat_1_output"
|
| 1327 |
],
|
| 1328 |
"outputs": [
|
| 1329 |
"END_Repeat_1_output"
|
lynxkite-core/src/lynxkite/core/ops.py
CHANGED
|
@@ -13,7 +13,7 @@ from typing_extensions import Annotated
|
|
| 13 |
if typing.TYPE_CHECKING:
|
| 14 |
from . import workspace
|
| 15 |
|
| 16 |
-
CATALOGS = {}
|
| 17 |
EXECUTORS = {}
|
| 18 |
|
| 19 |
typeof = type # We have some arguments called "type".
|
|
|
|
| 13 |
if typing.TYPE_CHECKING:
|
| 14 |
from . import workspace
|
| 15 |
|
| 16 |
+
CATALOGS: dict[str, dict[str, "Op"]] = {}
|
| 17 |
EXECUTORS = {}
|
| 18 |
|
| 19 |
typeof = type # We have some arguments called "type".
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
|
@@ -214,6 +214,9 @@ class ModelConfig:
|
|
| 214 |
source_workspace: str | None = None
|
| 215 |
trained: bool = False
|
| 216 |
|
|
|
|
|
|
|
|
|
|
| 217 |
def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 218 |
model_inputs = [inputs[i] for i in self.model_inputs]
|
| 219 |
output = self.model(*model_inputs)
|
|
@@ -270,7 +273,7 @@ class ModelBuilder:
|
|
| 270 |
def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
|
| 271 |
self.catalog = ops.CATALOGS[ENV]
|
| 272 |
optimizers = []
|
| 273 |
-
self.nodes = {}
|
| 274 |
for node in ws.nodes:
|
| 275 |
self.nodes[node.id] = node
|
| 276 |
if node.data.title == "Optimizer":
|
|
@@ -279,8 +282,8 @@ class ModelBuilder:
|
|
| 279 |
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
| 280 |
[self.optimizer] = optimizers
|
| 281 |
self.dependencies = {n.id: [] for n in ws.nodes}
|
| 282 |
-
self.in_edges = {}
|
| 283 |
-
self.out_edges = {}
|
| 284 |
repeats = []
|
| 285 |
for e in ws.edges:
|
| 286 |
if self.nodes[e.target].data.title == "Repeat":
|
|
@@ -367,21 +370,12 @@ class ModelBuilder:
|
|
| 367 |
t = node.data.title
|
| 368 |
op = self.catalog[t]
|
| 369 |
p = op.convert_params(node.data.params)
|
| 370 |
-
inputs = {}
|
| 371 |
-
for n in self.in_edges.get(node_id, []):
|
| 372 |
-
for b, h in self.in_edges[node_id][n]:
|
| 373 |
-
i = _to_id(b, h)
|
| 374 |
-
inputs[n] = i
|
| 375 |
-
outputs = {}
|
| 376 |
-
for out in self.out_edges.get(node_id, []):
|
| 377 |
-
i = _to_id(node_id, out)
|
| 378 |
-
outputs[out] = i
|
| 379 |
match t:
|
| 380 |
case "Repeat":
|
| 381 |
if node_id.startswith("END "):
|
| 382 |
repeat_id = node_id.removeprefix("END ")
|
| 383 |
start_id = f"START {repeat_id}"
|
| 384 |
-
|
| 385 |
after_start = self.all_downstream(start_id)
|
| 386 |
after_end = self.all_downstream(node_id)
|
| 387 |
before_end = self.all_upstream(node_id)
|
|
@@ -390,28 +384,64 @@ class ModelBuilder:
|
|
| 390 |
assert affected_nodes == repeated_nodes, (
|
| 391 |
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
| 392 |
)
|
| 393 |
-
for
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
| 396 |
return
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
self.layers.append(layer)
|
| 400 |
|
| 401 |
-
def run_op(self,
|
| 402 |
"""Returns the layer produced by this op."""
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
if op.func != ops.no_op:
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
layer =
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
return layer
|
| 416 |
|
| 417 |
def build_model(self) -> ModelConfig:
|
|
@@ -462,7 +492,6 @@ class ModelBuilder:
|
|
| 462 |
p = op.convert_params(self.nodes[self.optimizer].data.params)
|
| 463 |
o = getattr(torch.optim, p["type"].name)
|
| 464 |
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
| 465 |
-
print(cfg)
|
| 466 |
return ModelConfig(**cfg)
|
| 467 |
|
| 468 |
|
|
|
|
| 214 |
source_workspace: str | None = None
|
| 215 |
trained: bool = False
|
| 216 |
|
| 217 |
+
def num_parameters(self) -> int:
|
| 218 |
+
return sum(p.numel() for p in self.model.parameters())
|
| 219 |
+
|
| 220 |
def _forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 221 |
model_inputs = [inputs[i] for i in self.model_inputs]
|
| 222 |
output = self.model(*model_inputs)
|
|
|
|
| 273 |
def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
|
| 274 |
self.catalog = ops.CATALOGS[ENV]
|
| 275 |
optimizers = []
|
| 276 |
+
self.nodes: dict[str, workspace.WorkspaceNode] = {}
|
| 277 |
for node in ws.nodes:
|
| 278 |
self.nodes[node.id] = node
|
| 279 |
if node.data.title == "Optimizer":
|
|
|
|
| 282 |
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
| 283 |
[self.optimizer] = optimizers
|
| 284 |
self.dependencies = {n.id: [] for n in ws.nodes}
|
| 285 |
+
self.in_edges: dict[str, dict[str, list[(str, str)]]] = {}
|
| 286 |
+
self.out_edges: dict[str, dict[str, list[(str, str)]]] = {}
|
| 287 |
repeats = []
|
| 288 |
for e in ws.edges:
|
| 289 |
if self.nodes[e.target].data.title == "Repeat":
|
|
|
|
| 370 |
t = node.data.title
|
| 371 |
op = self.catalog[t]
|
| 372 |
p = op.convert_params(node.data.params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
match t:
|
| 374 |
case "Repeat":
|
| 375 |
if node_id.startswith("END "):
|
| 376 |
repeat_id = node_id.removeprefix("END ")
|
| 377 |
start_id = f"START {repeat_id}"
|
| 378 |
+
[last_output] = self.in_edges[node_id]["input"]
|
| 379 |
after_start = self.all_downstream(start_id)
|
| 380 |
after_end = self.all_downstream(node_id)
|
| 381 |
before_end = self.all_upstream(node_id)
|
|
|
|
| 384 |
assert affected_nodes == repeated_nodes, (
|
| 385 |
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
| 386 |
)
|
| 387 |
+
repeated_layers = [e for e in self.layers if e._origin_id in repeated_nodes]
|
| 388 |
+
for i in range(p["times"] - 1):
|
| 389 |
+
# Copy repeat section's output to repeat section's input.
|
| 390 |
+
self.layers.append(
|
| 391 |
+
self.empty_layer(
|
| 392 |
+
node_id,
|
| 393 |
+
inputs=[_to_id(*last_output)],
|
| 394 |
+
outputs=[_to_id(start_id, "output")],
|
| 395 |
+
)
|
| 396 |
+
)
|
| 397 |
+
# Repeat the layers in the section.
|
| 398 |
+
for layer in repeated_layers:
|
| 399 |
+
if p["same_weights"]:
|
| 400 |
+
self.layers.append(
|
| 401 |
+
Layer(
|
| 402 |
+
layer.module,
|
| 403 |
+
shapes=layer.shapes,
|
| 404 |
+
_origin_id=layer._origin_id,
|
| 405 |
+
_inputs=layer._inputs,
|
| 406 |
+
_outputs=layer._outputs,
|
| 407 |
+
)
|
| 408 |
+
)
|
| 409 |
+
else:
|
| 410 |
+
self.run_node(layer._origin_id)
|
| 411 |
+
self.layers.append(self.run_op(node_id, op, p))
|
| 412 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
| 413 |
return
|
| 414 |
+
case _:
|
| 415 |
+
self.layers.append(self.run_op(node_id, op, p))
|
|
|
|
| 416 |
|
| 417 |
+
def run_op(self, node_id: str, op: ops.Op, params) -> Layer:
|
| 418 |
"""Returns the layer produced by this op."""
|
| 419 |
+
inputs = [_to_id(*i) for n in op.inputs for i in self.in_edges[node_id][n]]
|
| 420 |
+
outputs = [_to_id(node_id, n) for n in op.outputs]
|
| 421 |
+
layer = self.empty_layer(node_id, inputs, outputs)
|
| 422 |
if op.func != ops.no_op:
|
| 423 |
+
op_layer = op.func(*layer._inputs, **params)
|
| 424 |
+
layer.module = op_layer.module
|
| 425 |
+
layer.shapes = op_layer.shapes
|
| 426 |
+
for o in layer._outputs:
|
| 427 |
+
self.sizes[o._id] = o.shape
|
| 428 |
+
return layer
|
| 429 |
+
|
| 430 |
+
def empty_layer(self, id: str, inputs: list[str], outputs: list[str]) -> Layer:
|
| 431 |
+
"""Creates an identity layer. Assumes that outputs have the same shapes as inputs."""
|
| 432 |
+
layer_inputs = [TensorRef(i, shape=self.sizes.get(i, 1)) for i in inputs]
|
| 433 |
+
layer_outputs = []
|
| 434 |
+
for i, o in zip(inputs, outputs):
|
| 435 |
+
shape = self.sizes.get(i, 1)
|
| 436 |
+
layer_outputs.append(TensorRef(o, shape=shape))
|
| 437 |
+
self.sizes[o] = shape
|
| 438 |
+
layer = Layer(
|
| 439 |
+
torch.nn.Identity(),
|
| 440 |
+
shapes=[self.sizes[o._id] for o in layer_outputs],
|
| 441 |
+
_inputs=layer_inputs,
|
| 442 |
+
_outputs=layer_outputs,
|
| 443 |
+
_origin_id=id,
|
| 444 |
+
)
|
| 445 |
return layer
|
| 446 |
|
| 447 |
def build_model(self) -> ModelConfig:
|
|
|
|
| 492 |
p = op.convert_params(self.nodes[self.optimizer].data.params)
|
| 493 |
o = getattr(torch.optim, p["type"].name)
|
| 494 |
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
|
|
|
| 495 |
return ModelConfig(**cfg)
|
| 496 |
|
| 497 |
|