File size: 12,614 Bytes
16d905e
a4a5700
22bc712
 
44657b5
 
22bc712
360d0c5
22bc712
16d905e
 
 
 
 
 
 
 
 
 
22bc712
 
 
16d905e
 
 
 
 
22bc712
 
16d905e
 
 
 
 
 
 
 
 
22bc712
 
 
44657b5
16d905e
 
 
 
 
e6a25a6
16d905e
e6a25a6
44657b5
84568d7
22bc712
 
 
 
 
 
 
 
 
 
84568d7
22bc712
 
 
 
84568d7
 
22bc712
 
 
 
 
84568d7
22bc712
 
 
 
 
84568d7
22bc712
44657b5
 
22bc712
44657b5
 
84568d7
22bc712
84568d7
 
22bc712
 
84568d7
 
 
 
 
16d905e
 
 
 
 
44657b5
 
16d905e
 
 
 
 
 
 
84568d7
16d905e
e6a25a6
 
44657b5
 
22bc712
 
 
 
 
16d905e
a4a5700
 
 
 
 
 
 
e6a25a6
 
 
a4a5700
22bc712
84568d7
e6a25a6
 
 
84568d7
e6a25a6
 
 
 
 
 
84568d7
22bc712
e6a25a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258f9e2
e6a25a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44657b5
e6a25a6
 
 
360d0c5
 
e6a25a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84568d7
e6a25a6
 
360d0c5
e6a25a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16d905e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6a25a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import asyncio
import importlib.metadata
import logging
import tempfile
from contextlib import asynccontextmanager
from io import BytesIO
from pathlib import Path
from typing import Annotated, Any, Optional, Union

