darabos commited on
Commit
e7d2291
·
1 Parent(s): 253ca3d

Split build_model() so we can repeat a part.

Browse files
examples/Model definition CHANGED
@@ -1,12 +1,5 @@
1
  {
2
  "edges": [
3
- {
4
- "id": "MSE loss 1 Optimizer 2",
5
- "source": "MSE loss 1",
6
- "sourceHandle": "loss",
7
- "target": "Optimizer 2",
8
- "targetHandle": "loss"
9
- },
10
  {
11
  "id": "Repeat 1 Linear 2",
12
  "source": "Repeat 1",
@@ -14,13 +7,6 @@
14
  "target": "Linear 2",
15
  "targetHandle": "x"
16
  },
17
- {
18
- "id": "Activation 1 MSE loss 1",
19
- "source": "Activation 1",
20
- "sourceHandle": "output",
21
- "target": "MSE loss 1",
22
- "targetHandle": "x"
23
- },
24
  {
25
  "id": "Linear 2 Activation 1",
26
  "source": "Linear 2",
@@ -35,72 +21,37 @@
35
  "target": "Repeat 1",
36
  "targetHandle": "input"
37
  },
38
- {
39
- "id": "Input: tensor 3 MSE loss 1",
40
- "source": "Input: tensor 3",
41
- "sourceHandle": "x",
42
- "target": "MSE loss 1",
43
- "targetHandle": "y"
44
- },
45
  {
46
  "id": "Input: tensor 1 Linear 2",
47
  "source": "Input: tensor 1",
48
  "sourceHandle": "x",
49
  "target": "Linear 2",
50
  "targetHandle": "x"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
52
  ],
53
  "env": "PyTorch model",
54
  "nodes": [
55
- {
56
- "data": {
57
- "display": null,
58
- "error": null,
59
- "input_metadata": null,
60
- "meta": {
61
- "inputs": {
62
- "x": {
63
- "name": "x",
64
- "position": "bottom",
65
- "type": {
66
- "type": "tensor"
67
- }
68
- },
69
- "y": {
70
- "name": "y",
71
- "position": "bottom",
72
- "type": {
73
- "type": "tensor"
74
- }
75
- }
76
- },
77
- "name": "MSE loss",
78
- "outputs": {
79
- "loss": {
80
- "name": "loss",
81
- "position": "top",
82
- "type": {
83
- "type": "tensor"
84
- }
85
- }
86
- },
87
- "params": {},
88
- "type": "basic"
89
- },
90
- "params": {},
91
- "status": "planned",
92
- "title": "MSE loss"
93
- },
94
- "dragHandle": ".bg-primary",
95
- "height": 200.0,
96
- "id": "MSE loss 1",
97
- "position": {
98
- "x": 315.0,
99
- "y": -510.0
100
- },
101
- "type": "basic",
102
- "width": 200.0
103
- },
104
  {
105
  "data": {
106
  "__execution_delay": 0.0,
@@ -384,7 +335,7 @@
384
  "height": 200.0,
385
  "id": "Input: tensor 1",
386
  "position": {
387
- "x": 97.60681762952905,
388
  "y": 293.6278596776366
389
  },
390
  "type": "basic",
@@ -434,8 +385,61 @@
434
  "height": 200.0,
435
  "id": "Input: tensor 3",
436
  "position": {
437
- "x": 862.4359094222825,
438
- "y": -290.0677203273021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  },
440
  "type": "basic",
441
  "width": 200.0
 
1
  {
2
  "edges": [
 
 
 
 
 
 
 
3
  {
4
  "id": "Repeat 1 Linear 2",
5
  "source": "Repeat 1",
 
7
  "target": "Linear 2",
8
  "targetHandle": "x"
9
  },
 
 
 
 
 
 
 
10
  {
11
  "id": "Linear 2 Activation 1",
12
  "source": "Linear 2",
 
21
  "target": "Repeat 1",
22
  "targetHandle": "input"
23
  },
 
 
 
 
 
 
 
24
  {
25
  "id": "Input: tensor 1 Linear 2",
26
  "source": "Input: tensor 1",
27
  "sourceHandle": "x",
28
  "target": "Linear 2",
29
  "targetHandle": "x"
30
+ },
31
+ {
32
+ "id": "MSE loss 2 Optimizer 2",
33
+ "source": "MSE loss 2",
34
+ "sourceHandle": "output",
35
+ "target": "Optimizer 2",
36
+ "targetHandle": "loss"
37
+ },
38
+ {
39
+ "id": "Activation 1 MSE loss 2",
40
+ "source": "Activation 1",
41
+ "sourceHandle": "output",
42
+ "target": "MSE loss 2",
43
+ "targetHandle": "x"
44
+ },
45
+ {
46
+ "id": "Input: tensor 3 MSE loss 2",
47
+ "source": "Input: tensor 3",
48
+ "sourceHandle": "x",
49
+ "target": "MSE loss 2",
50
+ "targetHandle": "y"
51
  }
52
  ],
53
  "env": "PyTorch model",
54
  "nodes": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  {
56
  "data": {
57
  "__execution_delay": 0.0,
 
335
  "height": 200.0,
336
  "id": "Input: tensor 1",
337
  "position": {
338
+ "x": 85.83561484252238,
339
  "y": 293.6278596776366
340
  },
341
  "type": "basic",
 
385
  "height": 200.0,
386
  "id": "Input: tensor 3",
387
  "position": {
388
+ "x": 485.8840220312055,
389
+ "y": -149.86223034126274
390
+ },
391
+ "type": "basic",
392
+ "width": 200.0
393
+ },
394
+ {
395
+ "data": {
396
+ "display": null,
397
+ "error": null,
398
+ "input_metadata": null,
399
+ "meta": {
400
+ "inputs": {
401
+ "x": {
402
+ "name": "x",
403
+ "position": "bottom",
404
+ "type": {
405
+ "type": "<class 'inspect._empty'>"
406
+ }
407
+ },
408
+ "y": {
409
+ "name": "y",
410
+ "position": "bottom",
411
+ "type": {
412
+ "type": "<class 'inspect._empty'>"
413
+ }
414
+ }
415
+ },
416
+ "name": "MSE loss",
417
+ "outputs": {
418
+ "output": {
419
+ "name": "output",
420
+ "position": "top",
421
+ "type": {
422
+ "type": "None"
423
+ }
424
+ }
425
+ },
426
+ "params": {},
427
+ "position": {
428
+ "x": 937.0,
429
+ "y": 270.0
430
+ },
431
+ "type": "basic"
432
+ },
433
+ "params": {},
434
+ "status": "planned",
435
+ "title": "MSE loss"
436
+ },
437
+ "dragHandle": ".bg-primary",
438
+ "height": 200.0,
439
+ "id": "MSE loss 2",
440
+ "position": {
441
+ "x": 690.0,
442
+ "y": -480.0
443
  },
444
  "type": "basic",
445
  "width": 200.0
examples/Model use CHANGED
@@ -579,54 +579,54 @@
579
  ],
580
  "data": [
581
  [
582
- "[0.31518555 0.49643308 0.11509258 0.95458382]",
583
- "[1.31518555 1.49643302 1.11509252 1.95458388]",
584
- "[1.3819222450256348, -0.005686390213668346, 1.3793643712997437, 1.581865906715393]"
585
  ],
586
  [
587
- "[0.02162331 0.81861657 0.92468154 0.07808572]",
588
- "[1.02162337 1.81861663 1.92468154 1.07808566]",
589
- "[1.312654972076416, -0.00689137727022171, 1.4941580295562744, 1.243792176246643]"
590
  ],
591
  [
592
- "[0.94221359 0.57740951 0.98649532 0.40934443]",
593
- "[1.94221354 1.57740951 1.98649526 1.40934443]",
594
- "[1.9255921840667725, -0.008701151236891747, 1.751355767250061, 1.79597806930542]"
595
  ],
596
  [
597
- "[0.34084332 0.73018837 0.54168713 0.91440833]",
598
- "[1.34084332 1.73018837 1.54168713 1.91440833]",
599
- "[1.6509568691253662, -0.007272087037563324, 1.5942981243133545, 1.81572687625885]"
600
  ],
601
  [
602
- "[0.85566247 0.83362883 0.48424995 0.25265992]",
603
- "[1.85566247 1.83362889 1.48424995 1.25265992]",
604
- "[1.7482354640960693, -0.0063837491907179356, 1.4504402875900269, 1.5329445600509644]"
605
  ],
606
  [
607
- "[0.02235305 0.52774918 0.7331115 0.84358269]",
608
- "[1.02235305 1.52774918 1.7331115 1.84358263]",
609
- "[1.3979142904281616, -0.007555779069662094, 1.6136289834976196, 1.6417407989501953]"
610
  ],
611
  [
612
- "[0.9829582 0.59269661 0.40120947 0.95487177]",
613
- "[1.9829582 1.59269667 1.40120947 1.95487177]",
614
- "[1.9523842334747314, -0.00748100271448493, 1.6264307498931885, 1.9942888021469116]"
615
  ],
616
  [
617
- "[0.49584109 0.80599248 0.07096875 0.75872749]",
618
- "[1.49584103 1.80599248 1.07096875 1.75872755]",
619
- "[1.5513110160827637, -0.005337317008525133, 1.3384482860565186, 1.5973539352416992]"
620
  ],
621
  [
622
- "[0.00497234 0.39319336 0.57054168 0.75150961]",
623
- "[1.00497234 1.39319336 1.57054162 1.75150967]",
624
- "[1.2277441024780273, -0.0067505668848752975, 1.4969637393951416, 1.4524610042572021]"
625
  ],
626
  [
627
- "[0.59492421 0.90274489 0.38069052 0.46101224]",
628
- "[1.59492421 1.90274489 1.38069057 1.46101224]",
629
- "[1.6593225002288818, -0.006088308058679104, 1.4240546226501465, 1.570335865020752]"
630
  ]
631
  ]
632
  },
@@ -648,10 +648,6 @@
648
  "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
  "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
  ],
651
- [
652
- "[0.19409031 0.68692201 0.60667384 0.57829887]",
653
- "[1.19409037 1.68692207 1.60667384 1.57829881]"
654
- ],
655
  [
656
  "[0.76807946 0.98855817 0.08259124 0.01730657]",
657
  "[1.76807952 1.98855817 1.0825913 1.01730657]"
@@ -684,10 +680,6 @@
684
  "[0.98324287 0.99464184 0.14008355 0.47651017]",
685
  "[1.98324287 1.99464178 1.14008355 1.47651017]"
686
  ],
687
- [
688
- "[0.11693293 0.49860179 0.55020827 0.88832849]",
689
- "[1.11693287 1.49860179 1.55020833 1.88832855]"
690
- ],
691
  [
692
  "[0.48959708 0.48549271 0.32688856 0.356677 ]",
693
  "[1.48959708 1.48549271 1.32688856 1.35667706]"
@@ -716,10 +708,6 @@
716
  "[0.24388778 0.07268471 0.68350857 0.73431659]",
717
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
718
  ],
719
- [
720
- "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
721
- "[1.62569475 1.9881897 1.83639622 1.98288584]"
722
- ],
723
  [
724
  "[0.56922203 0.98222166 0.76851749 0.28615737]",
725
  "[1.56922197 1.9822216 1.76851749 1.28615737]"
@@ -732,14 +720,14 @@
732
  "[0.90817457 0.89270043 0.38583666 0.66566533]",
733
  "[1.90817451 1.89270043 1.3858366 1.66566539]"
734
  ],
735
- [
736
- "[0.48507756 0.80808765 0.77162558 0.47834778]",
737
- "[1.48507762 1.80808759 1.77162552 1.47834778]"
738
- ],
739
  [
740
  "[0.68062544 0.98093534 0.14778823 0.53244978]",
741
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
742
  ],
 
 
 
 
743
  [
744
  "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
745
  "[1.79121017 1.54161119 1.69369793 1.15207696]"
@@ -785,20 +773,20 @@
785
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
786
  ],
787
  [
788
- "[0.44330525 0.09997386 0.89025736 0.90507984]",
789
- "[1.44330525 1.09997392 1.89025736 1.90507984]"
790
  ],
791
  [
792
- "[0.72290605 0.96945059 0.68354797 0.15270454]",
793
- "[1.72290611 1.96945059 1.68354797 1.15270448]"
794
  ],
795
  [
796
- "[0.75292218 0.81470108 0.49657214 0.56217098]",
797
- "[1.75292218 1.81470108 1.49657214 1.56217098]"
798
  ],
799
  [
800
- "[0.33480108 0.59181517 0.76198453 0.98062384]",
801
- "[1.33480108 1.59181523 1.76198459 1.98062384]"
802
  ],
803
  [
804
  "[0.52784437 0.54268694 0.12358981 0.72116476]",
@@ -808,6 +796,10 @@
808
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
809
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
810
  ],
 
 
 
 
811
  [
812
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
813
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
@@ -836,13 +828,17 @@
836
  "[0.18720162 0.74115586 0.98626411 0.30355608]",
837
  "[1.18720162 1.74115586 1.98626411 1.30355608]"
838
  ],
 
 
 
 
839
  [
840
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
841
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
842
  ],
843
  [
844
- "[0.32565445 0.90939188 0.07488042 0.13730896]",
845
- "[1.32565451 1.90939188 1.07488036 1.13730896]"
846
  ],
847
  [
848
  "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
@@ -873,8 +869,12 @@
873
  "[1.60075855 1.12234759 1.00614405 1.30560958]"
874
  ],
875
  [
876
- "[0.39147133 0.29854035 0.84663737 0.58175623]",
877
- "[1.39147139 1.29854035 1.84663737 1.58175623]"
 
 
 
 
878
  ],
879
  [
880
  "[0.6080932 0.56563014 0.32107437 0.72599429]",
@@ -908,6 +908,10 @@
908
  "[0.76933283 0.86241865 0.44114518 0.65644735]",
909
  "[1.76933289 1.86241865 1.44114518 1.65644741]"
910
  ],
 
 
 
 
911
  [
912
  "[0.15064228 0.03198934 0.25754827 0.51484001]",
913
  "[1.15064228 1.03198934 1.25754833 1.51484001]"
@@ -916,10 +920,6 @@
916
  "[0.12024075 0.21342516 0.56858408 0.58644271]",
917
  "[1.12024069 1.21342516 1.56858408 1.58644271]"
918
  ],
919
- [
920
- "[0.91730917 0.22574073 0.09591609 0.33056474]",
921
- "[1.91730917 1.22574067 1.09591603 1.33056474]"
922
- ],
923
  [
924
  "[0.49691743 0.61873293 0.90698647 0.94486356]",
925
  "[1.49691749 1.61873293 1.90698647 1.94486356]"
@@ -960,6 +960,10 @@
960
  "[0.47856545 0.46267092 0.6376707 0.84747767]",
961
  "[1.47856545 1.46267092 1.63767076 1.84747767]"
962
  ],
 
 
 
 
963
  [
964
  "[0.43500566 0.66041756 0.80293626 0.96224713]",
965
  "[1.43500566 1.66041756 1.80293632 1.96224713]"
@@ -988,10 +992,6 @@
988
  "[0.72795159 0.79317838 0.27832931 0.96576637]",
989
  "[1.72795153 1.79317832 1.27832937 1.96576643]"
990
  ],
991
- [
992
- "[0.87608397 0.93200487 0.80169648 0.37758952]",
993
- "[1.87608397 1.93200493 1.80169654 1.37758946]"
994
- ],
995
  [
996
  "[0.68891573 0.25576538 0.96339929 0.503833 ]",
997
  "[1.68891573 1.25576544 1.96339929 1.50383306]"
@@ -1000,7 +1000,7 @@
1000
  }
1001
  },
1002
  "other": {
1003
- "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x759513340220>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['END_Repeat_1_output', 'Input__tensor_3_x'], loss=Sequential(\n (0) - <function mse_loss at 0x759513341d00>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_1_loss\n (1) - Identity(): MSE_loss_1_loss -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
@@ -1270,7 +1270,7 @@
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
- "epochs": "150",
1274
  "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
 
579
  ],
580
  "data": [
581
  [
582
+ "[0.33480108 0.59181517 0.76198453 0.98062384]",
583
+ "[1.33480108 1.59181523 1.76198459 1.98062384]",
584
+ "[1.3419755697250366, 1.5946478843688965, 1.7717586755752563, 1.9897401332855225]"
585
  ],
586
  [
587
+ "[0.91730917 0.22574073 0.09591609 0.33056474]",
588
+ "[1.91730917 1.22574067 1.09591603 1.33056474]",
589
+ "[1.900892972946167, 1.2247941493988037, 1.0862866640090942, 1.323314905166626]"
590
  ],
591
  [
592
+ "[0.32565445 0.90939188 0.07488042 0.13730896]",
593
+ "[1.32565451 1.90939188 1.07488036 1.13730896]",
594
+ "[1.3460955619812012, 1.8960161209106445, 1.0530263185501099, 1.1075329780578613]"
595
  ],
596
  [
597
+ "[0.87608397 0.93200487 0.80169648 0.37758952]",
598
+ "[1.87608397 1.93200493 1.80169654 1.37758946]",
599
+ "[1.87070894241333, 1.9386992454528809, 1.8151044845581055, 1.3952441215515137]"
600
  ],
601
  [
602
+ "[0.39147133 0.29854035 0.84663737 0.58175623]",
603
+ "[1.39147139 1.29854035 1.84663737 1.58175623]",
604
+ "[1.3877646923065186, 1.2995290756225586, 1.847062587738037, 1.583693265914917]"
605
  ],
606
  [
607
+ "[0.48507756 0.80808765 0.77162558 0.47834778]",
608
+ "[1.48507762 1.80808759 1.77162552 1.47834778]",
609
+ "[1.490919828414917, 1.8087174892425537, 1.7757861614227295, 1.4824031591415405]"
610
  ],
611
  [
612
+ "[0.75292218 0.81470108 0.49657214 0.56217098]",
613
+ "[1.75292218 1.81470108 1.49657214 1.56217098]",
614
+ "[1.7527031898498535, 1.8176040649414062, 1.503413438796997, 1.570152759552002]"
615
  ],
616
  [
617
+ "[0.11693293 0.49860179 0.55020827 0.88832849]",
618
+ "[1.11693287 1.49860179 1.55020833 1.88832855]",
619
+ "[1.1314976215362549, 1.4944026470184326, 1.546830177307129, 1.8803892135620117]"
620
  ],
621
  [
622
+ "[0.19409031 0.68692201 0.60667384 0.57829887]",
623
+ "[1.19409037 1.68692207 1.60667384 1.57829881]",
624
+ "[1.2091591358184814, 1.6816589832305908, 1.6011345386505127, 1.5684995651245117]"
625
  ],
626
  [
627
+ "[0.62569475 0.9881897 0.83639616 0.9828859 ]",
628
+ "[1.62569475 1.9881897 1.83639622 1.98288584]",
629
+ "[1.6314740180969238, 1.996805191040039, 1.8592857122421265, 2.0075552463531494]"
630
  ]
631
  ]
632
  },
 
648
  "[0.11560339 0.57495481 0.76535827 0.0391947 ]",
649
  "[1.11560345 1.57495475 1.76535821 1.0391947 ]"
650
  ],
 
 
 
 
651
  [
652
  "[0.76807946 0.98855817 0.08259124 0.01730657]",
653
  "[1.76807952 1.98855817 1.0825913 1.01730657]"
 
680
  "[0.98324287 0.99464184 0.14008355 0.47651017]",
681
  "[1.98324287 1.99464178 1.14008355 1.47651017]"
682
  ],
 
 
 
 
683
  [
684
  "[0.48959708 0.48549271 0.32688856 0.356677 ]",
685
  "[1.48959708 1.48549271 1.32688856 1.35667706]"
 
708
  "[0.24388778 0.07268471 0.68350857 0.73431659]",
709
  "[1.24388778 1.07268476 1.68350863 1.73431659]"
710
  ],
 
 
 
 
711
  [
712
  "[0.56922203 0.98222166 0.76851749 0.28615737]",
713
  "[1.56922197 1.9822216 1.76851749 1.28615737]"
 
720
  "[0.90817457 0.89270043 0.38583666 0.66566533]",
721
  "[1.90817451 1.89270043 1.3858366 1.66566539]"
722
  ],
 
 
 
 
723
  [
724
  "[0.68062544 0.98093534 0.14778823 0.53244978]",
725
  "[1.68062544 1.98093534 1.14778829 1.53244972]"
726
  ],
727
+ [
728
+ "[0.31518555 0.49643308 0.11509258 0.95458382]",
729
+ "[1.31518555 1.49643302 1.11509252 1.95458388]"
730
+ ],
731
  [
732
  "[0.79121011 0.54161114 0.69369799 0.1520769 ]",
733
  "[1.79121017 1.54161119 1.69369793 1.15207696]"
 
773
  "[1.78956437 1.87284744 1.06880784 1.03455889]"
774
  ],
775
  [
776
+ "[0.94221359 0.57740951 0.98649532 0.40934443]",
777
+ "[1.94221354 1.57740951 1.98649526 1.40934443]"
778
  ],
779
  [
780
+ "[0.00497234 0.39319336 0.57054168 0.75150961]",
781
+ "[1.00497234 1.39319336 1.57054162 1.75150967]"
782
  ],
783
  [
784
+ "[0.44330525 0.09997386 0.89025736 0.90507984]",
785
+ "[1.44330525 1.09997392 1.89025736 1.90507984]"
786
  ],
787
  [
788
+ "[0.72290605 0.96945059 0.68354797 0.15270454]",
789
+ "[1.72290611 1.96945059 1.68354797 1.15270448]"
790
  ],
791
  [
792
  "[0.52784437 0.54268694 0.12358981 0.72116476]",
 
796
  "[0.73217702 0.65233225 0.44077861 0.33837909]",
797
  "[1.73217702 1.65233231 1.44077861 1.33837914]"
798
  ],
799
+ [
800
+ "[0.34084332 0.73018837 0.54168713 0.91440833]",
801
+ "[1.34084332 1.73018837 1.54168713 1.91440833]"
802
+ ],
803
  [
804
  "[0.60110539 0.3618983 0.32342511 0.98672163]",
805
  "[1.60110545 1.3618983 1.32342505 1.98672163]"
 
828
  "[0.18720162 0.74115586 0.98626411 0.30355608]",
829
  "[1.18720162 1.74115586 1.98626411 1.30355608]"
830
  ],
831
+ [
832
+ "[0.85566247 0.83362883 0.48424995 0.25265992]",
833
+ "[1.85566247 1.83362889 1.48424995 1.25265992]"
834
+ ],
835
  [
836
  "[0.95928186 0.84273899 0.71514636 0.38619852]",
837
  "[1.95928192 1.84273899 1.7151463 1.38619852]"
838
  ],
839
  [
840
+ "[0.9829582 0.59269661 0.40120947 0.95487177]",
841
+ "[1.9829582 1.59269667 1.40120947 1.95487177]"
842
  ],
843
  [
844
  "[0.79905868 0.89367443 0.75429088 0.3190186 ]",
 
869
  "[1.60075855 1.12234759 1.00614405 1.30560958]"
870
  ],
871
  [
872
+ "[0.02162331 0.81861657 0.92468154 0.07808572]",
873
+ "[1.02162337 1.81861663 1.92468154 1.07808566]"
874
+ ],
875
+ [
876
+ "[0.02235305 0.52774918 0.7331115 0.84358269]",
877
+ "[1.02235305 1.52774918 1.7331115 1.84358263]"
878
  ],
879
  [
880
  "[0.6080932 0.56563014 0.32107437 0.72599429]",
 
908
  "[0.76933283 0.86241865 0.44114518 0.65644735]",
909
  "[1.76933289 1.86241865 1.44114518 1.65644741]"
910
  ],
911
+ [
912
+ "[0.59492421 0.90274489 0.38069052 0.46101224]",
913
+ "[1.59492421 1.90274489 1.38069057 1.46101224]"
914
+ ],
915
  [
916
  "[0.15064228 0.03198934 0.25754827 0.51484001]",
917
  "[1.15064228 1.03198934 1.25754833 1.51484001]"
 
920
  "[0.12024075 0.21342516 0.56858408 0.58644271]",
921
  "[1.12024069 1.21342516 1.56858408 1.58644271]"
922
  ],
 
 
 
 
923
  [
924
  "[0.49691743 0.61873293 0.90698647 0.94486356]",
925
  "[1.49691749 1.61873293 1.90698647 1.94486356]"
 
960
  "[0.47856545 0.46267092 0.6376707 0.84747767]",
961
  "[1.47856545 1.46267092 1.63767076 1.84747767]"
962
  ],
963
+ [
964
+ "[0.49584109 0.80599248 0.07096875 0.75872749]",
965
+ "[1.49584103 1.80599248 1.07096875 1.75872755]"
966
+ ],
967
  [
968
  "[0.43500566 0.66041756 0.80293626 0.96224713]",
969
  "[1.43500566 1.66041756 1.80293632 1.96224713]"
 
992
  "[0.72795159 0.79317838 0.27832931 0.96576637]",
993
  "[1.72795153 1.79317832 1.27832937 1.96576643]"
994
  ],
 
 
 
 
995
  [
996
  "[0.68891573 0.25576538 0.96339929 0.503833 ]",
997
  "[1.68891573 1.25576544 1.96339929 1.50383306]"
 
1000
  }
1001
  },
1002
  "other": {
1003
+ "model": "ModelConfig(model=Sequential(\n (0) - Identity(): Input__tensor_1_x -> START_Repeat_1_output\n (1) - Linear(in_features=4, out_features=4, bias=True): START_Repeat_1_output -> Linear_2_output\n (2) - <function leaky_relu at 0x759ed4f2c360>: Linear_2_output -> Activation_1_output\n (3) - Identity(): Activation_1_output -> END_Repeat_1_output\n (4) - Identity(): END_Repeat_1_output -> END_Repeat_1_output\n), model_inputs=['Input__tensor_1_x'], model_outputs=['END_Repeat_1_output'], loss_inputs=['END_Repeat_1_output', 'Input__tensor_3_x'], loss=Sequential(\n (0) - <function mse_loss at 0x759ed4f2de40>: END_Repeat_1_output, Input__tensor_3_x -> MSE_loss_2_output\n (1) - Identity(): MSE_loss_2_output -> loss\n), optimizer=SGD (\nParameter Group 0\n dampening: 0\n differentiable: False\n foreach: None\n fused: None\n lr: 0.1\n maximize: False\n momentum: 0\n nesterov: False\n weight_decay: 0\n), source_workspace=None, trained=True)"
1004
  },
1005
  "relations": []
1006
  },
 
