File size: 3,445 Bytes
16d905e
 
 
360d0c5
16d905e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360d0c5
 
 
16d905e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import asyncio
import logging
import uuid
from typing import Optional

from fastapi import WebSocket

from docling_serve.datamodel.engines import Task, TaskStatus
from docling_serve.datamodel.requests import ConvertDocumentsRequest
from docling_serve.datamodel.responses import (
    MessageKind,
    TaskStatusResponse,
    WebsocketMessage,
)
from docling_serve.engines.async_local.worker import AsyncLocalWorker
from docling_serve.engines.base_orchestrator import BaseOrchestrator
from docling_serve.settings import docling_serve_settings

_log = logging.getLogger(__name__)


class OrchestratorError(Exception):
    pass


class TaskNotFoundError(OrchestratorError):
    pass


class AsyncLocalOrchestrator(BaseOrchestrator):
    def __init__(self):
        self.task_queue = asyncio.Queue()
        self.tasks: dict[str, Task] = {}
        self.queue_list: list[str] = []
        self.task_subscribers: dict[str, set[WebSocket]] = {}

    async def enqueue(self, request: ConvertDocumentsRequest) -> Task:
        task_id = str(uuid.uuid4())
        task = Task(task_id=task_id, request=request)
        self.tasks[task_id] = task
        self.queue_list.append(task_id)
        self.task_subscribers[task_id] = set()
        await self.task_queue.put(task_id)
        return task

    async def queue_size(self) -> int:
        return self.task_queue.qsize()

    async def get_queue_position(self, task_id: str) -> Optional[int]:
        return (
            self.queue_list.index(task_id) + 1 if task_id in self.queue_list else None
        )

    async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
        if task_id not in self.tasks:
            raise TaskNotFoundError()
        return self.tasks[task_id]

    async def task_result(self, task_id: str):
        if task_id not in self.tasks:
            raise TaskNotFoundError()
        return self.tasks[task_id].result

    async def process_queue(self):
        # Create a pool of workers
        workers = []
        for i in range(docling_serve_settings.eng_loc_num_workers):
            _log.debug(f"Starting worker {i}")
            w = AsyncLocalWorker(i, self)
            worker_task = asyncio.create_task(w.loop())
            workers.append(worker_task)

        # Wait for all workers to complete (they won't, as they run indefinitely)
        await asyncio.gather(*workers)
        _log.debug("All workers completed.")

    async def notify_task_subscribers(self, task_id: str):
        if task_id not in self.task_subscribers:
            raise RuntimeError(f"Task {task_id} does not have a subscribers list.")

        task = self.tasks[task_id]
        task_queue_position = await self.get_queue_position(task_id)
        msg = TaskStatusResponse(
            task_id=task.task_id,
            task_status=task.task_status,
            task_position=task_queue_position,
        )
        for websocket in self.task_subscribers[task_id]:
            await websocket.send_text(
                WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json()
            )
            if task.is_completed():
                await websocket.close()

    async def notify_queue_positions(self):
        for task_id in self.task_subscribers.keys():
            # notify only pending tasks
            if self.tasks[task_id].task_status != TaskStatus.PENDING:
                continue

            await self.notify_task_subscribers(task_id)