Spaces:
Configuration error
Configuration error
| import asyncio | |
| import logging | |
| import uuid | |
| from typing import Optional | |
| from fastapi import WebSocket | |
| from docling_serve.datamodel.engines import Task, TaskStatus | |
| from docling_serve.datamodel.requests import ConvertDocumentsRequest | |
| from docling_serve.datamodel.responses import ( | |
| MessageKind, | |
| TaskStatusResponse, | |
| WebsocketMessage, | |
| ) | |
| from docling_serve.engines.async_local.worker import AsyncLocalWorker | |
| from docling_serve.engines.base_orchestrator import BaseOrchestrator | |
| from docling_serve.settings import docling_serve_settings | |
| _log = logging.getLogger(__name__) | |
| class OrchestratorError(Exception): | |
| pass | |
| class TaskNotFoundError(OrchestratorError): | |
| pass | |
| class AsyncLocalOrchestrator(BaseOrchestrator): | |
| def __init__(self): | |
| self.task_queue = asyncio.Queue() | |
| self.tasks: dict[str, Task] = {} | |
| self.queue_list: list[str] = [] | |
| self.task_subscribers: dict[str, set[WebSocket]] = {} | |
| async def enqueue(self, request: ConvertDocumentsRequest) -> Task: | |
| task_id = str(uuid.uuid4()) | |
| task = Task(task_id=task_id, request=request) | |
| self.tasks[task_id] = task | |
| self.queue_list.append(task_id) | |
| self.task_subscribers[task_id] = set() | |
| await self.task_queue.put(task_id) | |
| return task | |
| async def queue_size(self) -> int: | |
| return self.task_queue.qsize() | |
| async def get_queue_position(self, task_id: str) -> Optional[int]: | |
| return ( | |
| self.queue_list.index(task_id) + 1 if task_id in self.queue_list else None | |
| ) | |
| async def task_status(self, task_id: str, wait: float = 0.0) -> Task: | |
| if task_id not in self.tasks: | |
| raise TaskNotFoundError() | |
| return self.tasks[task_id] | |
| async def task_result(self, task_id: str): | |
| if task_id not in self.tasks: | |
| raise TaskNotFoundError() | |
| return self.tasks[task_id].result | |
| async def process_queue(self): | |
| # Create a pool of workers | |
| workers = [] | |
| for i in range(docling_serve_settings.eng_loc_num_workers): | |
| _log.debug(f"Starting worker {i}") | |
| w = AsyncLocalWorker(i, self) | |
| worker_task = asyncio.create_task(w.loop()) | |
| workers.append(worker_task) | |
| # Wait for all workers to complete (they won't, as they run indefinitely) | |
| await asyncio.gather(*workers) | |
| _log.debug("All workers completed.") | |
| async def notify_task_subscribers(self, task_id: str): | |
| if task_id not in self.task_subscribers: | |
| raise RuntimeError(f"Task {task_id} does not have a subscribers list.") | |
| task = self.tasks[task_id] | |
| task_queue_position = await self.get_queue_position(task_id) | |
| msg = TaskStatusResponse( | |
| task_id=task.task_id, | |
| task_status=task.task_status, | |
| task_position=task_queue_position, | |
| ) | |
| for websocket in self.task_subscribers[task_id]: | |
| await websocket.send_text( | |
| WebsocketMessage(message=MessageKind.UPDATE, task=msg).model_dump_json() | |
| ) | |
| if task.is_completed(): | |
| await websocket.close() | |
| async def notify_queue_positions(self): | |
| for task_id in self.task_subscribers.keys(): | |
| # notify only pending tasks | |
| if self.tasks[task_id].task_status != TaskStatus.PENDING: | |
| continue | |
| await self.notify_task_subscribers(task_id) | |