darabos's picture
"pos" is no longer optional. Plus a comment.
a2c3c92
"""CRDT is used to synchronize workspace state for backend and frontend(s)."""
import asyncio
import contextlib
import enum
import pathlib
import fastapi
import os.path
import pycrdt
import pycrdt_websocket
import pycrdt_websocket.ystore
import uvicorn
import builtins
from lynxkite.core import workspace, ops
router = fastapi.APIRouter()
def ws_exception_handler(exception, log):
if isinstance(exception, builtins.ExceptionGroup):
for ex in exception.exceptions:
if not isinstance(ex, uvicorn.protocols.utils.ClientDisconnected):
log.exception(ex)
else:
log.exception(exception)
return True
class WorkspaceWebsocketServer(pycrdt_websocket.WebsocketServer):
async def init_room(self, name: str) -> pycrdt_websocket.YRoom:
"""Initialize a room for the workspace with the given name.
The workspace is loaded from ".crdt" if it exists there, or from a JSON file, or a new workspace is created.
"""
crdt_path = pathlib.Path(".crdt")
path = crdt_path / f"{name}.crdt"
assert path.is_relative_to(crdt_path), f"Path '{path}' is invalid"
ystore = pycrdt_websocket.ystore.FileYStore(path)
ydoc = pycrdt.Doc()
ydoc["workspace"] = ws = pycrdt.Map()
# Replay updates from the store.
try:
for update, timestamp in [(item[0], item[-1]) async for item in ystore.read()]:
ydoc.apply_update(update)
except pycrdt_websocket.ystore.YDocNotFound:
pass
if "nodes" not in ws:
ws["nodes"] = pycrdt.Array()
if "edges" not in ws:
ws["edges"] = pycrdt.Array()
if "env" not in ws:
ws["env"] = next(iter(ops.CATALOGS), "unset")
# We have two possible sources of truth for the workspaces, the YStore and the JSON files.
# In case we didn't find the workspace in the YStore, we try to load it from the JSON files.
try_to_load_workspace(ws, name)
ws_simple = workspace.Workspace.model_validate(ws.to_py())
clean_input(ws_simple)
# Set the last known version to the current state, so we don't trigger a change event.
last_known_versions[name] = ws_simple
room = pycrdt_websocket.YRoom(
ystore=ystore, ydoc=ydoc, exception_handler=ws_exception_handler
)
room.ws = ws
def on_change(changes):
task = asyncio.create_task(workspace_changed(name, changes, ws))
# We have no way to await workspace_changed(). The best we can do is to
# dereference its result after it's done, so exceptions are logged normally.
task.add_done_callback(lambda t: t.result())
ws.observe_deep(on_change)
return room
async def get_room(self, name: str) -> pycrdt_websocket.YRoom:
"""Get a room by name.
This method overrides the parent get_room method. The original creates an empty room,
with no associated Ydoc. Instead, we want to initialize the the room with a Workspace
object.
"""
if name not in self.rooms:
self.rooms[name] = await self.init_room(name)
room = self.rooms[name]
await self.start_room(room)
return room
class CodeWebsocketServer(WorkspaceWebsocketServer):
async def init_room(self, name: str) -> pycrdt_websocket.YRoom:
"""Initialize a room for a text document with the given name."""
crdt_path = pathlib.Path(".crdt")
path = crdt_path / f"{name}.crdt"
assert path.is_relative_to(crdt_path), f"Path '{path}' is invalid"
ystore = pycrdt_websocket.ystore.FileYStore(path)
ydoc = pycrdt.Doc()
ydoc["text"] = text = pycrdt.Text()
# Replay updates from the store.
try:
for update, timestamp in [(item[0], item[-1]) async for item in ystore.read()]:
ydoc.apply_update(update)
except pycrdt_websocket.ystore.YDocNotFound:
pass
if len(text) == 0:
if os.path.exists(name):
with open(name, encoding="utf-8") as f:
text += f.read().replace("\r\n", "\n")
room = pycrdt_websocket.YRoom(
ystore=ystore, ydoc=ydoc, exception_handler=ws_exception_handler
)
room.text = text
def on_change(changes):
asyncio.create_task(code_changed(name, changes, text))
text.observe(on_change)
return room
last_ws_input = None
def clean_input(ws_pyd):
"""Delete everything that we want to ignore for the purposes of change detection."""
for node in ws_pyd.nodes:
node.data.display = None
node.data.input_metadata = None
node.data.error = None
node.data.status = workspace.NodeStatus.done
for p in list(node.data.params):
if p.startswith("_"):
del node.data.params[p]
if node.data.title == "Comment":
node.data.params = {}
node.position.x = 0
node.position.y = 0
if node.model_extra:
for key in list(node.model_extra.keys()):
delattr(node, key)
def crdt_update(
crdt_obj: pycrdt.Map | pycrdt.Array,
python_obj: dict | list,
non_collaborative_fields: set[str] = set(),
):
"""Update a CRDT object to match a Python object.
The types between the CRDT object and the Python object must match. If the Python object
is a dict, the CRDT object must be a Map. If the Python object is a list, the CRDT object
must be an Array.
Args:
crdt_obj: The CRDT object, that will be updated to match the Python object.
python_obj: The Python object to update with.
non_collaborative_fields: List of fields to treat as a black box. Black boxes are
updated as a whole, instead of having a fine-grained data structure to edit
collaboratively. Useful for complex fields that contain auto-generated data or
metadata.
The default is an empty set.
Raises:
ValueError: If the Python object provided is not a dict or list.
"""
if isinstance(python_obj, dict):
for key, value in python_obj.items():
if key in non_collaborative_fields:
crdt_obj[key] = value
elif isinstance(value, dict):
if crdt_obj.get(key) is None:
crdt_obj[key] = pycrdt.Map()
crdt_update(crdt_obj[key], value, non_collaborative_fields)
elif isinstance(value, list):
if crdt_obj.get(key) is None:
crdt_obj[key] = pycrdt.Array()
crdt_update(crdt_obj[key], value, non_collaborative_fields)
elif isinstance(value, enum.Enum):
crdt_obj[key] = str(value.value)
else:
crdt_obj[key] = value
elif isinstance(python_obj, list):
for i, value in enumerate(python_obj):
if isinstance(value, dict):
if i >= len(crdt_obj):
crdt_obj.append(pycrdt.Map())
crdt_update(crdt_obj[i], value, non_collaborative_fields)
elif isinstance(value, list):
if i >= len(crdt_obj):
crdt_obj.append(pycrdt.Array())
crdt_update(crdt_obj[i], value, non_collaborative_fields)
else:
if isinstance(value, enum.Enum):
value = str(value.value)
if i >= len(crdt_obj):
crdt_obj.append(value)
else:
crdt_obj[i] = value
else:
raise ValueError("Invalid type:", python_obj)
def try_to_load_workspace(ws: pycrdt.Map, name: str):
"""Load the workspace `name`, if it exists, and update the `ws` CRDT object to match its contents.
Args:
ws: CRDT object to udpate with the workspace contents.
name: Name of the workspace to load.
"""
if os.path.exists(name):
ws_pyd = workspace.Workspace.load(name)
crdt_update(
ws,
ws_pyd.model_dump(),
# We treat some fields as black boxes. They are not edited on the frontend.
non_collaborative_fields={"display", "input_metadata", "meta"},
)
last_known_versions = {}
delayed_executions = {}
async def workspace_changed(name: str, changes: pycrdt.MapEvent, ws_crdt: pycrdt.Map):
"""Callback to react to changes in the workspace.
Args:
name: Name of the workspace.
changes: Changes performed to the workspace.
ws_crdt: CRDT object representing the workspace.
"""
ws_pyd = workspace.Workspace.model_validate(ws_crdt.to_py())
# Do not trigger execution for superficial changes.
# This is a quick solution until we build proper caching.
ws_simple = ws_pyd.model_copy(deep=True)
clean_input(ws_simple)
if ws_simple == last_known_versions.get(name):
return
last_known_versions[name] = ws_simple
# Frontend changes that result from typing are delayed to avoid
# rerunning the workspace for every keystroke.
if name in delayed_executions:
delayed_executions[name].cancel()
delay = min(
getattr(change, "keys", {}).get("__execution_delay", {}).get("newValue", 0)
for change in changes
)
if delay:
task = asyncio.create_task(execute(name, ws_crdt, ws_pyd, delay))
delayed_executions[name] = task
else:
await execute(name, ws_crdt, ws_pyd)
async def execute(name: str, ws_crdt: pycrdt.Map, ws_pyd: workspace.Workspace, delay: int = 0):
"""Execute the workspace and update the CRDT object with the results.
Args:
name: Name of the workspace.
ws_crdt: CRDT object representing the workspace.
ws_pyd: Workspace object to execute.
delay: Wait time before executing the workspace. The default is 0.
"""
if delay:
try:
await asyncio.sleep(delay)
except asyncio.CancelledError:
return
print(f"Running {name} in {ws_pyd.env}...")
cwd = pathlib.Path()
path = cwd / name
assert path.is_relative_to(cwd), f"Path '{path}' is invalid"
# Save user changes before executing, in case the execution fails.
ws_pyd.save(path)
ops.load_user_scripts(name)
ws_pyd.connect_crdt(ws_crdt)
ws_pyd.update_metadata()
if not ws_pyd.has_executor():
return
with ws_crdt.doc.transaction():
for nc in ws_crdt["nodes"]:
nc["data"]["status"] = "planned"
ws_pyd.normalize()
await ws_pyd.execute()
ws_pyd.save(path)
print(f"Finished running {name} in {ws_pyd.env}.")
async def code_changed(name: str, changes: pycrdt.TextEvent, text: pycrdt.Text):
contents = str(text).strip() + "\n"
with open(name, "w", encoding="utf-8") as f:
f.write(contents)
@contextlib.asynccontextmanager
async def lifespan(app):
global ws_websocket_server
global code_websocket_server
ws_websocket_server = WorkspaceWebsocketServer(auto_clean_rooms=False)
code_websocket_server = CodeWebsocketServer(auto_clean_rooms=False)
async with ws_websocket_server:
async with code_websocket_server:
yield
print("closing websocket server")
def delete_room(name: str):
if name in ws_websocket_server.rooms:
del ws_websocket_server.rooms[name]
def sanitize_path(path):
return os.path.relpath(os.path.normpath(os.path.join("/", path)), "/")
@router.websocket("/ws/crdt/{room_name:path}")
async def crdt_websocket(websocket: fastapi.WebSocket, room_name: str):
room_name = sanitize_path(room_name)
server = pycrdt_websocket.ASGIServer(ws_websocket_server)
await server({"path": room_name}, websocket._receive, websocket._send)
@router.websocket("/ws/code/crdt/{room_name:path}")
async def code_crdt_websocket(websocket: fastapi.WebSocket, room_name: str):
room_name = sanitize_path(room_name)
server = pycrdt_websocket.ASGIServer(code_websocket_server)
await server({"path": room_name}, websocket._receive, websocket._send)