Spaces:
Configuration error
Configuration error
import asyncio | |
import logging | |
import uuid | |
from typing import Optional | |
from fastapi import WebSocket | |
from docling_serve.datamodel.engines import Task, TaskStatus | |
from docling_serve.datamodel.requests import ConvertDocumentsRequest | |
from docling_serve.datamodel.responses import ( | |
MessageKind, | |
TaskStatusResponse, | |
WebsocketMessage, | |
) | |
from docling_serve.engines.async_local.worker import AsyncLocalWorker | |
from docling_serve.engines.base_orchestrator import BaseOrchestrator | |
from docling_serve.settings import docling_serve_settings | |
_log = logging.getLogger(__name__) | |
class OrchestratorError(Exception): | |
pass | |
class TaskNotFoundError(OrchestratorError): | |
pass | |
class AsyncLocalOrchestrator(BaseOrchestrator): | |
def __init__(self): | |
self.task_queue = asyncio.Queue() | |
self.tasks: dict[str, Task] = {} | |
self.queue_list: list[str] = [] | |
self.task_subscribers: dict[str, set[WebSocket]] = {} | |
async def enqueue(self, request: ConvertDocumentsRequest) -> Task: | |
task_id = str(uuid.uuid4()) | |
task = Task(task_id=task_id, request=request) | |
self.tasks[task_id] = task | |
self.queue_list.append(task_id) | |
self.task_subscribers[task_id] = set() | |
await self.task_queue.put(task_id) | |
return task | |
async def queue_size(self) -> int: | |
return self.task_queue.qsize() | |
async def get_queue_position(self, task_id: str) -> Optional[int]: | |
return ( | |
self.queue_list.index(task_id) + 1 if task_id in self.queue_list else None | |
) | |
async def task_status(self, task_id: str, wait: float = 0.0) -> Task: | |
if task_id not in self.tasks: | |
raise TaskNotFoundError() | |
return self.tasks[task_id] | |
async def task_result(self, task_id: str): | |
if task_id not in self.tasks: | |
raise TaskNotFoundError() | |
return self.tasks[task_id].result | |
async def process_queue(self): | |
# Create a pool of workers | |
workers = [] | |
for i in range(docling_serve_settings.eng_loc_num_workers): | |
_log.debug(f"Starting worker {i}") | |
w = AsyncLocalWorker(i, self) | |
worker_task = asyncio.create_task(w.loop()) | |
workers.append(worker_task) | |
# Wait for all workers to complete (they won't, as they run indefinitely) | |
await asyncio.gather(*workers) | |
_log.debug("All workers completed.") | |
async def notify_task_subscribers(self, task_id: str): | |
if task_id not in self.task_subscribers: | |
raise RuntimeError(f"Task {task_id} does not have a subscribers list.") | |
task = self.tasks[task_id] | |
task_queue_position = await self.get_queue_position(task_id) | |
msg = TaskStatusResponse( | |
task_id=task.task_id, | |
task_status=task.task_status, | |
task_position=task_queue_position, | |
) | |
for websocket in self.task_subscribers[task_id]: | |
await websocket.send_text( | |
WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json() | |
) | |
if task.is_completed(): | |
await websocket.close() | |
async def notify_queue_positions(self): | |
for task_id in self.task_subscribers.keys(): | |
# notify only pending tasks | |
if self.tasks[task_id].task_status != TaskStatus.PENDING: | |
continue | |
await self.notify_task_subscribers(task_id) | |