1270
  "type": "basic"
1271
  },
1272
  "params": {
1273
+ "epochs": "1500",
1274
  "input_mapping": "{\"map\":{\"Input__tensor_1_x\":{\"df\":\"df_train\",\"column\":\"x\"},\"Input__tensor_3_x\":{\"df\":\"df_train\",\"column\":\"y\"}}}",
1275
  "model_name": "model"
1276
  },
lynxkite-graph-analytics/src/lynxkite_graph_analytics/pytorch_model_ops.py CHANGED
@@ -3,7 +3,6 @@
3
  import copy
4
  import enum
5
  import graphlib
6
- import types
7
 
8
  import pydantic
9
  from lynxkite.core import ops, workspace
@@ -100,6 +99,11 @@ def activation(x, *, type: ActivationTypes = ActivationTypes.ReLU):
100
  return Layer(f, shape=x.shape)
101
 
102
 
 
 
 
 
 
103
  reg("Softmax", inputs=["x"])
104
  reg(
105
  "Graph conv",
@@ -111,7 +115,6 @@ reg("Concatenate", inputs=["a", "b"], outputs=["x"])
111
  reg("Add", inputs=["a", "b"], outputs=["x"])
112
  reg("Subtract", inputs=["a", "b"], outputs=["x"])
113
  reg("Multiply", inputs=["a", "b"], outputs=["x"])
114
- reg("MSE loss", inputs=["x", "y"], outputs=["loss"])
115
  reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
116
  reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
117
  reg(
@@ -161,10 +164,10 @@ def _to_id(*strings: str) -> str:
161
 
162
 
163
  @dataclasses.dataclass
164
- class OpInput:
165
  """Ops get their inputs like this. They have to return a Layer made for this input."""
166
 
167
- id: str
168
  shape: tuple[int, ...]
169
 
170
 
@@ -175,12 +178,21 @@ class Layer:
175
  module: torch.nn.Module
176
  shapes: list[tuple[int, ...]] | None = None # One for each output.
177
  shape: dataclasses.InitVar[tuple[int, ...] | None] = None # Convenience for single output.
 
 
 
 
178
 
179
  def __post_init__(self, shape):
180
  assert not self.shapes or not shape, "Cannot set both shapes and shape."
181
  if shape:
182
  self.shapes = [shape]
183
 
 
 
 
 
 
184
 
185
  class ColumnSpec(pydantic.BaseModel):
186
  df: str
@@ -246,204 +258,212 @@ class ModelConfig:
246
  }
247
 
248
 
249
- def _add_op(op, params, inputs, outputs, sizes, layers):
250
- op_inputs = []
251
- for i in op.inputs.keys():
252
- id = getattr(inputs, i)
253
- op_inputs.append(OpInput(id, shape=sizes.get(id, 1)))
254
- if op.func != ops.no_op:
255
- layer = op.func(*op_inputs, **params)
256
- else:
257
- layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
258
- input_ids = ", ".join(i.id for i in op_inputs)
259
- output_ids = []
260
- for o, shape in zip(op.outputs.keys(), layer.shapes):
261
- id = getattr(outputs, o)
262
- output_ids.append(id)
263
- sizes[id] = shape
264
- output_ids = ", ".join(output_ids)
265
- layers.append((layer.module, f"{input_ids} -> {output_ids}"))
266
-
267
-
268
- def _all_dependencies(node: str, dependencies: dict[str, list[str]]) -> set[str]:
269
- """Returns all dependencies of a node."""
270
- deps = set()
271
- for dep in dependencies[node]:
272
- deps.add(dep)
273
- deps.update(_all_dependencies(dep, dependencies))
274
- return deps
275
-
276
-
277
  def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
278
  """Builds the model described in the workspace."""
279
- catalog = ops.CATALOGS[ENV]
280
- optimizers = []
281
- nodes = {}
282
- for node in ws.nodes:
283
- nodes[node.id] = node
284
- if node.data.title == "Optimizer":
285
- optimizers.append(node.id)
286
- assert optimizers, "No optimizer found."
287
- assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
288
- [optimizer] = optimizers
289
- dependencies = {n.id: [] for n in ws.nodes}
290
- inv_dependencies = {n.id: [] for n in ws.nodes}
291
- in_edges = {}
292
- out_edges = {}
293
- repeats = []
294
- for e in ws.edges:
295
- if nodes[e.target].data.title == "Repeat":
296
- repeats.append(e.target)
297
- dependencies[e.target].append(e.source)
298
- inv_dependencies[e.source].append(e.target)
299
- in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
300
- (e.source, e.sourceHandle)
301
- )
302
- out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
303
- (e.target, e.targetHandle)
304
- )
305
- # Split repeat boxes into start and end, and insert them into the flow.
306
- # TODO: Think about recursive repeats.
307
- for repeat in repeats:
308
- start_id = f"START {repeat}"
309
- end_id = f"END {repeat}"
310
- # repeat -> first <- real_input
311
- # ...becomes...
312
- # real_input -> start -> first
313
- first, firsth = out_edges[repeat]["output"][0]
314
- [(real_input, real_inputh)] = [
315
- k for k in in_edges[first][firsth] if k != (repeat, "output")
316
- ]
317
- dependencies[first].remove(repeat)
318
- dependencies[first].append(start_id)
319
- dependencies[start_id] = [real_input]
320
- out_edges[real_input][real_inputh] = [
321
- k if k != (first, firsth) else (start_id, "input")
322
- for k in out_edges[real_input][real_inputh]
323
- ]
324
- in_edges[start_id] = {"input": [(real_input, real_inputh)]}
325
- out_edges[start_id] = {"output": [(first, firsth)]}
326
- in_edges[first][firsth] = [(start_id, "output")]
327
- # repeat <- last -> real_output
328
- # ...becomes...
329
- # last -> end -> real_output
330
- last, lasth = in_edges[repeat]["input"][0]
331
- [(real_output, real_outputh)] = [
332
- k for k in out_edges[last][lasth] if k != (repeat, "input")
333
- ]
334
- del dependencies[repeat]
335
- dependencies[end_id] = [last]
336
- dependencies[real_output].append(end_id)
337
- out_edges[last][lasth] = [(end_id, "input")]
338
- in_edges[end_id] = {"input": [(last, lasth)]}
339
- out_edges[end_id] = {"output": [(real_output, real_outputh)]}
340
- in_edges[real_output][real_outputh] = [
341
- k if k != (last, lasth) else (end_id, "output")
342
- for k in in_edges[real_output][real_outputh]
343
- ]
344
- # Walk the graph in topological order.
345
- sizes = {}
346
- for k, i in inputs.items():
347
- sizes[k] = i.shape[-1]
348
- ts = graphlib.TopologicalSorter(dependencies)
349
- layers = []
350
- loss_layers = []
351
- regions: dict[str, set[str]] = {node_id: set() for node_id in dependencies}
352
- cfg = {}
353
- used_in_model = set()
354
- made_in_model = set()
355
- used_in_loss = set()
356
- made_in_loss = set()
357
- for node_id in ts.static_order():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  if node_id.startswith("START "):
359
- node = nodes[node_id.removeprefix("START ")]
360
  elif node_id.startswith("END "):
