Spaces:
Configuration error
Configuration error
import asyncio | |
import importlib.metadata | |
import logging | |
import tempfile | |
from contextlib import asynccontextmanager | |
from io import BytesIO | |
from pathlib import Path | |
from typing import Annotated, Any, Optional, Union | |
from fastapi import ( | |
BackgroundTasks, | |
Depends, | |
FastAPI, | |
HTTPException, | |
Query, | |
UploadFile, | |
WebSocket, | |
WebSocketDisconnect, | |
) | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import RedirectResponse | |
from docling.datamodel.base_models import DocumentStream, InputFormat | |
from docling.document_converter import DocumentConverter | |
from docling_serve.datamodel.convert import ConvertDocumentsOptions | |
from docling_serve.datamodel.requests import ( | |
ConvertDocumentFileSourcesRequest, | |
ConvertDocumentsRequest, | |
) | |
from docling_serve.datamodel.responses import ( | |
ConvertDocumentResponse, | |
HealthCheckResponse, | |
MessageKind, | |
TaskStatusResponse, | |
WebsocketMessage, | |
) | |
from docling_serve.docling_conversion import ( | |
convert_documents, | |
converters, | |
get_pdf_pipeline_opts, | |
) | |
from docling_serve.engines import get_orchestrator | |
from docling_serve.engines.async_local.orchestrator import ( | |
AsyncLocalOrchestrator, | |
TaskNotFoundError, | |
) | |
from docling_serve.helper_functions import FormDepends | |
from docling_serve.response_preparation import process_results | |
from docling_serve.settings import docling_serve_settings | |
# Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging | |
class ColoredLogFormatter(logging.Formatter): | |
COLOR_CODES = { | |
logging.DEBUG: "\033[94m", # Blue | |
logging.INFO: "\033[92m", # Green | |
logging.WARNING: "\033[93m", # Yellow | |
logging.ERROR: "\033[91m", # Red | |
logging.CRITICAL: "\033[95m", # Magenta | |
} | |
RESET_CODE = "\033[0m" | |
def format(self, record): | |
color = self.COLOR_CODES.get(record.levelno, "") | |
record.levelname = f"{color}{record.levelname}{self.RESET_CODE}" | |
return super().format(record) | |
logging.basicConfig( | |
level=logging.INFO, # Set the logging level | |
format="%(levelname)s:\t%(asctime)s - %(name)s - %(message)s", | |
datefmt="%H:%M:%S", | |
) | |
# Override the formatter with the custom ColoredLogFormatter | |
root_logger = logging.getLogger() # Get the root logger | |
for handler in root_logger.handlers: # Iterate through existing handlers | |
if handler.formatter: | |
handler.setFormatter(ColoredLogFormatter(handler.formatter._fmt)) | |
_log = logging.getLogger(__name__) | |
# Context manager to initialize and clean up the lifespan of the FastAPI app | |
async def lifespan(app: FastAPI): | |
# Converter with default options | |
pdf_format_option, options_hash = get_pdf_pipeline_opts(ConvertDocumentsOptions()) | |
converters[options_hash] = DocumentConverter( | |
format_options={ | |
InputFormat.PDF: pdf_format_option, | |
InputFormat.IMAGE: pdf_format_option, | |
} | |
) | |
converters[options_hash].initialize_pipeline(InputFormat.PDF) | |
orchestrator = get_orchestrator() | |
# Start the background queue processor | |
queue_task = asyncio.create_task(orchestrator.process_queue()) | |
yield | |
# Cancel the background queue processor on shutdown | |
queue_task.cancel() | |
try: | |
await queue_task | |
except asyncio.CancelledError: | |
_log.info("Queue processor cancelled.") | |
converters.clear() | |
# if WITH_UI: | |
# gradio_ui.close() | |
################################## | |
# App creation and configuration # | |
################################## | |
def create_app(): # noqa: C901 | |
try: | |
version = importlib.metadata.version("docling_serve") | |
except importlib.metadata.PackageNotFoundError: | |
_log.warning("Unable to get docling_serve version, falling back to 0.0.0") | |
version = "0.0.0" | |
app = FastAPI( | |
title="Docling Serve", | |
lifespan=lifespan, | |
version=version, | |
) | |
origins = ["*"] | |
methods = ["*"] | |
headers = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=methods, | |
allow_headers=headers, | |
) | |
# Mount the Gradio app | |
if docling_serve_settings.enable_ui: | |
try: | |
import gradio as gr | |
from docling_serve.gradio_ui import ui as gradio_ui | |
tmp_output_dir = Path(tempfile.mkdtemp()) | |
gradio_ui.gradio_output_dir = tmp_output_dir | |
app = gr.mount_gradio_app( | |
app, | |
gradio_ui, | |
path="/ui", | |
allowed_paths=["./logo.png", tmp_output_dir], | |
root_path="/ui", | |
) | |
except ImportError: | |
_log.warning( | |
"Docling Serve enable_ui is activated, but gradio is not installed. " | |
"Install it with `pip install docling-serve[ui]` " | |
"or `pip install gradio`" | |
) | |
############################# | |
# API Endpoints definitions # | |
############################# | |
# Favicon | |
async def favicon(): | |
response = RedirectResponse( | |
url="https://raw.githubusercontent.com/docling-project/docling/refs/heads/main/docs/assets/logo.svg" | |
) | |
return response | |
def health() -> HealthCheckResponse: | |
return HealthCheckResponse() | |
# API readiness compatibility for OpenShift AI Workbench | |
def api_check() -> HealthCheckResponse: | |
return HealthCheckResponse() | |
# Convert a document from URL(s) | |
def process_url( | |
background_tasks: BackgroundTasks, conversion_request: ConvertDocumentsRequest | |
): | |
sources: list[Union[str, DocumentStream]] = [] | |
headers: Optional[dict[str, Any]] = None | |
if isinstance(conversion_request, ConvertDocumentFileSourcesRequest): | |
for file_source in conversion_request.file_sources: | |
sources.append(file_source.to_document_stream()) | |
else: | |
for http_source in conversion_request.http_sources: | |
sources.append(http_source.url) | |
if headers is None and http_source.headers: | |
headers = http_source.headers | |
# Note: results are only an iterator->lazy evaluation | |
results = convert_documents( | |
sources=sources, options=conversion_request.options, headers=headers | |
) | |
# The real processing will happen here | |
response = process_results( | |
background_tasks=background_tasks, | |
conversion_options=conversion_request.options, | |
conv_results=results, | |
) | |
return response | |
# Convert a document from file(s) | |
async def process_file( | |
background_tasks: BackgroundTasks, | |
files: list[UploadFile], | |
options: Annotated[ | |
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions) | |
], | |
): | |
_log.info(f"Received {len(files)} files for processing.") | |
# Load the uploaded files to Docling DocumentStream | |
file_sources = [] | |
for file in files: | |
buf = BytesIO(file.file.read()) | |
name = file.filename if file.filename else "file.pdf" | |
file_sources.append(DocumentStream(name=name, stream=buf)) | |
results = convert_documents(sources=file_sources, options=options) | |
response = process_results( | |
background_tasks=background_tasks, | |
conversion_options=options, | |
conv_results=results, | |
) | |
return response | |
# Convert a document from URL(s) using the async api | |
async def process_url_async( | |
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], | |
conversion_request: ConvertDocumentsRequest, | |
): | |
task = await orchestrator.enqueue(request=conversion_request) | |
task_queue_position = await orchestrator.get_queue_position( | |
task_id=task.task_id | |
) | |
return TaskStatusResponse( | |
task_id=task.task_id, | |
task_status=task.task_status, | |
task_position=task_queue_position, | |
) | |
# Task status poll | |
async def task_status_poll( | |
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], | |
task_id: str, | |
wait: Annotated[ | |
float, Query(help="Number of seconds to wait for a completed status.") | |
] = 0.0, | |
): | |
try: | |
task = await orchestrator.task_status(task_id=task_id, wait=wait) | |
task_queue_position = await orchestrator.get_queue_position(task_id=task_id) | |
except TaskNotFoundError: | |
raise HTTPException(status_code=404, detail="Task not found.") | |
return TaskStatusResponse( | |
task_id=task.task_id, | |
task_status=task.task_status, | |
task_position=task_queue_position, | |
) | |
# Task status websocket | |
async def task_status_ws( | |
websocket: WebSocket, | |
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], | |
task_id: str, | |
): | |
await websocket.accept() | |
if task_id not in orchestrator.tasks: | |
await websocket.send_text( | |
WebsocketMessage( | |
message=MessageKind.ERROR, error="Task not found." | |
).model_dump_json() | |
) | |
await websocket.close() | |
return | |
task = orchestrator.tasks[task_id] | |
# Track active WebSocket connections for this job | |
orchestrator.task_subscribers[task_id].add(websocket) | |
try: | |
task_queue_position = await orchestrator.get_queue_position(task_id=task_id) | |
task_response = TaskStatusResponse( | |
task_id=task.task_id, | |
task_status=task.task_status, | |
task_position=task_queue_position, | |
) | |
await websocket.send_text( | |
WebsocketMessage( | |
message=MessageKind.CONNECTION, task=task_response | |
).model_dump_json() | |
) | |
while True: | |
task_queue_position = await orchestrator.get_queue_position( | |
task_id=task_id | |
) | |
task_response = TaskStatusResponse( | |
task_id=task.task_id, | |
task_status=task.task_status, | |
task_position=task_queue_position, | |
) | |
await websocket.send_text( | |
WebsocketMessage( | |
message=MessageKind.UPDATE, task=task_response | |
).model_dump_json() | |
) | |
# each client message will be interpreted as a request for update | |
msg = await websocket.receive_text() | |
_log.debug(f"Received message: {msg}") | |
except WebSocketDisconnect: | |
_log.info(f"WebSocket disconnected for job {task_id}") | |
finally: | |
orchestrator.task_subscribers[task_id].remove(websocket) | |
# Task result | |
async def task_result( | |
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)], | |
task_id: str, | |
): | |
result = await orchestrator.task_result(task_id=task_id) | |
if result is None: | |
raise HTTPException( | |
status_code=404, | |
detail="Task result not found. Please wait for a completion status.", | |
) | |
return result | |
return app | |