File size: 4,089 Bytes
2f4e3c5
5faa80b
2f4e3c5
 
 
 
0002aae
2f4e3c5
 
e6745ef
 
2f4e3c5
 
 
 
e6745ef
2f4e3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5faa80b
e6745ef
 
 
5faa80b
e6745ef
 
 
 
 
 
2f4e3c5
 
5faa80b
2f4e3c5
aaf43fc
0002aae
e6745ef
aaf43fc
e6745ef
2f4e3c5
 
 
 
aaf43fc
e6745ef
aaf43fc
 
e6745ef
 
2f4e3c5
 
 
 
5faa80b
2f4e3c5
aaf43fc
2f4e3c5
aaf43fc
 
2f4e3c5
 
 
e6745ef
 
 
5faa80b
e6745ef
aaf43fc
0002aae
e6745ef
aaf43fc
e6745ef
 
 
 
 
 
aaf43fc
e6745ef
aaf43fc
 
e6745ef
 
 
 
 
 
 
 
5faa80b
aaf43fc
 
e6745ef
 
5faa80b
aaf43fc
 
e6745ef
 
5faa80b
aaf43fc
 
e6745ef
 
2f4e3c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from lynxkite.core import workspace
from lynxkite_graph_analytics.pytorch import pytorch_core
import torch
import pytest


def make_ws(env, nodes: dict[str, dict], edges: list[tuple[str, str]]):
    ws = workspace.Workspace(env=env)
    for id, data in nodes.items():
        title = data["title"]
        del data["title"]
        ws.nodes.append(
            workspace.WorkspaceNode(
                id=id,
                type="basic",
                data=workspace.WorkspaceNodeData(title=title, params=data),
                position=workspace.Position(
                    x=data.get("x", 0),
                    y=data.get("y", 0),
                ),
            )
        )
    ws.edges = [
        workspace.WorkspaceEdge(
            id=f"{source}->{target}",
            source=source.split(":")[0],
            target=target.split(":")[0],
            sourceHandle=source.split(":")[1],
            targetHandle=target.split(":")[1],
        )
        for source, target in edges
    ]
    return ws


def summarize_layers(m: pytorch_core.ModelConfig) -> str:
    return "".join(str(e)[0] for e in m.model)


def summarize_connections(m: pytorch_core.ModelConfig) -> str:
    return " ".join(
        "".join(n[0] for n in c.param_names) + "->" + "".join(n[0] for n in c.return_names)
        for c in m.model._children
    )


async def test_build_model():
    ws = make_ws(
        pytorch_core.ENV,
        {
            "input": {"title": "Input: tensor"},
            "lin": {"title": "Linear", "output_dim": 4},
            "act": {"title": "Activation", "type": "Leaky_ReLU"},
            "output": {"title": "Output"},
            "label": {"title": "Input: tensor"},
            "loss": {"title": "MSE loss"},
            "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
        },
        [
            ("input:output", "lin:x"),
            ("lin:output", "act:x"),
            ("act:output", "output:x"),
            ("output:x", "loss:x"),
            ("label:output", "loss:y"),
            ("loss:output", "optim:loss"),
        ],
    )
    x = torch.rand(100, 4)
    y = x + 1
    m = pytorch_core.build_model(ws)
    for i in range(1000):
        loss = m.train({"input_output": x, "label_output": y})
    assert loss < 0.1
    o = m.inference({"input_output": x[:1]})
    error = torch.nn.functional.mse_loss(o["output_x"], x[:1] + 1)
    assert error < 0.1


async def test_build_model_with_repeat():
    def repeated_ws(times):
        return make_ws(
            pytorch_core.ENV,
            {
                "input": {"title": "Input: tensor"},
                "lin": {"title": "Linear", "output_dim": 8},
                "act": {"title": "Activation", "type": "Leaky_ReLU"},
                "output": {"title": "Output"},
                "label": {"title": "Input: tensor"},
                "loss": {"title": "MSE loss"},
                "optim": {"title": "Optimizer", "type": "SGD", "lr": 0.1},
                "repeat": {"title": "Repeat", "times": times, "same_weights": False},
            },
            [
                ("input:output", "lin:x"),
                ("lin:output", "act:x"),
                ("act:output", "output:x"),
                ("output:x", "loss:x"),
                ("label:output", "loss:y"),
                ("loss:output", "optim:loss"),
                ("repeat:output", "lin:x"),
                ("act:output", "repeat:input"),
            ],
        )

    # 1 repetition
    m = pytorch_core.build_model(repeated_ws(1))
    assert summarize_layers(m) == "IL<III"
    assert summarize_connections(m) == "i->S S->l l->a a->E E->o o->o"

    # 2 repetitions
    m = pytorch_core.build_model(repeated_ws(2))
    assert summarize_layers(m) == "IL<IL<III"
    assert summarize_connections(m) == "i->S S->l l->a a->S S->l l->a a->E E->o o->o"

    # 3 repetitions
    m = pytorch_core.build_model(repeated_ws(3))
    assert summarize_layers(m) == "IL<IL<IL<III"
    assert summarize_connections(m) == "i->S S->l l->a a->S S->l l->a a->S S->l l->a a->E E->o o->o"


if __name__ == "__main__":
    pytest.main()