361
- node = nodes[node_id.removeprefix("END ")]
362
  else:
363
- node = nodes[node_id]
364
  t = node.data.title
365
- op = catalog[t]
366
  p = op.convert_params(node.data.params)
367
- for b in dependencies[node_id]:
368
- regions[node_id] |= regions[b]
369
- if "loss" in t:
370
- regions[node_id].add("loss")
371
  inputs = {}
372
- for n in in_edges.get(node_id, []):
373
- for b, h in in_edges[node_id][n]:
374
  i = _to_id(b, h)
375
  inputs[n] = i
376
- if "loss" in regions[node_id]:
377
- used_in_loss.add(i)
378
- else:
379
- used_in_model.add(i)
380
  outputs = {}
381
- for out in out_edges.get(node_id, []):
382
  i = _to_id(node_id, out)
383
  outputs[out] = i
384
- if not t.startswith("Input:"): # The outputs of inputs are not "made" by us.
385
- if "loss" in regions[node_id]:
386
- made_in_loss.add(i)
387
- else:
388
- made_in_model.add(i)
389
- inputs = types.SimpleNamespace(**inputs)
390
- outputs = types.SimpleNamespace(**outputs)
391
- ls = loss_layers if "loss" in regions[node_id] else layers
392
  match t:
393
- case "MSE loss":
394
- ls.append(
395
- (
396
- torch.nn.functional.mse_loss,
397
- f"{inputs.x}, {inputs.y} -> {outputs.loss}",
398
- )
399
- )
400
  case "Repeat":
