Michele Dolfi
chore: Remove deprecated type aliases and run as pre-commit (#79)
360d0c5 unverified
import asyncio
import logging
import uuid
from typing import Optional
from fastapi import WebSocket
from docling_serve.datamodel.engines import Task, TaskStatus
from docling_serve.datamodel.requests import ConvertDocumentsRequest
from docling_serve.datamodel.responses import (
MessageKind,
TaskStatusResponse,
WebsocketMessage,
)
from docling_serve.engines.async_local.worker import AsyncLocalWorker
from docling_serve.engines.base_orchestrator import BaseOrchestrator
from docling_serve.settings import docling_serve_settings
_log = logging.getLogger(__name__)
class OrchestratorError(Exception):
pass
class TaskNotFoundError(OrchestratorError):
pass
class AsyncLocalOrchestrator(BaseOrchestrator):
def __init__(self):
self.task_queue = asyncio.Queue()
self.tasks: dict[str, Task] = {}
self.queue_list: list[str] = []
self.task_subscribers: dict[str, set[WebSocket]] = {}
async def enqueue(self, request: ConvertDocumentsRequest) -> Task:
task_id = str(uuid.uuid4())
task = Task(task_id=task_id, request=request)
self.tasks[task_id] = task
self.queue_list.append(task_id)
self.task_subscribers[task_id] = set()
await self.task_queue.put(task_id)
return task
async def queue_size(self) -> int:
return self.task_queue.qsize()
async def get_queue_position(self, task_id: str) -> Optional[int]:
return (
self.queue_list.index(task_id) + 1 if task_id in self.queue_list else None
)
async def task_status(self, task_id: str, wait: float = 0.0) -> Task:
if task_id not in self.tasks:
raise TaskNotFoundError()
return self.tasks[task_id]
async def task_result(self, task_id: str):
if task_id not in self.tasks:
raise TaskNotFoundError()
return self.tasks[task_id].result
async def process_queue(self):
# Create a pool of workers
workers = []
for i in range(docling_serve_settings.eng_loc_num_workers):
_log.debug(f"Starting worker {i}")
w = AsyncLocalWorker(i, self)
worker_task = asyncio.create_task(w.loop())
workers.append(worker_task)
# Wait for all workers to complete (they won't, as they run indefinitely)
await asyncio.gather(*workers)
_log.debug("All workers completed.")
async def notify_task_subscribers(self, task_id: str):
if task_id not in self.task_subscribers:
raise RuntimeError(f"Task {task_id} does not have a subscribers list.")
task = self.tasks[task_id]
task_queue_position = await self.get_queue_position(task_id)
msg = TaskStatusResponse(
task_id=task.task_id,
task_status=task.task_status,
task_position=task_queue_position,
)
for websocket in self.task_subscribers[task_id]:
await websocket.send_text(
WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json()
)
if task.is_completed():
await websocket.close()
async def notify_queue_positions(self):
for task_id in self.task_subscribers.keys():
# notify only pending tasks
if self.tasks[task_id].task_status != TaskStatus.PENDING:
continue
await self.notify_task_subscribers(task_id)