import asyncio import importlib.metadata import logging import tempfile from contextlib import asynccontextmanager from io import BytesIO from pathlib import Path from typing import Annotated, Any, Dict, List, Optional, Union from fastapi import ( BackgroundTasks, Depends, FastAPI, HTTPException, Query, UploadFile, WebSocket, WebSocketDisconnect, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from docling.datamodel.base_models import DocumentStream, InputFormat from docling.document_converter import DocumentConverter from docling_serve.datamodel.convert import ConvertDocumentsOptions from docling_serve.datamodel.requests import ( ConvertDocumentFileSourcesRequest, ConvertDocumentsRequest, ) from docling_serve.datamodel.responses import ( ConvertDocumentResponse, HealthCheckResponse, MessageKind, TaskStatusResponse, WebsocketMessage, ) from docling_serve.docling_conversion import ( convert_documents, converters, get_pdf_pipeline_opts, ) from docling_serve.engines import get_orchestrator from docling_serve.engines.async_local.orchestrator import ( AsyncLocalOrchestrator, TaskNotFoundError, ) from docling_serve.helper_functions import FormDepends from docling_serve.response_preparation import process_results from docling_serve.settings import docling_serve_settings # Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging class ColoredLogFormatter(logging.Formatter): COLOR_CODES = { logging.DEBUG: "\033[94m", # Blue logging.INFO: "\033[92m", # Green logging.WARNING: "\033[93m", # Yellow logging.ERROR: "\033[91m", # Red logging.CRITICAL: "\033[95m", # Magenta } RESET_CODE = "\033[0m" def format(self, record): color = self.COLOR_CODES.get(record.levelno, "") record.levelname = f"{color}{record.levelname}{self.RESET_CODE}" return super().format(record) logging.basicConfig( level=logging.INFO, # Set the logging level format="%(levelname)s:\t%(asctime)s - %(name)s - %(message)s", datefmt="%H:%M:%S", ) # Override the formatter with the custom ColoredLogFormatter root_logger = logging.getLogger() # Get the root logger for handler in root_logger.handlers: # Iterate through existing handlers if handler.formatter: handler.setFormatter(ColoredLogFormatter(handler.formatter._fmt)) _log = logging.getLogger(__name__) # Context manager to initialize and clean up the lifespan of the FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): # Converter with default options pdf_format_option, options_hash = get_pdf_pipeline_opts(ConvertDocumentsOptions()) converters[options_hash] = DocumentConverter( format_options={ InputFormat.PDF: pdf_format_option, InputFormat.IMAGE: pdf_format_option, } ) converters[options_hash].initialize_pipeline(InputFormat.PDF) orchestrator = get_orchestrator() # Start the background queue processor queue_task = asyncio.create_task(orchestrator.process_queue()) yield # Cancel the background queue processor on shutdown queue_task.cancel() try: await queue_task except asyncio.CancelledError: _log.info("Queue processor cancelled.") converters.clear() # if WITH_UI: # gradio_ui.close() ################################## # App creation and configuration # ################################## def create_app(): # noqa: C901 try: version = importlib.metadata.version("docling_serve") except importlib.metadata.PackageNotFoundError: _log.warning("Unable to get docling_serve version, falling back to 0.0.0") version = "0.0.0" app = FastAPI( title="Docling Serve", lifespan=lifespan, version=version, ) origins = ["*"] methods = ["*"] headers = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=methods, allow_headers=headers, ) # Mount the Gradio app if docling_serve_settings.enable_ui: try: import gradio as gr from docling_serve.gradio_ui import ui as gradio_ui tmp_output_dir = Path(tempfile.mkdtemp()) gradio_ui.gradio_output_dir = tmp_output_dir app = gr.mount_gradio_app( app, gradio_ui, path="/ui", allowed_paths=["./logo.png", tmp_output_dir], root_path="/ui", ) except ImportError: _log.warning( "Docling Serve enable_ui is activated, but gradio is not installed. " "Install it with `pip install docling-serve[ui]` " "or `pip install gradio`" ) ############################# # API Endpoints definitions # ############################# # Favicon @app.get("/favicon.ico", include_in_schema=False) async def favicon(): response = RedirectResponse( url="https://ds4sd.github.io/docling/assets/logo.png" ) return response @app.get("/health") def health() -> HealthCheckResponse: return HealthCheckResponse() # API readiness compatibility for OpenShift AI Workbench @app.get("/api", include_in_schema=False) def api_check() -> HealthCheckResponse: return HealthCheckResponse() # Convert a document from URL(s) @app.post( "/v1alpha/convert/source", response_model=ConvertDocumentResponse, responses={ 200: { "content": {"application/zip": {}}, # "description": "Return the JSON item or an image.", } }, ) def process_url( background_tasks: BackgroundTasks, conversion_request: ConvertDocumentsRequest ): sources: List[Union[str, DocumentStream]] = [] headers: Optional[Dict[str, Any]] = None if isinstance(conversion_request, ConvertDocumentFileSourcesRequest): for file_source in conversion_request.file_sources: sources.append(file_source.to_document_stream()) else: for http_source in conversion_request.http_sources: sources.append(http_source.url) if headers is None and http_source.headers: headers = http_source.headers # Note: results are only an iterator->lazy evaluation results = convert_documents( sources=sources, options=conversion_request.options, headers=headers ) # The real processing will happen here response = process_results( background_tasks=background_tasks, conversion_options=conversion_request.options, conv_results=results, ) return response # Convert a document from file(s) @app.post( "/v1alpha/convert/file", response_model=ConvertDocumentResponse, responses={ 200: { "content": {"application/zip": {}}, } }, ) async def process_file( background_tasks: BackgroundTasks, files: List[UploadFile], options: Annotated[ ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions) ], ): _log.info(f"Received {len(files)} files for processing.") # Load the uploaded files to Docling DocumentStream file_sources = [] for file in files: buf = BytesIO(file.file.read()) name = file.filename if file.filename else "file.pdf" file_sources.append(DocumentStream(name=name, stream=buf)) results = convert_documents(sources=file_sources, options=options) response = process_results( background_tasks=background_tasks, conversion_options=options, conv_results=results, ) return response # Convert a document from URL(s) using the async api @app.post( "/v1alpha/convert/source/async", response_model=TaskStatusResponse, ) async def process_url_async( orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], conversion_request: ConvertDocumentsRequest, ): task = await orchestrator.enqueue(request=conversion_request) task_queue_position = await orchestrator.get_queue_position( task_id=task.task_id ) return TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) # Task status poll @app.get( "/v1alpha/status/poll/{task_id}", response_model=TaskStatusResponse, ) async def task_status_poll( orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], task_id: str, wait: Annotated[ float, Query(help="Number of seconds to wait for a completed status.") ] = 0.0, ): try: task = await orchestrator.task_status(task_id=task_id, wait=wait) task_queue_position = await orchestrator.get_queue_position(task_id=task_id) except TaskNotFoundError: raise HTTPException(status_code=404, detail="Task not found.") return TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) # Task status websocket @app.websocket( "/v1alpha/status/ws/{task_id}", ) async def task_status_ws( websocket: WebSocket, orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], task_id: str, ): await websocket.accept() if task_id not in orchestrator.tasks: await websocket.send_text( WebsocketMessage( message=MessageKind.ERROR, error="Task not found." ).model_dump_json() ) await websocket.close() return task = orchestrator.tasks[task_id] # Track active WebSocket connections for this job orchestrator.task_subscribers[task_id].add(websocket) try: task_queue_position = await orchestrator.get_queue_position(task_id=task_id) task_response = TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) await websocket.send_text( WebsocketMessage( message=MessageKind.CONNECTION, task=task_response ).model_dump_json() ) while True: task_queue_position = await orchestrator.get_queue_position( task_id=task_id ) task_response = TaskStatusResponse( task_id=task.task_id, task_status=task.task_status, task_position=task_queue_position, ) await websocket.send_text( WebsocketMessage( message=MessageKind.UPDATE, task=task_response ).model_dump_json() ) # each client message will be interpreted as a request for update msg = await websocket.receive_text() _log.debug(f"Received message: {msg}") except WebSocketDisconnect: _log.info(f"WebSocket disconnected for job {task_id}") finally: orchestrator.task_subscribers[task_id].remove(websocket) # Task result @app.get( "/v1alpha/result/{task_id}", response_model=ConvertDocumentResponse, responses={ 200: { "content": {"application/zip": {}}, } }, ) async def task_result( orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], task_id: str, ): result = await orchestrator.task_result(task_id=task_id) if result is None: raise HTTPException( status_code=404, detail="Task result not found. Please wait for a completion status.", ) return result return app