File size: 6,083 Bytes
d768932
 
31d7b86
2720e84
d768932
2720e84
 
 
 
 
1300a60
 
2720e84
 
 
d768932
2720e84
1300a60
 
 
 
 
 
2720e84
 
d768932
2720e84
 
d768932
2720e84
d768932
2720e84
 
d768932
 
 
2720e84
 
 
d768932
 
 
 
 
 
eee9365
1300a60
 
 
31d7b86
d768932
31d7b86
f2d4392
d768932
31d7b86
 
2720e84
 
 
 
 
 
 
 
d768932
31d7b86
d768932
 
31d7b86
 
 
f2d4392
31d7b86
 
 
 
 
 
d768932
eee9365
31d7b86
 
eee9365
 
 
31d7b86
 
eee9365
31d7b86
 
 
eee9365
d768932
 
31d7b86
 
 
 
 
 
 
eee9365
31d7b86
 
 
eee9365
31d7b86
 
 
 
 
 
d768932
31d7b86
eee9365
 
 
d768932
 
eee9365
 
d768932
 
eee9365
f2d4392
 
 
 
 
31d7b86
d768932
31d7b86
de3a4c7
 
31d7b86
f2d4392
31d7b86
f2d4392
de3a4c7
 
f2d4392
 
 
901487a
f2d4392
 
de3a4c7
f2d4392
 
 
 
 
 
 
 
 
 
 
 
 
 
eee9365
901487a
 
 
 
 
 
 
d768932
2720e84
 
 
1300a60
 
 
 
2720e84
 
f2d4392
2720e84
d768932
2720e84
 
 
d768932
2720e84
 
 
1300a60
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""CRDT is used to synchronize workspace state for backend and frontend(s)."""

import asyncio
import contextlib
import enum
import fastapi
import os.path
import pycrdt
import pycrdt_websocket
import pycrdt_websocket.ystore
import uvicorn
import builtins

router = fastapi.APIRouter()


def ws_exception_handler(exception, log):
    if isinstance(exception, builtins.ExceptionGroup):
        for ex in exception.exceptions:
            if not isinstance(ex, uvicorn.protocols.utils.ClientDisconnected):
                log.exception(ex)
    else:
        log.exception(exception)
    return True


class WebsocketServer(pycrdt_websocket.WebsocketServer):
    async def init_room(self, name):
        ystore = pycrdt_websocket.ystore.FileYStore(f"crdt_data/{name}.crdt")
        ydoc = pycrdt.Doc()
        ydoc["workspace"] = ws = pycrdt.Map()
        # Replay updates from the store.
        try:
            for update, timestamp in [
                (item[0], item[-1]) async for item in ystore.read()
            ]:
                ydoc.apply_update(update)
        except pycrdt_websocket.ystore.YDocNotFound:
            pass
        if "nodes" not in ws:
            ws["nodes"] = pycrdt.Array()
        if "edges" not in ws:
            ws["edges"] = pycrdt.Array()
        if "env" not in ws:
            ws["env"] = "unset"
            try_to_load_workspace(ws, name)
        room = pycrdt_websocket.YRoom(
            ystore=ystore, ydoc=ydoc, exception_handler=ws_exception_handler
        )
        room.ws = ws

        def on_change(changes):
            asyncio.create_task(workspace_changed(name, changes, ws))

        ws.observe_deep(on_change)
        return room

    async def get_room(self, name: str) -> pycrdt_websocket.YRoom:
        if name not in self.rooms:
            self.rooms[name] = await self.init_room(name)
        room = self.rooms[name]
        await self.start_room(room)
        return room


last_ws_input = None


def clean_input(ws_pyd):
    for node in ws_pyd.nodes:
        node.data.display = None
        node.data.error = None
        node.position.x = 0
        node.position.y = 0
        if node.model_extra:
            for key in list(node.model_extra.keys()):
                delattr(node, key)


def crdt_update(crdt_obj, python_obj, boxes=set()):
    if isinstance(python_obj, dict):
        for key, value in python_obj.items():
            if key in boxes:
                crdt_obj[key] = value
            elif isinstance(value, dict):
                if crdt_obj.get(key) is None:
                    crdt_obj[key] = pycrdt.Map()
                crdt_update(crdt_obj[key], value, boxes)
            elif isinstance(value, list):
                if crdt_obj.get(key) is None:
                    crdt_obj[key] = pycrdt.Array()
                crdt_update(crdt_obj[key], value, boxes)
            elif isinstance(value, enum.Enum):
                crdt_obj[key] = str(value)
            else:
                crdt_obj[key] = value
    elif isinstance(python_obj, list):
        for i, value in enumerate(python_obj):
            if isinstance(value, dict):
                if i >= len(crdt_obj):
                    crdt_obj.append(pycrdt.Map())
                crdt_update(crdt_obj[i], value, boxes)
            elif isinstance(value, list):
                if i >= len(crdt_obj):
                    crdt_obj.append(pycrdt.Array())
                crdt_update(crdt_obj[i], value, boxes)
            else:
                if i >= len(crdt_obj):
                    crdt_obj.append(value)
                else:
                    crdt_obj[i] = value
    else:
        raise ValueError("Invalid type:", python_obj)


def try_to_load_workspace(ws, name):
    from . import workspace

    json_path = f"data/{name}"
    if os.path.exists(json_path):
        ws_pyd = workspace.load(json_path)
        crdt_update(ws, ws_pyd.model_dump(), boxes={"display"})


last_known_versions = {}
delayed_executions = {}


async def workspace_changed(name, changes, ws_crdt):
    from . import workspace

    ws_pyd = workspace.Workspace.model_validate(ws_crdt.to_py())
    # Do not trigger execution for superficial changes.
    # This is a quick solution until we build proper caching.
    clean_input(ws_pyd)
    if ws_pyd == last_known_versions.get(name):
        return
    last_known_versions[name] = ws_pyd.model_copy(deep=True)
    # Frontend changes that result from typing are delayed to avoid
    # rerunning the workspace for every keystroke.
    if name in delayed_executions:
        delayed_executions[name].cancel()
    delay = min(
        getattr(change, "keys", {}).get("__execution_delay", {}).get("newValue", 0)
        for change in changes
    )
    if delay:
        task = asyncio.create_task(execute(ws_crdt, ws_pyd, delay))
        delayed_executions[name] = task
    else:
        await execute(ws_crdt, ws_pyd)


async def execute(ws_crdt, ws_pyd, delay=0):
    from . import workspace

    if delay:
        try:
            await asyncio.sleep(delay)
        except asyncio.CancelledError:
            return
    await workspace.execute(ws_pyd)
    with ws_crdt.doc.transaction():
        for nc, np in zip(ws_crdt["nodes"], ws_pyd.nodes):
            if "data" not in nc:
                nc["data"] = pycrdt.Map()
            # Display is added as an opaque Box.
            nc["data"]["display"] = np.data.display
            nc["data"]["error"] = np.data.error


@contextlib.asynccontextmanager
async def lifespan(app):
    global websocket_server
    websocket_server = WebsocketServer(
        auto_clean_rooms=False,
    )
    async with websocket_server:
        yield
    print("closing websocket server")


def sanitize_path(path):
    return os.path.relpath(os.path.normpath(os.path.join("/", path)), "/")


@router.websocket("/ws/crdt/{room_name}")
async def crdt_websocket(websocket: fastapi.WebSocket, room_name: str):
    room_name = sanitize_path(room_name)
    server = pycrdt_websocket.ASGIServer(websocket_server)
    await server({"path": room_name}, websocket._receive, websocket._send)