Spaces:
Configuration error
Configuration error
File size: 4,742 Bytes
16d905e 360d0c5 16d905e 360d0c5 16d905e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import asyncio
import logging
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import BackgroundTasks
from docling.datamodel.base_models import DocumentStream
from docling_serve.datamodel.engines import TaskStatus
from docling_serve.datamodel.requests import ConvertDocumentFileSourcesRequest
from docling_serve.datamodel.responses import ConvertDocumentResponse
from docling_serve.docling_conversion import convert_documents
from docling_serve.response_preparation import process_results
if TYPE_CHECKING:
from docling_serve.engines.async_local.orchestrator import AsyncLocalOrchestrator
_log = logging.getLogger(__name__)
class AsyncLocalWorker:
def __init__(self, worker_id: int, orchestrator: "AsyncLocalOrchestrator"):
self.worker_id = worker_id
self.orchestrator = orchestrator
async def loop(self):
_log.debug(f"Starting loop for worker {self.worker_id}")
while True:
task_id: str = await self.orchestrator.task_queue.get()
self.orchestrator.queue_list.remove(task_id)
if task_id not in self.orchestrator.tasks:
raise RuntimeError(f"Task {task_id} not found.")
task = self.orchestrator.tasks[task_id]
try:
task.task_status = TaskStatus.STARTED
_log.info(f"Worker {self.worker_id} processing task {task_id}")
# Notify clients about task updates
await self.orchestrator.notify_task_subscribers(task_id)
# Notify clients about queue updates
await self.orchestrator.notify_queue_positions()
# Get the current event loop
asyncio.get_event_loop()
# Define a callback function to send progress updates to the client.
# TODO: send partial updates, e.g. when a document in the batch is done
def run_conversion():
sources: list[Union[str, DocumentStream]] = []
headers: Optional[dict[str, Any]] = None
if isinstance(task.request, ConvertDocumentFileSourcesRequest):
for file_source in task.request.file_sources:
sources.append(file_source.to_document_stream())
else:
for http_source in task.request.http_sources:
sources.append(http_source.url)
if headers is None and http_source.headers:
headers = http_source.headers
# Note: results are only an iterator->lazy evaluation
results = convert_documents(
sources=sources,
options=task.request.options,
headers=headers,
)
# The real processing will happen here
response = process_results(
background_tasks=BackgroundTasks(),
conversion_options=task.request.options,
conv_results=results,
)
return response
# Run the prediction in a thread to avoid blocking the event loop.
start_time = time.monotonic()
# future = asyncio.run_coroutine_threadsafe(
# run_conversion(),
# loop=loop
# )
# response = future.result()
response = await asyncio.to_thread(
run_conversion,
)
processing_time = time.monotonic() - start_time
if not isinstance(response, ConvertDocumentResponse):
_log.error(
f"Worker {self.worker_id} got un-processable "
"result for {task_id}: {type(response)}"
)
task.result = response
task.request = None
task.task_status = TaskStatus.SUCCESS
_log.info(
f"Worker {self.worker_id} completed job {task_id} "
f"in {processing_time:.2f} seconds"
)
except Exception as e:
_log.error(
f"Worker {self.worker_id} failed to process job {task_id}: {e}"
)
task.task_status = TaskStatus.FAILURE
finally:
await self.orchestrator.notify_task_subscribers(task_id)
self.orchestrator.task_queue.task_done()
_log.debug(f"Worker {self.worker_id} completely done with {task_id}")
|