from fastapi import (
    BackgroundTasks,
    Depends,
    FastAPI,
    HTTPException,
    Query,
    UploadFile,
    WebSocket,
    WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse

from docling.datamodel.base_models import DocumentStream, InputFormat
from docling.document_converter import DocumentConverter

from docling_serve.datamodel.convert import ConvertDocumentsOptions
from docling_serve.datamodel.requests import (
    ConvertDocumentFileSourcesRequest,
    ConvertDocumentsRequest,
)
from docling_serve.datamodel.responses import (
    ConvertDocumentResponse,
    HealthCheckResponse,
    MessageKind,
    TaskStatusResponse,
    WebsocketMessage,
)
from docling_serve.docling_conversion import (
    convert_documents,
    converters,
    get_pdf_pipeline_opts,
)
from docling_serve.engines import get_orchestrator
from docling_serve.engines.async_local.orchestrator import (
    AsyncLocalOrchestrator,
    TaskNotFoundError,
)
from docling_serve.helper_functions import FormDepends
from docling_serve.response_preparation import process_results
from docling_serve.settings import docling_serve_settings


# Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging
class ColoredLogFormatter(logging.Formatter):
    COLOR_CODES = {
        logging.DEBUG: "\033[94m",  # Blue
        logging.INFO: "\033[92m",  # Green
        logging.WARNING: "\033[93m",  # Yellow
        logging.ERROR: "\033[91m",  # Red
        logging.CRITICAL: "\033[95m",  # Magenta
    }
    RESET_CODE = "\033[0m"

    def format(self, record):
        color = self.COLOR_CODES.get(record.levelno, "")
        record.levelname = f"{color}{record.levelname}{self.RESET_CODE}"
        return super().format(record)


logging.basicConfig(
    level=logging.INFO,  # Set the logging level
    format="%(levelname)s:\t%(asctime)s - %(name)s - %(message)s",
    datefmt="%H:%M:%S",
)

# Override the formatter with the custom ColoredLogFormatter
root_logger = logging.getLogger()  # Get the root logger
for handler in root_logger.handlers:  # Iterate through existing handlers
    if handler.formatter:
        handler.setFormatter(ColoredLogFormatter(handler.formatter._fmt))

_log = logging.getLogger(__name__)


# Context manager to initialize and clean up the lifespan of the FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Converter with default options
    pdf_format_option, options_hash = get_pdf_pipeline_opts(ConvertDocumentsOptions())
    converters[options_hash] = DocumentConverter(
        format_options={
            InputFormat.PDF: pdf_format_option,
            InputFormat.IMAGE: pdf_format_option,
        }
    )

    converters[options_hash].initialize_pipeline(InputFormat.PDF)

    orchestrator = get_orchestrator()

    # Start the background queue processor
    queue_task = asyncio.create_task(orchestrator.process_queue())

    yield

    # Cancel the background queue processor on shutdown
    queue_task.cancel()
    try:
        await queue_task
    except asyncio.CancelledError:
        _log.info("Queue processor cancelled.")

    converters.clear()

    # if WITH_UI:
    #     gradio_ui.close()


##################################
# App creation and configuration #
##################################


def create_app():  # noqa: C901
    try:
        version = importlib.metadata.version("docling_serve")
    except importlib.metadata.PackageNotFoundError:
        _log.warning("Unable to get docling_serve version, falling back to 0.0.0")

        version = "0.0.0"

    app = FastAPI(
        title="Docling Serve",
        lifespan=lifespan,
        version=version,
    )

    origins = ["*"]
    methods = ["*"]
    headers = ["*"]

    app.add_middleware(
        CORSMiddleware,
        allow_origins=origins,
        allow_credentials=True,
        allow_methods=methods,
        allow_headers=headers,
    )

    # Mount the Gradio app
    if docling_serve_settings.enable_ui:
        try:
            import gradio as gr

            from docling_serve.gradio_ui import ui as gradio_ui

            tmp_output_dir = Path(tempfile.mkdtemp())
            gradio_ui.gradio_output_dir = tmp_output_dir
            app = gr.mount_gradio_app(
                app,
                gradio_ui,
                path="/ui",
                allowed_paths=["./logo.png", tmp_output_dir],
                root_path="/ui",
            )
        except ImportError:
            _log.warning(
                "Docling Serve enable_ui is activated, but gradio is not installed. "
                "Install it with `pip install docling-serve[ui]` "
                "or `pip install gradio`"
            )

    #############################
    # API Endpoints definitions #
    #############################

    # Favicon
    @app.get("/favicon.ico", include_in_schema=False)
    async def favicon():
        response = RedirectResponse(
            url="https://raw.githubusercontent.com/docling-project/docling/refs/heads/main/docs/assets/logo.svg"
        )
        return response

    @app.get("/health")
    def health() -> HealthCheckResponse:
        return HealthCheckResponse()

    # API readiness compatibility for OpenShift AI Workbench
    @app.get("/api", include_in_schema=False)
    def api_check() -> HealthCheckResponse:
        return HealthCheckResponse()

    # Convert a document from URL(s)
    @app.post(
        "/v1alpha/convert/source",
        response_model=ConvertDocumentResponse,
        responses={
            200: {
                "content": {"application/zip": {}},
                # "description": "Return the JSON item or an image.",
            }
        },
    )
    def process_url(
        background_tasks: BackgroundTasks, conversion_request: ConvertDocumentsRequest
    ):
        sources: list[Union[str, DocumentStream]] = []
        headers: Optional[dict[str, Any]] = None
        if isinstance(conversion_request, ConvertDocumentFileSourcesRequest):
            for file_source in conversion_request.file_sources:
                sources.append(file_source.to_document_stream())
        else:
            for http_source in conversion_request.http_sources:
                sources.append(http_source.url)
                if headers is None and http_source.headers:
                    headers = http_source.headers

        # Note: results are only an iterator->lazy evaluation
        results = convert_documents(
            sources=sources, options=conversion_request.options, headers=headers
        )

        # The real processing will happen here
        response = process_results(
            background_tasks=background_tasks,
            conversion_options=conversion_request.options,
            conv_results=results,
        )

        return response

    # Convert a document from file(s)
    @app.post(
        "/v1alpha/convert/file",
        response_model=ConvertDocumentResponse,
        responses={
            200: {
                "content": {"application/zip": {}},
            }
        },
    )
    async def process_file(
        background_tasks: BackgroundTasks,
        files: list[UploadFile],
        options: Annotated[
            ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
        ],
    ):
        _log.info(f"Received {len(files)} files for processing.")

        # Load the uploaded files to Docling DocumentStream
        file_sources = []
        for file in files:
            buf = BytesIO(file.file.read())
            name = file.filename if file.filename else "file.pdf"
            file_sources.append(DocumentStream(name=name, stream=buf))

        results = convert_documents(sources=file_sources, options=options)

        response = process_results(
            background_tasks=background_tasks,
            conversion_options=options,
            conv_results=results,
        )

        return response

    # Convert a document from URL(s) using the async api
    @app.post(
        "/v1alpha/convert/source/async",
        response_model=TaskStatusResponse,
    )
    async def process_url_async(
        orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
        conversion_request: ConvertDocumentsRequest,
    ):
        task = await orchestrator.enqueue(request=conversion_request)
        task_queue_position = await orchestrator.get_queue_position(
            task_id=task.task_id
        )
        return TaskStatusResponse(
            task_id=task.task_id,
            task_status=task.task_status,
            task_position=task_queue_position,
        )

    # Task status poll
    @app.get(
        "/v1alpha/status/poll/{task_id}",
        response_model=TaskStatusResponse,
    )
    async def task_status_poll(
        orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
        task_id: str,
        wait: Annotated[
            float, Query(help="Number of seconds to wait for a completed status.")
        ] = 0.0,
    ):
        try:
            task = await orchestrator.task_status(task_id=task_id, wait=wait)
            task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
        except TaskNotFoundError:
            raise HTTPException(status_code=404, detail="Task not found.")
        return TaskStatusResponse(
            task_id=task.task_id,
            task_status=task.task_status,
            task_position=task_queue_position,
        )

    # Task status websocket
    @app.websocket(
        "/v1alpha/status/ws/{task_id}",
    )
    async def task_status_ws(
        websocket: WebSocket,
        orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
        task_id: str,
    ):
        await websocket.accept()

        if task_id not in orchestrator.tasks:
            await websocket.send_text(
                WebsocketMessage(
                    message=MessageKind.ERROR, error="Task not found."
                ).model_dump_json()
            )
            await websocket.close()
            return

        task = orchestrator.tasks[task_id]

        # Track active WebSocket connections for this job
        orchestrator.task_subscribers[task_id].add(websocket)

        try:
            task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
            task_response = TaskStatusResponse(
                task_id=task.task_id,
                task_status=task.task_status,
                task_position=task_queue_position,
            )
            await websocket.send_text(
                WebsocketMessage(
                    message=MessageKind.CONNECTION, task=task_response
                ).model_dump_json()
            )
            while True:
                task_queue_position = await orchestrator.get_queue_position(
                    task_id=task_id
                )
                task_response = TaskStatusResponse(
                    task_id=task.task_id,
                    task_status=task.task_status,
                    task_position=task_queue_position,
                )
                await websocket.send_text(
                    WebsocketMessage(
                        message=MessageKind.UPDATE, task=task_response
                    ).model_dump_json()
                )
                # each client message will be interpreted as a request for update
                msg = await websocket.receive_text()
                _log.debug(f"Received message: {msg}")

        except WebSocketDisconnect:
            _log.info(f"WebSocket disconnected for job {task_id}")

        finally:
            orchestrator.task_subscribers[task_id].remove(websocket)

    # Task result
    @app.get(
        "/v1alpha/result/{task_id}",
        response_model=ConvertDocumentResponse,
        responses={
            200: {
                "content": {"application/zip": {}},
            }
        },
    )
    async def task_result(
        orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
        task_id: str,
    ):
        result = await orchestrator.task_result(task_id=task_id)
        if result is None:
            raise HTTPException(
                status_code=404,
                detail="Task result not found. Please wait for a completion status.",
            )
        return result

    return app