401
- ls.append((torch.nn.Identity(), f"{inputs.input} -> {outputs.output}"))
402
- sizes[outputs.output] = sizes.get(inputs.input, 1)
403
- if node_id.startswith("START "):
404
- regions[node_id].add(("repeat", node_id.removeprefix("START ")))
405
- else:
406
  repeat_id = node_id.removeprefix("END ")
407
  start_id = f"START {repeat_id}"
408
  print(f"repeat {repeat_id} ending")
409
- after_start = _all_dependencies(start_id, inv_dependencies)
410
- after_end = _all_dependencies(node_id, inv_dependencies)
411
- before_end = _all_dependencies(node_id, dependencies)
412
- affected_nodes = after_start - after_end
413
  repeated_nodes = after_start & before_end
414
  assert affected_nodes == repeated_nodes, (
415
  f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
416
  )
417
- regions[node_id].remove(("repeat", repeat_id))
418
  for n in repeated_nodes:
419
  print(f"repeating {n}")
420
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
421
- pass
422
- case _:
423
- _add_op(op, p, inputs, outputs, sizes, ls)
424
- cfg["model_inputs"] = list(used_in_model - made_in_model)
425
- cfg["model_outputs"] = list(made_in_model & used_in_loss)
426
- cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
427
- # Make sure the trained output is output from the last model layer.
428
- outputs = ", ".join(cfg["model_outputs"])
429
- layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
430
- # Create model.
431
- cfg["model"] = pyg.nn.Sequential(", ".join(cfg["model_inputs"]), layers)
432
- # Make sure the loss is output from the last loss layer.
433
- [(lossb, lossh)] = in_edges[optimizer]["loss"]
434
- lossi = _to_id(lossb, lossh)
435
- loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
436
- # Create loss function.
437
- cfg["loss"] = pyg.nn.Sequential(", ".join(cfg["loss_inputs"]), loss_layers)
438
- assert not list(cfg["loss"].parameters()), (
439
- f"loss should have no parameters: {list(cfg['loss'].parameters())}"
440
- )
441
- # Create optimizer.
442
- op = catalog["Optimizer"]
443
- p = op.convert_params(nodes[optimizer].data.params)
444
- o = getattr(torch.optim, p["type"].name)
445
- cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
446
- return ModelConfig(**cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
 
449
  def to_tensors(b: core.Bundle, m: ModelMapping | None) -> dict[str, torch.Tensor]:
 
3
  import copy
4
  import enum
5
  import graphlib
 
6
 
7
  import pydantic
8
  from lynxkite.core import ops, workspace
 
99
  return Layer(f, shape=x.shape)
100
 
101
 
102
+ @op("MSE loss")
103
+ def mse_loss(x, y):
104
+ return Layer(torch.nn.functional.mse_loss, shape=[1])
105
+
106
+
107
  reg("Softmax", inputs=["x"])
108
  reg(
109
  "Graph conv",
 
115
  reg("Add", inputs=["a", "b"], outputs=["x"])
116
  reg("Subtract", inputs=["a", "b"], outputs=["x"])
117
  reg("Multiply", inputs=["a", "b"], outputs=["x"])
 
118
  reg("Triplet margin loss", inputs=["x", "x_pos", "x_neg"], outputs=["loss"])
119
  reg("Cross-entropy loss", inputs=["x", "y"], outputs=["loss"])
120
  reg(
 
164
 
165
 
166
  @dataclasses.dataclass
167
+ class TensorRef:
168
  """Ops get their inputs like this. They have to return a Layer made for this input."""
169
 
170
+ _id: str
171
  shape: tuple[int, ...]
172
 
173
 
 
178
  module: torch.nn.Module
179
  shapes: list[tuple[int, ...]] | None = None # One for each output.
180
  shape: dataclasses.InitVar[tuple[int, ...] | None] = None # Convenience for single output.
181
+ # Set by ModelBuilder.
182
+ _origin_id: str | None = None
183
+ _inputs: list[TensorRef] | None = None
184
+ _outputs: list[TensorRef] | None = None
185
 
186
  def __post_init__(self, shape):
187
  assert not self.shapes or not shape, "Cannot set both shapes and shape."
188
  if shape:
189
  self.shapes = [shape]
190
 
191
+ def _for_sequential(self):
192
+ inputs = ", ".join(i._id for i in self._inputs)
193
+ outputs = ", ".join(o._id for o in self._outputs)
194
+ return self.module, f"{inputs} -> {outputs}"
195
+
196
 
197
  class ColumnSpec(pydantic.BaseModel):
198
  df: str
 
258
  }
259
 
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def build_model(ws: workspace.Workspace, inputs: dict[str, torch.Tensor]) -> ModelConfig:
262
  """Builds the model described in the workspace."""
263
+ builder = ModelBuilder(ws, inputs)
264
+ return builder.build_model()
265
+
266
+
267
+ class ModelBuilder:
268
+ """The state shared between methods that are used to build the model."""
269
+
270
+ def __init__(self, ws: workspace.Workspace, inputs: dict[str, torch.Tensor]):
271
+ self.catalog = ops.CATALOGS[ENV]
272
+ optimizers = []
273
+ self.nodes = {}
274
+ for node in ws.nodes:
275
+ self.nodes[node.id] = node
276
+ if node.data.title == "Optimizer":
277
+ optimizers.append(node.id)
278
+ assert optimizers, "No optimizer found."
279
+ assert len(optimizers) == 1, f"More than one optimizer found: {optimizers}"
280
+ [self.optimizer] = optimizers
281
+ self.dependencies = {n.id: [] for n in ws.nodes}
282
+ self.in_edges = {}
283
+ self.out_edges = {}
284
+ repeats = []
285
+ for e in ws.edges:
286
+ if self.nodes[e.target].data.title == "Repeat":
287
+ repeats.append(e.target)
288
+ self.dependencies[e.target].append(e.source)
289
+ self.in_edges.setdefault(e.target, {}).setdefault(e.targetHandle, []).append(
290
+ (e.source, e.sourceHandle)
291
+ )
292
+ self.out_edges.setdefault(e.source, {}).setdefault(e.sourceHandle, []).append(
293
+ (e.target, e.targetHandle)
294
+ )
295
+ # Split repeat boxes into start and end, and insert them into the flow.
296
+ # TODO: Think about recursive repeats.
297
+ for repeat in repeats:
298
+ start_id = f"START {repeat}"
299
+ end_id = f"END {repeat}"
300
+ # repeat -> first <- real_input
301
+ # ...becomes...
302
+ # real_input -> start -> first
303
+ first, firsth = self.out_edges[repeat]["output"][0]
304
+ [(real_input, real_inputh)] = [
305
+ k for k in self.in_edges[first][firsth] if k != (repeat, "output")
306
+ ]
307
+ self.dependencies[first].remove(repeat)
308
+ self.dependencies[first].append(start_id)
309
+ self.dependencies[start_id] = [real_input]
310
+ self.out_edges[real_input][real_inputh] = [
311
+ k if k != (first, firsth) else (start_id, "input")
312
+ for k in self.out_edges[real_input][real_inputh]
313
+ ]
314
+ self.in_edges[start_id] = {"input": [(real_input, real_inputh)]}
315
+ self.out_edges[start_id] = {"output": [(first, firsth)]}
316
+ self.in_edges[first][firsth] = [(start_id, "output")]
317
+ # repeat <- last -> real_output
318
+ # ...becomes...
319
+ # last -> end -> real_output
320
+ last, lasth = self.in_edges[repeat]["input"][0]
321
+ [(real_output, real_outputh)] = [
322
+ k for k in self.out_edges[last][lasth] if k != (repeat, "input")
323
+ ]
324
+ del self.dependencies[repeat]
325
+ self.dependencies[end_id] = [last]
326
+ self.dependencies[real_output].append(end_id)
327
+ self.out_edges[last][lasth] = [(end_id, "input")]
328
+ self.in_edges[end_id] = {"input": [(last, lasth)]}
329
+ self.out_edges[end_id] = {"output": [(real_output, real_outputh)]}
330
+ self.in_edges[real_output][real_outputh] = [
331
+ k if k != (last, lasth) else (end_id, "output")
332
+ for k in self.in_edges[real_output][real_outputh]
333
+ ]
334
+ self.inv_dependencies = {n: [] for n in self.dependencies}
335
+ for k, v in self.dependencies.items():
336
+ for i in v:
337
+ self.inv_dependencies[i].append(k)
338
+ self.sizes = {}
339
+ for k, i in inputs.items():
340
+ self.sizes[k] = i.shape[-1]
341
+ self.layers = []
342
+
343
+ def all_upstream(self, node: str) -> set[str]:
344
+ """Returns all nodes upstream of a node."""
345
+ deps = set()
346
+ for dep in self.dependencies[node]:
347
+ deps.add(dep)
348
+ deps.update(self.all_upstream(dep))
349
+ return deps
350
+
351
+ def all_downstream(self, node: str) -> set[str]:
352
+ """Returns all nodes downstream of a node."""
353
+ deps = set()
354
+ for dep in self.inv_dependencies[node]:
355
+ deps.add(dep)
356
+ deps.update(self.all_downstream(dep))
357
+ return deps
358
+
359
+ def run_node(self, node_id: str) -> None:
360
+ """Adds the layer(s) produced by this node to self.layers."""
361
  if node_id.startswith("START "):
362
+ node = self.nodes[node_id.removeprefix("START ")]
363
  elif node_id.startswith("END "):
364
+ node = self.nodes[node_id.removeprefix("END ")]
365
  else:
366
+ node = self.nodes[node_id]
367
  t = node.data.title
368
+ op = self.catalog[t]
369
  p = op.convert_params(node.data.params)
 
 
 
 
370
  inputs = {}
371
+ for n in self.in_edges.get(node_id, []):
372
+ for b, h in self.in_edges[node_id][n]:
373
  i = _to_id(b, h)
374
  inputs[n] = i
 
 
 
 
375
  outputs = {}
376
+ for out in self.out_edges.get(node_id, []):
377
  i = _to_id(node_id, out)
378
  outputs[out] = i
 
 
 
 
 
 
 
 
379
  match t:
 
 
 
 
 
 
 
380
  case "Repeat":
381
+ if node_id.startswith("END "):
 
 
 
 
382
  repeat_id = node_id.removeprefix("END ")
383
  start_id = f"START {repeat_id}"
384
  print(f"repeat {repeat_id} ending")
385
+ after_start = self.all_downstream(start_id)
386
+ after_end = self.all_downstream(node_id)
387
+ before_end = self.all_upstream(node_id)
388
+ affected_nodes = after_start - after_end - {node_id}
389
  repeated_nodes = after_start & before_end
390
  assert affected_nodes == repeated_nodes, (
391
  f"edges leave repeated section '{repeat_id}':\n{affected_nodes - repeated_nodes}"
392
  )
 
393
  for n in repeated_nodes:
394
  print(f"repeating {n}")
395
  case "Optimizer" | "Input: tensor" | "Input: graph edges" | "Input: sequential":
396
+ return
397
+ layer = self.run_op(op, p, inputs, outputs)
398
+ layer._origin_id = node_id
399
+ self.layers.append(layer)
400
+
401
+ def run_op(self, op, params, inputs: dict[str, str], outputs: dict[str, str]) -> Layer:
402
+ """Returns the layer produced by this op."""
403
+ op_inputs = [
404
+ TensorRef(inputs[i], shape=self.sizes.get(inputs[i], 1)) for i in op.inputs.keys()
405
+ ]
406
+ if op.func != ops.no_op:
407
+ layer = op.func(*op_inputs, **params)
408
+ else:
409
+ layer = Layer(torch.nn.Identity(), shapes=[i.shape for i in op_inputs])
410
+ layer._inputs = op_inputs
411
+ layer._outputs = []
412
+ for o, shape in zip(op.outputs.keys(), layer.shapes):
413
+ layer._outputs.append(TensorRef(outputs[o], shape=shape))
414
+ self.sizes[outputs[o]] = shape
415
+ return layer
416
+
417
+ def build_model(self) -> ModelConfig:
418
+ # Walk the graph in topological order.
419
+ ts = graphlib.TopologicalSorter(self.dependencies)
420
+ for node_id in ts.static_order():
421
+ self.run_node(node_id)
422
+ return self.get_config()
423
+
424
+ def get_config(self) -> ModelConfig:
425
+ # Split the design into model and loss.
426
+ loss_nodes = set()
427
+ for node_id in self.nodes:
428
+ if "loss" in self.nodes[node_id].data.title:
429
+ loss_nodes.add(node_id)
430
+ loss_nodes |= self.all_downstream(node_id)
431
+ layers = []
432
+ loss_layers = []
433
+ for layer in self.layers:
434
+ if layer._origin_id in loss_nodes:
435
+ loss_layers.append(layer)
436
+ else:
437
+ layers.append(layer)
438
+ used_in_model = set(input._id for layer in layers for input in layer._inputs)
439
+ used_in_loss = set(input._id for layer in loss_layers for input in layer._inputs)
440
+ made_in_model = set(output._id for layer in layers for output in layer._outputs)
441
+ made_in_loss = set(output._id for layer in loss_layers for output in layer._outputs)
442
+ layers = [layer._for_sequential() for layer in layers]
443
+ loss_layers = [layer._for_sequential() for layer in loss_layers]
444
+ cfg = {}
445
+ cfg["model_inputs"] = list(used_in_model - made_in_model)
446
+ cfg["model_outputs"] = list(made_in_model & used_in_loss)
447
+ cfg["loss_inputs"] = list(used_in_loss - made_in_loss)
448
+ # Make sure the trained output is output from the last model layer.
449
+ outputs = ", ".join(cfg["model_outputs"])
450
+ layers.append((torch.nn.Identity(), f"{outputs} -> {outputs}"))
451
+ # Create model.
452
+ cfg["model"] = pyg.nn.Sequential(", ".join(cfg["model_inputs"]), layers)
453
+ # Make sure the loss is output from the last loss layer.
454
+ [(lossb, lossh)] = self.in_edges[self.optimizer]["loss"]
455
+ lossi = _to_id(lossb, lossh)
456
+ loss_layers.append((torch.nn.Identity(), f"{lossi} -> loss"))
457
+ # Create loss function.
458
+ cfg["loss"] = pyg.nn.Sequential(", ".join(cfg["loss_inputs"]), loss_layers)
459
+ assert not list(cfg["loss"].parameters()), f"loss should have no parameters: {loss_layers}"
460
+ # Create optimizer.
461
+ op = self.catalog["Optimizer"]
462
+ p = op.convert_params(self.nodes[self.optimizer].data.params)
463
+ o = getattr(torch.optim, p["type"].name)
464
+ cfg["optimizer"] = o(cfg["model"].parameters(), lr=p["lr"])
465
+ print(cfg)
466
+ return ModelConfig(**cfg)
467
 
468
 
469
  def to_tensors(b: core.Bundle, m: ModelMapping | None) -> dict[str, torch.Tensor]: