Spaces:
Running
Running
Split build_model() so we can repeat a part.
Browse files
examples/Model definition
CHANGED
@@ -1,12 +1,5 @@
|
|
1 |
{
|
2 |
"edges": [
|
3 |
-
{
|
4 |
-
"id": "MSE loss 1 Optimizer 2",
|
5 |
-
"source": "MSE loss 1",
|
6 |
-
"sourceHandle": "loss",
|
7 |
-
"target": "Optimizer 2",
|
8 |
-
"targetHandle": "loss"
|
9 |
-
},
|
10 |
{
|
11 |
"id": "Repeat 1 Linear 2",
|
12 |
"source": "Repeat 1",
|
@@ -14,13 +7,6 @@
|
|
14 |
"target": "Linear 2",
|
15 |
"targetHandle": "x"
|
16 |
},
|
17 |
-
{
|
18 |
-
"id": "Activation 1 MSE loss 1",
|
19 |
-
"source": "Activation 1",
|
20 |
-
"sourceHandle": "output",
|
21 |
-
"target": "MSE loss 1",
|
22 |
-
"targetHandle": "x"
|
23 |
-
},
|
24 |
{
|
25 |
"id": "Linear 2 Activation 1",
|
26 |
"source": "Linear 2",
|
@@ -35,72 +21,37 @@
|
|
35 |
"target": "Repeat 1",
|
36 |
"targetHandle": "input"
|
37 |
},
|
38 |
-
{
|
39 |
-
"id": "Input: tensor 3 MSE loss 1",
|
40 |
-
"source": "Input: tensor 3",
|
41 |
-
"sourceHandle": "x",
|
42 |
-
"target": "MSE loss 1",
|
43 |
-
"targetHandle": "y"
|
44 |
-
},
|
45 |
{
|
46 |
"id": "Input: tensor 1 Linear 2",
|
47 |
"source": "Input: tensor 1",
|
48 |
"sourceHandle": "x",
|
49 |
"target": "Linear 2",
|
50 |
"targetHandle": "x"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
}
|
52 |
],
|
53 |
"env": "PyTorch model",
|
54 |
"nodes": [
|
55 |
-
{
|
56 |
-
"data": {
|
57 |
-
"display": null,
|
58 |
-
"error": null,
|
59 |
-
"input_metadata": null,
|
60 |
-
"meta": {
|
61 |
-
"inputs": {
|
62 |
-
"x": {
|
63 |
-
"name": "x",
|
64 |
-
"position": "bottom",
|
65 |
-
"type": {
|
66 |
-
"type": "tensor"
|
67 |
-
}
|
68 |
-
},
|
69 |
-
"y": {
|
70 |
-
"name": "y",
|
71 |
-
"position": "bottom",
|
72 |
-
"type": {
|
73 |
-
"type": "tensor"
|
74 |
-
}
|
75 |
-
}
|
76 |
-
},
|
77 |
-
"name": "MSE loss",
|
78 |
-
"outputs": {
|
79 |
-
"loss": {
|
80 |
-
"name": "loss",
|
81 |
-
"position": "top",
|
82 |
-
"type": {
|
83 |
-
"type": "tensor"
|
84 |
-
}
|
85 |
-
}
|
86 |
-
},
|
87 |
-
"params": {},
|
88 |
-
"type": "basic"
|
89 |
-
},
|
90 |
-
"params": {},
|
91 |
-
"status": "planned",
|
92 |
-
"title": "MSE loss"
|
93 |
-
},
|
94 |
-
"dragHandle": ".bg-primary",
|
95 |
-
"height": 200.0,
|
96 |
-
"id": "MSE loss 1",
|
97 |
-
"position": {
|
98 |
-
"x": 315.0,
|
99 |
-
"y": -510.0
|
100 |
-
},
|
101 |
-
"type": "basic",
|
102 |
-
"width": 200.0
|
103 |
-
},
|
104 |
{
|
105 |
"data": {
|
106 |
"__execution_delay": 0.0,
|
@@ -384,7 +335,7 @@
|
|
384 |
"height": 200.0,
|
385 |
"id": "Input: tensor 1",
|
386 |
"position": {
|
387 |
-
"x":
|
388 |
"y": 293.6278596776366
|
389 |
},
|
390 |
"type": "basic",
|
@@ -434,8 +385,61 @@
|
|
434 |
"height": 200.0,
|
435 |
"id": "Input: tensor 3",
|
436 |
"position": {
|
437 |
-
"x":
|
438 |
-
"y": -
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
},
|
440 |
"type": "basic",
|
441 |
"width": 200.0
|
|
|
1 |
{
|
2 |
"edges": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
{
|
4 |
"id": "Repeat 1 Linear 2",
|
5 |
"source": "Repeat 1",
|
|
|
7 |
"target": "Linear 2",
|
8 |
"targetHandle": "x"
|
9 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
{
|
11 |
"id": "Linear 2 Activation 1",
|
12 |
"source": "Linear 2",
|
|
|
21 |
"target": "Repeat 1",
|
22 |
"targetHandle": "input"
|
23 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
{
|
25 |
"id": "Input: tensor 1 Linear 2",
|
26 |
"source": "Input: tensor 1",
|
27 |
"sourceHandle": "x",
|
28 |
"target": "Linear 2",
|
29 |
"targetHandle": "x"
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"id": "MSE loss 2 Optimizer 2",
|
33 |
+
"source": "MSE loss 2",
|
34 |
+
"sourceHandle": "output",
|
35 |
+
"target": "Optimizer 2",
|
36 |
+
"targetHandle": "loss"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"id": "Activation 1 MSE loss 2",
|
40 |
+
"source": "Activation 1",
|
41 |
+
"sourceHandle": "output",
|
42 |
+
"target": "MSE loss 2",
|
43 |
+
"targetHandle": "x"
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"id": "Input: tensor 3 MSE loss 2",
|
47 |
+
"source": "Input: tensor 3",
|
48 |
+
"sourceHandle": "x",
|
49 |
+
"target": "MSE loss 2",
|
50 |
+
"targetHandle": "y"
|
51 |
}
|
52 |
],
|
53 |
"env": "PyTorch model",
|
54 |
"nodes": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
{
|
56 |
"data": {
|
57 |
"__execution_delay": 0.0,
|
|
|
335 |
"height": 200.0,
|
336 |
"id": "Input: tensor 1",
|
337 |
"position": {
|
338 |
+
"x": 85.83561484252238,
|
339 |
"y": 293.6278596776366
|
340 |
},
|
341 |
"type": "basic",
|
|
|
385 |
"height": 200.0,
|
386 |
"id": "Input: tensor 3",
|
387 |
"position": {
|
388 |
+
"x": 485.8840220312055,
|
389 |
+
"y": -149.86223034126274
|
390 |
+
},
|
391 |
+
"type": "basic",
|
392 |
+
"width": 200.0
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"data": {
|
396 |
+
"display": null,
|
397 |
+
"error": null,
|
398 |
+
"input_metadata": null,
|
399 |
+
"meta": {
|
400 |
+
"inputs": {
|
401 |
+
"x": {
|
402 |
+
"name": "x",
|
403 |
+
"position": "bottom",
|
404 |
+
"type": {
|
405 |
+
"type": "<class 'inspect._empty'>"
|
406 |
+
}
|
407 |
+
},
|
408 |
+
"y": {
|
409 |
+
"name": "y",
|
410 |
+
"position": "bottom",
|
411 |
+
"type": {
|
412 |
+
"type": "<class 'inspect._empty'>"
|
413 |
+
}
|
414 |
+
}
|
415 |
+
},
|
416 |
+
"name": "MSE loss",
|
417 |
+
"outputs": {
|
418 |
+
"output": {
|
419 |
+
"name": "output",
|
420 |
+
"position": "top",
|
421 |
+
"type": {
|
422 |
+
"type": "None"
|
423 |
+
}
|
424 |
+
}
|
425 |
+
},
|
426 |
+
"params": {},
|
427 |
+
"position": {
|
428 |
+
"x": 937.0,
|
429 |
+
"y": 270.0
|
430 |
+
},
|
431 |
+
"type": "basic"
|
432 |
+
},
|
433 |
+
"params": {},
|
434 |
+
"status": "planned",
|
435 |
+
"title": "MSE loss"
|
436 |
+
},
|
437 |
+
"dragHandle": ".bg-primary",
|
438 |
+
"height": 200.0,
|
439 |
+
"id": "MSE loss 2",
|
440 |
+
"position": {
|
441 |
+
"x": 690.0,
|
442 |
+
"y": -480.0
|
443 |
},
|
444 |
"type": "basic",
|
445 |
"width": 200.0
|
examples/Model use
CHANGED
@@ -579,54 +579,54 @@
|
|
579 |
],
|
580 |
"data": [
|
581 |
[
|
582 |
-
"[0.
|
583 |
-
"[1.
|
584 |
-
"[1.
|
585 |
],
|
586 |
[
|
587 |
-
"[0.
|
588 |
-
"[1.
|
589 |
-
"[1.
|
590 |
],
|
591 |
[
|
592 |
-
"[0.
|
593 |
-
"[1.
|
594 |
-
"[1.
|
595 |
],
|
596 |
[
|
597 |
-
"[0.
|
598 |
-
"[1.
|
599 |
-
"[1.
|
600 |
],
|
601 |
[
|
602 |
-
"[0.
|
603 |
-
"[1.
|
604 |
-
"[1.
|
605 |
],
|
606 |
[
|
607 |
-
"[0.
|
608 |
-
"[1.
|
609 |
-
"[1.
|
610 |
],
|
611 |
[
|
612 |
-
"[0.
|
613 |
-
"[1.
|
614 |
-
"[1.
|
615 |
],
|
616 |
[
|
617 |
-
"[0.
|
618 |
-
"[1.
|
619 |
-
"[1.
|
620 |
],
|
621 |
[
|
622 |
-
"[0.
|
623 |
-
"[1.
|
624 |
-
"[1.
|
625 |
],
|
626 |
[
|
627 |
-
"[0.
|
628 |
-
"[1.
|
629 |
-
"[1.
|
630 |
]
|
631 |
]
|
632 |
},
|
@@ -648,10 +648,6 @@
|
|
648 |
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
649 |
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
650 |
],
|
651 |
-
[
|
652 |
-
"[0.19409031 0.68692201 0.60667384 0.57829887]",
|
653 |
-
"[1.19409037 1.68692207 1.60667384 1.57829881]"
|
654 |
-
],
|
655 |
[
|
656 |
"[0.76807946 0.98855817 0.08259124 0.01730657]",
|
657 |
"[1.76807952 1.98855817 1.0825913 1.01730657]"
|
@@ -684,10 +680,6 @@
|
|
684 |
"[0.98324287 0.99464184 0.14008355 0.47651017]",
|
685 |
"[1.98324287 1.99464178 1.14008355 1.47651017]"
|
686 |
],
|
687 |
-
[
|
688 |
-
"[0.11693293 0.49860179 0.55020827 0.88832849]",
|
689 |
-
"[1.11693287 1.49860179 1.55020833 1.88832855]"
|
690 |
-
],
|
691 |
[
|
692 |
"[0.48959708 0.48549271 0.32688856 0.356677 ]",
|
693 |
"[1.48959708 1.48549271 1.32688856 1.35667706]"
|
@@ -716,10 +708,6 @@
|
|
716 |
"[0.24388778 0.07268471 0.68350857 0.73431659]",
|
717 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
718 |
],
|
719 |
-
[
|
720 |
-
"[0.62569475 0.9881897 0.83639616 0.9828859 ]",
|
721 |
-
"[1.62569475 1.9881897 1.83639622 1.98288584]"
|
722 |
-
],
|
723 |
[
|
724 |
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
725 |
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
@@ -732,14 +720,14 @@
|
|
732 |
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
733 |
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
734 |
],
|
735 |
-
[
|
736 |
-
"[0.48507756 0.80808765 0.77162558 0.47834778]",
|
737 |
-
"[1.48507762 1.80808759 1.77162552 1.47834778]"
|
738 |
-
],
|
739 |
[
|
740 |
"[0.68062544 0.98093534 0.14778823 0.53244978]",
|
741 |
"[1.68062544 1.98093534 1.14778829 1.53244972]"
|
742 |
],
|
|
|
|
|
|
|
|
|
743 |
[
|
744 |
"[0.79121011 0.54161114 0.69369799 0.1520769 ]",
|
745 |
"[1.79121017 1.54161119 1.69369793 1.15207696]"
|
@@ -785,20 +773,20 @@
|
|
785 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
786 |
],
|
787 |
[
|
788 |
-
"[0.
|
789 |
-
"[1.
|
790 |
],
|
791 |
[
|
792 |
-
"[0.
|
793 |
-
"[1.
|
794 |
],
|
795 |
[
|
796 |
-
"[0.
|
797 |
-
"[1.
|
798 |
],
|
799 |
[
|
800 |
-
"[0.
|
801 |
-
"[1.
|
802 |
],
|
803 |
[
|
804 |
"[0.52784437 0.54268694 0.12358981 0.72116476]",
|
@@ -808,6 +796,10 @@
|
|
808 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
809 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
810 |
],
|
|
|
|
|
|
|
|
|
811 |
[
|
812 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
813 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
@@ -836,13 +828,17 @@
|
|
836 |
"[0.18720162 0.74115586 0.98626411 0.30355608]",
|
837 |
"[1.18720162 1.74115586 1.98626411 1.30355608]"
|
838 |
],
|
|
|
|
|
|
|
|
|
839 |
[
|
840 |
"[0.95928186 0.84273899 0.71514636 0.38619852]",
|
841 |
"[1.95928192 1.84273899 1.7151463 1.38619852]"
|
842 |
],
|
843 |
[
|
844 |
-
"[0.
|
845 |
-
"[1.
|
846 |
],
|
847 |
[
|
848 |
"[0.79905868 0.89367443 0.75429088 0.3190186 ]",
|
@@ -873,8 +869,12 @@
|
|
873 |
"[1.60075855 1.12234759 1.00614405 1.30560958]"
|
874 |
],
|
875 |
[
|
876 |
-
"[0.
|
877 |
-
"[1.
|
|
|
|
|
|
|
|
|
878 |
],
|
879 |
[
|
880 |
"[0.6080932 0.56563014 0.32107437 0.72599429]",
|
@@ -908,6 +908,10 @@
|
|
908 |
"[0.76933283 0.86241865 0.44114518 0.65644735]",
|
909 |
"[1.76933289 1.86241865 1.44114518 1.65644741]"
|
910 |
],
|
|
|
|
|
|
|
|
|
911 |
[
|
912 |
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
913 |
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
@@ -916,10 +920,6 @@
|
|
916 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
917 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
918 |
],
|
919 |
-
[
|
920 |
-
"[0.91730917 0.22574073 0.09591609 0.33056474]",
|
921 |
-
"[1.91730917 1.22574067 1.09591603 1.33056474]"
|
922 |
-
],
|
923 |
[
|
924 |
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
925 |
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
@@ -960,6 +960,10 @@
|
|
960 |
"[0.47856545 0.46267092 0.6376707 0.84747767]",
|
961 |
"[1.47856545 1.46267092 1.63767076 1.84747767]"
|
962 |
],
|
|
|
|
|
|
|
|
|
963 |
[
|
964 |
"[0.43500566 0.66041756 0.80293626 0.96224713]",
|
965 |
"[1.43500566 1.66041756 1.80293632 1.96224713]"
|
@@ -988,10 +992,6 @@
|
|
988 |
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
989 |
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
990 |
],
|
991 |
-
[
|
992 |
-
"[0.87608397 0.93200487 0.80169648 0.37758952]",
|
993 |
-
"[1.87608397 1.93200493 1.80169654 1.37758946]"
|
994 |
-
],
|
995 |
[
|
996 |
"[0.68891573 0.25576538 0.96339929 0.503833 ]",
|
997 |
"[1.68891573 1.25576544 1.96339929 1.50383306]"
|
@@ -1000,7 +1000,7 @@
|
|
1000 |
}
|
1001 |
},
|
1002 |
"other": {
|
1003 |
-
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at
|
1004 |
},
|
1005 |
"relations": []
|
1006 |
},
|
@@ -1270,7 +1270,7 @@
|
|
1270 |
"type": "basic"
|
1271 |
},
|
1272 |
"params": {
|
1273 |
-
"epochs": "
|
1274 |
"input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
|
1275 |
"model_name": "model"
|
1276 |
},
|
|
|
579 |
],
|
580 |
"data": [
|
581 |
[
|
582 |
+
"[0.33480108 0.59181517 0.76198453 0.98062384]",
|
583 |
+
"[1.33480108 1.59181523 1.76198459 1.98062384]",
|
584 |
+
"[1.3419755697250366, 1.5946478843688965, 1.7717586755752563, 1.9897401332855225]"
|
585 |
],
|
586 |
[
|
587 |
+
"[0.91730917 0.22574073 0.09591609 0.33056474]",
|
588 |
+
"[1.91730917 1.22574067 1.09591603 1.33056474]",
|
589 |
+
"[1.900892972946167, 1.2247941493988037, 1.0862866640090942, 1.323314905166626]"
|
590 |
],
|
591 |
[
|
592 |
+
"[0.32565445 0.90939188 0.07488042 0.13730896]",
|
593 |
+
"[1.32565451 1.90939188 1.07488036 1.13730896]",
|
594 |
+
"[1.3460955619812012, 1.8960161209106445, 1.0530263185501099, 1.1075329780578613]"
|
595 |
],
|
596 |
[
|
597 |
+
"[0.87608397 0.93200487 0.80169648 0.37758952]",
|
598 |
+
"[1.87608397 1.93200493 1.80169654 1.37758946]",
|
599 |
+
"[1.87070894241333, 1.9386992454528809, 1.8151044845581055, 1.3952441215515137]"
|
600 |
],
|
601 |
[
|
602 |
+
"[0.39147133 0.29854035 0.84663737 0.58175623]",
|
603 |
+
"[1.39147139 1.29854035 1.84663737 1.58175623]",
|
604 |
+
"[1.3877646923065186, 1.2995290756225586, 1.847062587738037, 1.583693265914917]"
|
605 |
],
|
606 |
[
|
607 |
+
"[0.48507756 0.80808765 0.77162558 0.47834778]",
|
608 |
+
"[1.48507762 1.80808759 1.77162552 1.47834778]",
|
609 |
+
"[1.490919828414917, 1.8087174892425537, 1.7757861614227295, 1.4824031591415405]"
|
610 |
],
|
611 |
[
|
612 |
+
"[0.75292218 0.81470108 0.49657214 0.56217098]",
|
613 |
+
"[1.75292218 1.81470108 1.49657214 1.56217098]",
|
614 |
+
"[1.7527031898498535, 1.8176040649414062, 1.503413438796997, 1.570152759552002]"
|
615 |
],
|
616 |
[
|
617 |
+
"[0.11693293 0.49860179 0.55020827 0.88832849]",
|
618 |
+
"[1.11693287 1.49860179 1.55020833 1.88832855]",
|
619 |
+
"[1.1314976215362549, 1.4944026470184326, 1.546830177307129, 1.8803892135620117]"
|
620 |
],
|
621 |
[
|
622 |
+
"[0.19409031 0.68692201 0.60667384 0.57829887]",
|
623 |
+
"[1.19409037 1.68692207 1.60667384 1.57829881]",
|
624 |
+
"[1.2091591358184814, 1.6816589832305908, 1.6011345386505127, 1.5684995651245117]"
|
625 |
],
|
626 |
[
|
627 |
+
"[0.62569475 0.9881897 0.83639616 0.9828859 ]",
|
628 |
+
"[1.62569475 1.9881897 1.83639622 1.98288584]",
|
629 |
+
"[1.6314740180969238, 1.996805191040039, 1.8592857122421265, 2.0075552463531494]"
|
630 |
]
|
631 |
]
|
632 |
},
|
|
|
648 |
"[0.11560339 0.57495481 0.76535827 0.0391947 ]",
|
649 |
"[1.11560345 1.57495475 1.76535821 1.0391947 ]"
|
650 |
],
|
|
|
|
|
|
|
|
|
651 |
[
|
652 |
"[0.76807946 0.98855817 0.08259124 0.01730657]",
|
653 |
"[1.76807952 1.98855817 1.0825913 1.01730657]"
|
|
|
680 |
"[0.98324287 0.99464184 0.14008355 0.47651017]",
|
681 |
"[1.98324287 1.99464178 1.14008355 1.47651017]"
|
682 |
],
|
|
|
|
|
|
|
|
|
683 |
[
|
684 |
"[0.48959708 0.48549271 0.32688856 0.356677 ]",
|
685 |
"[1.48959708 1.48549271 1.32688856 1.35667706]"
|
|
|
708 |
"[0.24388778 0.07268471 0.68350857 0.73431659]",
|
709 |
"[1.24388778 1.07268476 1.68350863 1.73431659]"
|
710 |
],
|
|
|
|
|
|
|
|
|
711 |
[
|
712 |
"[0.56922203 0.98222166 0.76851749 0.28615737]",
|
713 |
"[1.56922197 1.9822216 1.76851749 1.28615737]"
|
|
|
720 |
"[0.90817457 0.89270043 0.38583666 0.66566533]",
|
721 |
"[1.90817451 1.89270043 1.3858366 1.66566539]"
|
722 |
],
|
|
|
|
|
|
|
|
|
723 |
[
|
724 |
"[0.68062544 0.98093534 0.14778823 0.53244978]",
|
725 |
"[1.68062544 1.98093534 1.14778829 1.53244972]"
|
726 |
],
|
727 |
+
[
|
728 |
+
"[0.31518555 0.49643308 0.11509258 0.95458382]",
|
729 |
+
"[1.31518555 1.49643302 1.11509252 1.95458388]"
|
730 |
+
],
|
731 |
[
|
732 |
"[0.79121011 0.54161114 0.69369799 0.1520769 ]",
|
733 |
"[1.79121017 1.54161119 1.69369793 1.15207696]"
|
|
|
773 |
"[1.78956437 1.87284744 1.06880784 1.03455889]"
|
774 |
],
|
775 |
[
|
776 |
+
"[0.94221359 0.57740951 0.98649532 0.40934443]",
|
777 |
+
"[1.94221354 1.57740951 1.98649526 1.40934443]"
|
778 |
],
|
779 |
[
|
780 |
+
"[0.00497234 0.39319336 0.57054168 0.75150961]",
|
781 |
+
"[1.00497234 1.39319336 1.57054162 1.75150967]"
|
782 |
],
|
783 |
[
|
784 |
+
"[0.44330525 0.09997386 0.89025736 0.90507984]",
|
785 |
+
"[1.44330525 1.09997392 1.89025736 1.90507984]"
|
786 |
],
|
787 |
[
|
788 |
+
"[0.72290605 0.96945059 0.68354797 0.15270454]",
|
789 |
+
"[1.72290611 1.96945059 1.68354797 1.15270448]"
|
790 |
],
|
791 |
[
|
792 |
"[0.52784437 0.54268694 0.12358981 0.72116476]",
|
|
|
796 |
"[0.73217702 0.65233225 0.44077861 0.33837909]",
|
797 |
"[1.73217702 1.65233231 1.44077861 1.33837914]"
|
798 |
],
|
799 |
+
[
|
800 |
+
"[0.34084332 0.73018837 0.54168713 0.91440833]",
|
801 |
+
"[1.34084332 1.73018837 1.54168713 1.91440833]"
|
802 |
+
],
|
803 |
[
|
804 |
"[0.60110539 0.3618983 0.32342511 0.98672163]",
|
805 |
"[1.60110545 1.3618983 1.32342505 1.98672163]"
|
|
|
828 |
"[0.18720162 0.74115586 0.98626411 0.30355608]",
|
829 |
"[1.18720162 1.74115586 1.98626411 1.30355608]"
|
830 |
],
|
831 |
+
[
|
832 |
+
"[0.85566247 0.83362883 0.48424995 0.25265992]",
|
833 |
+
"[1.85566247 1.83362889 1.48424995 1.25265992]"
|
834 |
+
],
|
835 |
[
|
836 |
"[0.95928186 0.84273899 0.71514636 0.38619852]",
|
837 |
"[1.95928192 1.84273899 1.7151463 1.38619852]"
|
838 |
],
|
839 |
[
|
840 |
+
"[0.9829582 0.59269661 0.40120947 0.95487177]",
|
841 |
+
"[1.9829582 1.59269667 1.40120947 1.95487177]"
|
842 |
],
|
843 |
[
|
844 |
"[0.79905868 0.89367443 0.75429088 0.3190186 ]",
|
|
|
869 |
"[1.60075855 1.12234759 1.00614405 1.30560958]"
|
870 |
],
|
871 |
[
|
872 |
+
"[0.02162331 0.81861657 0.92468154 0.07808572]",
|
873 |
+
"[1.02162337 1.81861663 1.92468154 1.07808566]"
|
874 |
+
],
|
875 |
+
[
|
876 |
+
"[0.02235305 0.52774918 0.7331115 0.84358269]",
|
877 |
+
"[1.02235305 1.52774918 1.7331115 1.84358263]"
|
878 |
],
|
879 |
[
|
880 |
"[0.6080932 0.56563014 0.32107437 0.72599429]",
|
|
|
908 |
"[0.76933283 0.86241865 0.44114518 0.65644735]",
|
909 |
"[1.76933289 1.86241865 1.44114518 1.65644741]"
|
910 |
],
|
911 |
+
[
|
912 |
+
"[0.59492421 0.90274489 0.38069052 0.46101224]",
|
913 |
+
"[1.59492421 1.90274489 1.38069057 1.46101224]"
|
914 |
+
],
|
915 |
[
|
916 |
"[0.15064228 0.03198934 0.25754827 0.51484001]",
|
917 |
"[1.15064228 1.03198934 1.25754833 1.51484001]"
|
|
|
920 |
"[0.12024075 0.21342516 0.56858408 0.58644271]",
|
921 |
"[1.12024069 1.21342516 1.56858408 1.58644271]"
|
922 |
],
|
|
|
|
|
|
|
|
|
923 |
[
|
924 |
"[0.49691743 0.61873293 0.90698647 0.94486356]",
|
925 |
"[1.49691749 1.61873293 1.90698647 1.94486356]"
|
|
|
960 |
"[0.47856545 0.46267092 0.6376707 0.84747767]",
|
961 |
"[1.47856545 1.46267092 1.63767076 1.84747767]"
|
962 |
],
|
963 |
+
[
|
964 |
+
"[0.49584109 0.80599248 0.07096875 0.75872749]",
|
965 |
+
"[1.49584103 1.80599248 1.07096875 1.75872755]"
|
966 |
+
],
|
967 |
[
|
968 |
"[0.43500566 0.66041756 0.80293626 0.96224713]",
|
969 |
"[1.43500566 1.66041756 1.80293632 1.96224713]"
|
|
|
992 |
"[0.72795159 0.79317838 0.27832931 0.96576637]",
|
993 |
"[1.72795153 1.79317832 1.27832937 1.96576643]"
|
994 |
],
|
|
|
|
|
|
|
|
|
995 |
[
|
996 |
"[0.68891573 0.25576538 0.96339929 0.503833 ]",
|
997 |
"[1.68891573 1.25576544 1.96339929 1.50383306]"
|
|
|
1000 |
}
|
1001 |
},
|
1002 |
"other": {
|
1003 |
+
"model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x759ed4f2c360>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['END_Repeat_1_output', 'Input__tensor_3_x'], loss=Sequential(\n (0) - <function mse_loss at 0x759ed4f2de40>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
|
1004 |
},
|
1005 |
"relations": []
|
1006 |
},
|
|
|
1270 |
"type": "basic"
|
1271 |
},
|
1272 |
"params": {
|
1273 |
+
"epochs": "1500",
|
1274 |
"input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
|
1275 |
"model_name": "model"
|
1276 |
},
|
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
import copy
|
4 |
import enum
|
5 |
import graphlib
|
6 |
-
import types
|
7 |
|
8 |
import pydantic
|
9 |
from lynxkite.core import ops, workspace
|
@@ -100,6 +99,11 @@ def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU):
|
|
100 |
return Layer(f, shape=x.shape)
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
|
|
103 |
reg("Softmax", inputs=["x"])
|
104 |
reg(
|
105 |
"Graph conv",
|
@@ -111,7 +115,6 @@ reg("Concatenate", inputs=["a", "b"], outputs=["x"])
|
|
111 |
reg("Add", inputs=["a", "b"], outputs=["x"])
|
112 |
reg("Subtract", inputs=["a", "b"], outputs=["x"])
|
113 |
reg("Multiply", inputs=["a", "b"], outputs=["x"])
|
114 |
-
reg("MSE loss", inputs=["x", "y"], outputs=["loss"])
|
115 |
reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
|
116 |
reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
|
117 |
reg(
|
@@ -161,10 +164,10 @@ def _to_id(*strings: str) -> str:
|
|
161 |
|
162 |
|
163 |
@dataclasses.dataclass
|
164 |
-
class
|
165 |
"""Ops get their inputs like this. They have to return a Layer made for this input."""
|
166 |
|
167 |
-
|
168 |
shape: tuple[int, ...]
|
169 |
|
170 |
|
@@ -175,12 +178,21 @@ class Layer:
|
|
175 |
module: torch.nn.Module
|
176 |
shapes: list[tuple[int, ...]] | None = None # One for each output.
|
177 |
shape: dataclasses.InitVar[tuple[int, ...] | None] = None # Convenience for single output.
|
|
|
|
|
|
|
|
|
178 |
|
179 |
def __post_init__(self, shape):
|
180 |
assert not self.shapes or not shape, "Cannot set both shapes and shape."
|
181 |
if shape:
|
182 |
self.shapes = [shape]
|
183 |
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
class ColumnSpec(pydantic.BaseModel):
|
186 |
df: str
|
@@ -246,204 +258,212 @@ class ModelConfig:
|
|
246 |
}
|
247 |
|
248 |
|
249 |
-
def _add_op(op, params, inputs, outputs, sizes, layers):
|
250 |
-
op_inputs = []
|
251 |
-
for i in op.inputs.keys():
|
252 |
-
id = getattr(inputs, i)
|
253 |
-
op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
|
254 |
-
if op.func != ops.no_op:
|
255 |
-
layer = op.func(*op_inputs, **params)
|
256 |
-
else:
|
257 |
-
layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
|
258 |
-
input_ids = ", ".join(i.id for i in op_inputs)
|
259 |
-
output_ids = []
|
260 |
-
for o, shape in zip(op.outputs.keys(), layer.shapes):
|
261 |
-
id = getattr(outputs, o)
|
262 |
-
output_ids.append(id)
|
263 |
-
sizes[id] = shape
|
264 |
-
output_ids = ", ".join(output_ids)
|
265 |
-
layers.append((layer.module, f"{input_ids} -> {output_ids}"))
|
266 |
-
|
267 |
-
|
268 |
-
def _all_dependencies(node: str, dependencies: dict[str, list[str]]) -> set[str]:
|
269 |
-
"""Returns all dependencies of a node."""
|
270 |
-
deps = set()
|
271 |
-
for dep in dependencies[node]:
|
272 |
-
deps.add(dep)
|
273 |
-
deps.update(_all_dependencies(dep, dependencies))
|
274 |
-
return deps
|
275 |
-
|
276 |
-
|
277 |
def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
|
278 |
"""Builds the model described in the workspace."""
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
dependencies[
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
#
|
312 |
-
#
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
if node_id.startswith("START "):
|
359 |
-
node = nodes[node_id.removeprefix("START ")]
|
360 |
elif node_id.startswith("END "):
|
361 |
-
node = nodes[node_id.removeprefix("END ")]
|
362 |
else:
|
363 |
-
node = nodes[node_id]
|
364 |
t = node.data.title
|
365 |
-
op = catalog[t]
|
366 |
p = op.convert_params(node.data.params)
|
367 |
-
for b in dependencies[node_id]:
|
368 |
-
regions[node_id] |= regions[b]
|
369 |
-
if "loss" in t:
|
370 |
-
regions[node_id].add("loss")
|
371 |
inputs = {}
|
372 |
-
for n in in_edges.get(node_id, []):
|
373 |
-
for b, h in in_edges[node_id][n]:
|
374 |
i = _to_id(b, h)
|
375 |
inputs[n] = i
|
376 |
-
if "loss" in regions[node_id]:
|
377 |
-
used_in_loss.add(i)
|
378 |
-
else:
|
379 |
-
used_in_model.add(i)
|
380 |
outputs = {}
|
381 |
-
for out in out_edges.get(node_id, []):
|
382 |
i = _to_id(node_id, out)
|
383 |
outputs[out] = i
|
384 |
-
if not t.startswith("Input:"): # The outputs of inputs are not "made" by us.
|
385 |
-
if "loss" in regions[node_id]:
|
386 |
-
made_in_loss.add(i)
|
387 |
-
else:
|
388 |
-
made_in_model.add(i)
|
389 |
-
inputs = types.SimpleNamespace(**inputs)
|
390 |
-
outputs = types.SimpleNamespace(**outputs)
|
391 |
-
ls = loss_layers if "loss" in regions[node_id] else layers
|
392 |
match t:
|
393 |
-
case "MSE loss":
|
394 |
-
ls.append(
|
395 |
-
(
|
396 |
-
torch.nn.functional.mse_loss,
|
397 |
-
f"{inputs.x}, {inputs.y} -> {outputs.loss}",
|
398 |
-
)
|
399 |
-
)
|
400 |
case "Repeat":
|
401 |
-
|
402 |
-
sizes[outputs.output] = sizes.get(inputs.input, 1)
|
403 |
-
if node_id.startswith("START "):
|
404 |
-
regions[node_id].add(("repeat", node_id.removeprefix("START ")))
|
405 |
-
else:
|
406 |
repeat_id = node_id.removeprefix("END ")
|
407 |
start_id = f"START {repeat_id}"
|
408 |
print(f"repeat {repeat_id} ending")
|
409 |
-
after_start =
|
410 |
-
after_end =
|
411 |
-
before_end =
|
412 |
-
affected_nodes = after_start - after_end
|
413 |
repeated_nodes = after_start & before_end
|
414 |
assert affected_nodes == repeated_nodes, (
|
415 |
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
416 |
)
|
417 |
-
regions[node_id].remove(("repeat", repeat_id))
|
418 |
for n in repeated_nodes:
|
419 |
print(f"repeating {n}")
|
420 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
|
449 |
def to_tensors(b: core.Bundle, m: ModelMapping | None) -> dict[str, torch.Tensor]:
|
|
|
3 |
import copy
|
4 |
import enum
|
5 |
import graphlib
|
|
|
6 |
|
7 |
import pydantic
|
8 |
from lynxkite.core import ops, workspace
|
|
|
99 |
return Layer(f, shape=x.shape)
|
100 |
|
101 |
|
102 |
+
@op("MSE loss")
|
103 |
+
def mse_loss(x, y):
|
104 |
+
return Layer(torch.nn.functional.mse_loss, shape=[1])
|
105 |
+
|
106 |
+
|
107 |
reg("Softmax", inputs=["x"])
|
108 |
reg(
|
109 |
"Graph conv",
|
|
|
115 |
reg("Add", inputs=["a", "b"], outputs=["x"])
|
116 |
reg("Subtract", inputs=["a", "b"], outputs=["x"])
|
117 |
reg("Multiply", inputs=["a", "b"], outputs=["x"])
|
|
|
118 |
reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
|
119 |
reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
|
120 |
reg(
|
|
|
164 |
|
165 |
|
166 |
@dataclasses.dataclass
|
167 |
+
class TensorRef:
|
168 |
"""Ops get their inputs like this. They have to return a Layer made for this input."""
|
169 |
|
170 |
+
_id: str
|
171 |
shape: tuple[int, ...]
|
172 |
|
173 |
|
|
|
178 |
module: torch.nn.Module
|
179 |
shapes: list[tuple[int, ...]] | None = None # One for each output.
|
180 |
shape: dataclasses.InitVar[tuple[int, ...] | None] = None # Convenience for single output.
|
181 |
+
# Set by ModelBuilder.
|
182 |
+
_origin_id: str | None = None
|
183 |
+
_inputs: list[TensorRef] | None = None
|
184 |
+
_outputs: list[TensorRef] | None = None
|
185 |
|
186 |
def __post_init__(self, shape):
|
187 |
assert not self.shapes or not shape, "Cannot set both shapes and shape."
|
188 |
if shape:
|
189 |
self.shapes = [shape]
|
190 |
|
191 |
+
def _for_sequential(self):
|
192 |
+
inputs = ", ".join(i._id for i in self._inputs)
|
193 |
+
outputs = ", ".join(o._id for o in self._outputs)
|
194 |
+
return self.module, f"{inputs} -> {outputs}"
|
195 |
+
|
196 |
|
197 |
class ColumnSpec(pydantic.BaseModel):
|
198 |
df: str
|
|
|
258 |
}
|
259 |
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
|
262 |
"""Builds the model described in the workspace."""
|
263 |
+
builder = ModelBuilder(ws, inputs)
|
264 |
+
return builder.build_model()
|
265 |
+
|
266 |
+
|
267 |
+
class ModelBuilder:
|
268 |
+
"""The state shared between methods that are used to build the model."""
|
269 |
+
|
270 |
+
def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
|
271 |
+
self.catalog = ops.CATALOGS[ENV]
|
272 |
+
optimizers = []
|
273 |
+
self.nodes = {}
|
274 |
+
for node in ws.nodes:
|
275 |
+
self.nodes[node.id] = node
|
276 |
+
if node.data.title == "Optimizer":
|
277 |
+
optimizers.append(node.id)
|
278 |
+
assert optimizers, "No optimizer found."
|
279 |
+
assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
|
280 |
+
[self.optimizer] = optimizers
|
281 |
+
self.dependencies = {n.id: [] for n in ws.nodes}
|
282 |
+
self.in_edges = {}
|
283 |
+
self.out_edges = {}
|
284 |
+
repeats = []
|
285 |
+
for e in ws.edges:
|
286 |
+
if self.nodes[e.target].data.title == "Repeat":
|
287 |
+
repeats.append(e.target)
|
288 |
+
self.dependencies[e.target].append(e.source)
|
289 |
+
self.in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
|
290 |
+
(e.source, e.sourceHandle)
|
291 |
+
)
|
292 |
+
self.out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
|
293 |
+
(e.target, e.targetHandle)
|
294 |
+
)
|
295 |
+
# Split repeat boxes into start and end, and insert them into the flow.
|
296 |
+
# TODO: Think about recursive repeats.
|
297 |
+
for repeat in repeats:
|
298 |
+
start_id = f"START {repeat}"
|
299 |
+
end_id = f"END {repeat}"
|
300 |
+
# repeat -> first <- real_input
|
301 |
+
# ...becomes...
|
302 |
+
# real_input -> start -> first
|
303 |
+
first, firsth = self.out_edges[repeat]["output"][0]
|
304 |
+
[(real_input, real_inputh)] = [
|
305 |
+
k for k in self.in_edges[first][firsth] if k != (repeat, "output")
|
306 |
+
]
|
307 |
+
self.dependencies[first].remove(repeat)
|
308 |
+
self.dependencies[first].append(start_id)
|
309 |
+
self.dependencies[start_id] = [real_input]
|
310 |
+
self.out_edges[real_input][real_inputh] = [
|
311 |
+
k if k != (first, firsth) else (start_id, "input")
|
312 |
+
for k in self.out_edges[real_input][real_inputh]
|
313 |
+
]
|
314 |
+
self.in_edges[start_id] = {"input": [(real_input, real_inputh)]}
|
315 |
+
self.out_edges[start_id] = {"output": [(first, firsth)]}
|
316 |
+
self.in_edges[first][firsth] = [(start_id, "output")]
|
317 |
+
# repeat <- last -> real_output
|
318 |
+
# ...becomes...
|
319 |
+
# last -> end -> real_output
|
320 |
+
last, lasth = self.in_edges[repeat]["input"][0]
|
321 |
+
[(real_output, real_outputh)] = [
|
322 |
+
k for k in self.out_edges[last][lasth] if k != (repeat, "input")
|
323 |
+
]
|
324 |
+
del self.dependencies[repeat]
|
325 |
+
self.dependencies[end_id] = [last]
|
326 |
+
self.dependencies[real_output].append(end_id)
|
327 |
+
self.out_edges[last][lasth] = [(end_id, "input")]
|
328 |
+
self.in_edges[end_id] = {"input": [(last, lasth)]}
|
329 |
+
self.out_edges[end_id] = {"output": [(real_output, real_outputh)]}
|
330 |
+
self.in_edges[real_output][real_outputh] = [
|
331 |
+
k if k != (last, lasth) else (end_id, "output")
|
332 |
+
for k in self.in_edges[real_output][real_outputh]
|
333 |
+
]
|
334 |
+
self.inv_dependencies = {n: [] for n in self.dependencies}
|
335 |
+
for k, v in self.dependencies.items():
|
336 |
+
for i in v:
|
337 |
+
self.inv_dependencies[i].append(k)
|
338 |
+
self.sizes = {}
|
339 |
+
for k, i in inputs.items():
|
340 |
+
self.sizes[k] = i.shape[-1]
|
341 |
+
self.layers = []
|
342 |
+
|
343 |
+
def all_upstream(self, node: str) -> set[str]:
|
344 |
+
"""Returns all nodes upstream of a node."""
|
345 |
+
deps = set()
|
346 |
+
for dep in self.dependencies[node]:
|
347 |
+
deps.add(dep)
|
348 |
+
deps.update(self.all_upstream(dep))
|
349 |
+
return deps
|
350 |
+
|
351 |
+
def all_downstream(self, node: str) -> set[str]:
|
352 |
+
"""Returns all nodes downstream of a node."""
|
353 |
+
deps = set()
|
354 |
+
for dep in self.inv_dependencies[node]:
|
355 |
+
deps.add(dep)
|
356 |
+
deps.update(self.all_downstream(dep))
|
357 |
+
return deps
|
358 |
+
|
359 |
+
def run_node(self, node_id: str) -> None:
|
360 |
+
"""Adds the layer(s) produced by this node to self.layers."""
|
361 |
if node_id.startswith("START "):
|
362 |
+
node = self.nodes[node_id.removeprefix("START ")]
|
363 |
elif node_id.startswith("END "):
|
364 |
+
node = self.nodes[node_id.removeprefix("END ")]
|
365 |
else:
|
366 |
+
node = self.nodes[node_id]
|
367 |
t = node.data.title
|
368 |
+
op = self.catalog[t]
|
369 |
p = op.convert_params(node.data.params)
|
|
|
|
|
|
|
|
|
370 |
inputs = {}
|
371 |
+
for n in self.in_edges.get(node_id, []):
|
372 |
+
for b, h in self.in_edges[node_id][n]:
|
373 |
i = _to_id(b, h)
|
374 |
inputs[n] = i
|
|
|
|
|
|
|
|
|
375 |
outputs = {}
|
376 |
+
for out in self.out_edges.get(node_id, []):
|
377 |
i = _to_id(node_id, out)
|
378 |
outputs[out] = i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
match t:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
case "Repeat":
|
381 |
+
if node_id.startswith("END "):
|
|
|
|
|
|
|
|
|
382 |
repeat_id = node_id.removeprefix("END ")
|
383 |
start_id = f"START {repeat_id}"
|
384 |
print(f"repeat {repeat_id} ending")
|
385 |
+
after_start = self.all_downstream(start_id)
|
386 |
+
after_end = self.all_downstream(node_id)
|
387 |
+
before_end = self.all_upstream(node_id)
|
388 |
+
affected_nodes = after_start - after_end - {node_id}
|
389 |
repeated_nodes = after_start & before_end
|
390 |
assert affected_nodes == repeated_nodes, (
|
391 |
f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
|
392 |
)
|
|
|
393 |
for n in repeated_nodes:
|
394 |
print(f"repeating {n}")
|
395 |
case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
|
396 |
+
return
|
397 |
+
layer = self.run_op(op, p, inputs, outputs)
|
398 |
+
layer._origin_id = node_id
|
399 |
+
self.layers.append(layer)
|
400 |
+
|
401 |
+
def run_op(self, op, params, inputs: dict[str, str], outputs: dict[str, str]) -> Layer:
|
402 |
+
"""Returns the layer produced by this op."""
|
403 |
+
op_inputs = [
|
404 |
+
TensorRef(inputs[i], shape=self.sizes.get(inputs[i], 1)) for i in op.inputs.keys()
|
405 |
+
]
|
406 |
+
if op.func != ops.no_op:
|
407 |
+
layer = op.func(*op_inputs, **params)
|
408 |
+
else:
|
409 |
+
layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
|
410 |
+
layer._inputs = op_inputs
|
411 |
+
layer._outputs = []
|
412 |
+
for o, shape in zip(op.outputs.keys(), layer.shapes):
|
413 |
+
layer._outputs.append(TensorRef(outputs[o], shape=shape))
|
414 |
+
self.sizes[outputs[o]] = shape
|
415 |
+
return layer
|
416 |
+
|
417 |
+
def build_model(self) -> ModelConfig:
|
418 |
+
# Walk the graph in topological order.
|
419 |
+
ts = graphlib.TopologicalSorter(self.dependencies)
|
420 |
+
for node_id in ts.static_order():
|
421 |
+
self.run_node(node_id)
|
422 |
+
return self.get_config()
|
423 |
+
|
424 |
+
def get_config(self) -> ModelConfig:
|
425 |
+
# Split the design into model and loss.
|
426 |
+
loss_nodes = set()
|
427 |
+
for node_id in self.nodes:
|
428 |
+
if "loss" in self.nodes[node_id].data.title:
|
429 |
+
loss_nodes.add(node_id)
|
430 |
+
loss_nodes |= self.all_downstream(node_id)
|
431 |
+
layers = []
|
432 |
+
loss_layers = []
|
433 |
+
for layer in self.layers:
|
434 |
+
if layer._origin_id in loss_nodes:
|
435 |
+
loss_layers.append(layer)
|
436 |
+
else:
|
437 |
+
layers.append(layer)
|
438 |
+
used_in_model = set(input._id for layer in layers for input in layer._inputs)
|
439 |
+
used_in_loss = set(input._id for layer in loss_layers for input in layer._inputs)
|
440 |
+
made_in_model = set(output._id for layer in layers for output in layer._outputs)
|
441 |
+
made_in_loss = set(output._id for layer in loss_layers for output in layer._outputs)
|
442 |
+
layers = [layer._for_sequential() for layer in layers]
|
443 |
+
loss_layers = [layer._for_sequential() for layer in loss_layers]
|
444 |
+
cfg = {}
|
445 |
+
cfg["model_inputs"] = list(used_in_model - made_in_model)
|
446 |
+
cfg["model_outputs"] = list(made_in_model & used_in_loss)
|
447 |
+
cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
|
448 |
+
# Make sure the trained output is output from the last model layer.
|
449 |
+
outputs = ", ".join(cfg["model_outputs"])
|
450 |
+
layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
|
451 |
+
# Create model.
|
452 |
+
cfg["model"] = pyg.nn.Sequential(", ".join(cfg["model_inputs"]), layers)
|
453 |
+
# Make sure the loss is output from the last loss layer.
|
454 |
+
[(lossb, lossh)] = self.in_edges[self.optimizer]["loss"]
|
455 |
+
lossi = _to_id(lossb, lossh)
|
456 |
+
loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
|
457 |
+
# Create loss function.
|
458 |
+
cfg["loss"] = pyg.nn.Sequential(", ".join(cfg["loss_inputs"]), loss_layers)
|
459 |
+
assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
|
460 |
+
# Create optimizer.
|
461 |
+
op = self.catalog["Optimizer"]
|
462 |
+
p = op.convert_params(self.nodes[self.optimizer].data.params)
|
463 |
+
o = getattr(torch.optim, p["type"].name)
|
464 |
+
cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
|
465 |
+
print(cfg)
|
466 |
+
return ModelConfig(**cfg)
|
467 |
|
468 |
|
469 |
def to_tensors(b: core.Bundle, m: ModelMapping | None) -> dict[str, torch.Tensor]:
|