import asyncio import importlib.metadata import logging import tempfile from contextlib import asynccontextmanager from io import BytesIO from pathlib import Path from typing import Annotated, Any, Optional, Union from fastapi import ( BackgroundTasks, Depends, FastAPI, HTTPException, Query, UploadFile, WebSocket, WebSocketDisconnect, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from docling.datamodel.base_models import DocumentStream, InputFormat from docling.document_converter import DocumentConverter from docling_serve.datamodel.convert import ConvertDocumentsOptions from docling_serve.datamodel.requests import ( ConvertDocumentFileSourcesRequest, ConvertDocumentsRequest, ) from docling_serve.datamodel.responses import ( ConvertDocumentResponse, HealthCheckResponse, MessageKind, TaskStatusResponse, WebsocketMessage, ) from docling_serve.docling_conversion import ( convert_documents, converters, get_pdf_pipeline_opts, ) from docling_serve.engines import get_orchestrator from docling_serve.engines.async_local.orchestrator import ( AsyncLocalOrchestrator, TaskNotFoundError, ) from docling_serve.helper_functions import FormDepends from docling_serve.response_preparation import process_results from docling_serve.settings import docling_serve_settings # Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging class ColoredLogFormatter(logging.Formatter): COLOR_CODES = { logging.DEBUG: "\033[94m", # Blue logging.INFO: "\033[92m", # Green logging.WARNING: "\033[93m", # Yellow logging.ERROR: "\033[91m", # Red logging.CRITICAL: "\033[95m", # Magenta } RESET_CODE = "\033[0m" def format(self, record): color = self.COLOR_CODES.get(record.levelno, "") record.levelname = f"{color}{record.levelname}{self.RESET_CODE}" return super().format(record) logging.basicConfig( level=logging.INFO, # Set the logging level format="%(levelname)s:\t%(asctime)s - %(name)s - %(message)s", datefmt="%H:%M:%S", ) # Override the formatter with the custom ColoredLogFormatter root_logger = logging.getLogger() # Get the root logger for handler in root_logger.handlers: # Iterate through existing handlers if handler.formatter: handler.setFormatter(ColoredLogFormatter(handler.formatter._fmt)) _log = logging.getLogger(__name__) # Context manager to initialize and clean up the lifespan of the FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Converter with default options pdf_format_option, options_hash = get_pdf_pipeline_opts(ConvertDocumentsOptions()) converters[options_hash] = DocumentConverter( format_options={ InputFormat.PDF: pdf_format_option, InputFormat.IMAGE: pdf_format_option, } ) converters[options_hash].initialize_pipeline(InputFormat.PDF) orchestrator = get_orchestrator() # Start the background queue processor queue_task = asyncio.create_task(orchestrator.process_queue()) yield # Cancel the background queue processor on shutdown queue_task.cancel() try: await queue_task except asyncio.CancelledError: _log.info("Queue processor cancelled.") converters.clear() # if WITH_UI: # gradio_ui.close() ################################## # App creation and configuration # ################################## def create_app(): # noqa: C901 try: version = importlib.metadata.version("docling_serve") except importlib.metadata.PackageNotFoundError: _log.warning("Unable to get docling_serve version, falling back to 0.0.0") version = "0.0.0" app = FastAPI( title="Docling Serve", lifespan=lifespan, version=version, ) origins = ["*"] methods = ["*"] headers = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=methods, allow_headers=headers, ) # Mount the Gradio app if docling_serve_settings.enable_ui: try: import gradio as gr from docling_serve.gradio_ui import ui as gradio_ui tmp_output_dir = Path(tempfile.mkdtemp()) gradio_ui.gradio_output_dir = tmp_output_dir app = gr.mount_gradio_app( app, gradio_ui, path="/ui", allowed_paths=["./logo.png", tmp_output_dir], root_path="/ui", ) except ImportError: _log.warning( "Docling Serve enable_ui is activated, but gradio is not installed. " "Install it with `pip install docling-serve[ui]` " "or `pip install gradio`" ) ############################# # API Endpoints definitions # ############################# # Favicon @app.get("/favicon.ico", include_in_schema=False) async def favicon(): response = RedirectResponse( url="https://raw.githubusercontent.com/docling-project/docling/refs/heads/main/docs/assets/logo.svg" ) return response @app.get("/health") def health() -> HealthCheckResponse: return HealthCheckResponse() # API readiness compatibility for OpenShift AI Workbench @app.get("/api", include_in_schema=False) def api_check() -> HealthCheckResponse: return HealthCheckResponse() # Convert a document from URL(s) @app.post( "/v1alpha/convert/source", response_model=ConvertDocumentResponse, responses={ 200: { "content": {"application/zip": {}}, # "description": "Return the JSON item or an image.", } }, ) def process_url( background_tasks: BackgroundTasks, conversion_request: ConvertDocumentsRequest ): sources: list[Union[str, DocumentStream]] = [] headers: Optional[dict[str, Any]] = None if isinstance(conversion_request, ConvertDocumentFileSourcesRequest): for file_source in conversion_request.file_sources: sources.append(file_source.to_document_stream()) else: for http_source in conversion_request.http_sources: sources.append(http_source.url) if headers is None and http_source.headers: headers = http_source.headers # Note: results are only an iterator->lazy evaluation results = convert_documents( sources=sources, options=conversion_request.options, headers=headers ) # The real processing will happen here response = process_results( background_tasks=background_tasks, conversion_options=conversion_request.options, conv_results=results, ) return response # Convert a document from file(s) @app.post( "/v1alpha/convert/file", response_model=ConvertDocumentResponse, responses={ 200: { "content": {"application/zip": {}}, } }, ) async def process_file( background_tasks: BackgroundTasks, files: list[UploadFile], options: Annotated[ ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions) ], ): _log.info(f"Received {len(files)} files for processing.") # Load the uploaded files to Docling DocumentStream file_sources = [] for file in files: buf = BytesIO(file.file.read()) name = file.filename if file.filename else "file.pdf" file_sources.append(DocumentStream(name=name, stream=buf)) results = convert_documents(sources=file_sources, options=options) response = process_results( background_tasks=background_tasks, conversion_options=options, conv_results=results, ) return response # Convert a document from URL(s) using the async api @app.post( "/v1alpha/convert/source/async", response_model=TaskStatusResponse, ) async def process_url_async( orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], conversion_request: ConvertDocumentsRequest, ): task = await orchestrator.enqueue(request=conversion_request) task_queue_position = await orchestrator.get_queue_position( task_id=task.task_id ) return TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) # Task status poll @app.get( "/v1alpha/status/poll/{task_id}", response_model=TaskStatusResponse, ) async def task_status_poll( orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], task_id: str, wait: Annotated[ float, Query(help="Number of seconds to wait for a completed status.") ] = 0.0, ): try: task = await orchestrator.task_status(task_id=task_id, wait=wait) task_queue_position = await orchestrator.get_queue_position(task_id=task_id) except TaskNotFoundError: raise HTTPException(status_code=404, detail="Task not found.") return TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) # Task status websocket @app.websocket( "/v1alpha/status/ws/{task_id}", ) async def task_status_ws( websocket: WebSocket, orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], task_id: str, ): await websocket.accept() if task_id not in orchestrator.tasks: await websocket.send_text( WebsocketMessage( message=MessageKind.ERROR, error="Task not found." ).model_dump_json() ) await websocket.close() return task = orchestrator.tasks[task_id] # Track active WebSocket connections for this job orchestrator.task_subscribers[task_id].add(websocket) try: task_queue_position = await orchestrator.get_queue_position(task_id=task_id) task_response = TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) await websocket.send_text( WebsocketMessage( message=MessageKind.CONNECTION, task=task_response ).model_dump_json() ) while True: task_queue_position = await orchestrator.get_queue_position( task_id=task_id ) task_response = TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) await websocket.send_text( WebsocketMessage( message=MessageKind.UPDATE, task=task_response ).model_dump_json() ) # each client message will be interpreted as a request for update msg = await websocket.receive_text() _log.debug(f"Received message: {msg}") except WebSocketDisconnect: _log.info(f"WebSocket disconnected for job {task_id}") finally: orchestrator.task_subscribers[task_id].remove(websocket) # Task result @app.get( "/v1alpha/result/{task_id}", response_model=ConvertDocumentResponse, responses={ 200: { "content": {"application/zip": {}}, } }, ) async def task_result( orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], task_id: str, ): result = await orchestrator.task_result(task_id=task_id) if result is None: raise HTTPException( status_code=404, detail="Task result not found. Please wait for a completion status.", ) return result return app