Spaces:
Paused
Paused
# SPDX-License-Identifier: Apache-2.0 | |
""" | |
NOTE: This API server is used only for demonstrating usage of AsyncEngine | |
and simple performance benchmarks. It is not intended for production use. | |
For production use, we recommend using our OpenAI compatible server. | |
We are also not going to accept PRs modifying this file, please | |
change `vllm/entrypoints/openai/api_server.py` instead. | |
""" | |
import asyncio | |
import base64 | |
import json | |
import io | |
import ssl | |
from argparse import Namespace | |
from collections.abc import AsyncGenerator | |
from PIL import Image | |
from typing import Any, Optional | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse, Response, StreamingResponse | |
from vllm.engine.arg_utils import AsyncEngineArgs | |
from vllm.engine.async_llm_engine import AsyncLLMEngine | |
from vllm.entrypoints.launcher import serve_http | |
from vllm.entrypoints.utils import with_cancellation | |
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt | |
from vllm.logger import init_logger | |
from vllm.sampling_params import SamplingParams | |
from vllm.usage.usage_lib import UsageContext | |
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit | |
from vllm.version import __version__ as VLLM_VERSION | |
logger = init_logger("api_server") | |
TIMEOUT_KEEP_ALIVE = 5 # seconds. | |
app = FastAPI() | |
engine = None | |
async def health() -> Response: | |
"""Health check.""" | |
return Response(status_code=200) | |
async def generate(request: Request) -> Response: | |
"""Generate completion for the request. | |
The request should be a JSON object with the following fields: | |
- prompt: the prompt to use for the generation. | |
- stream: whether to stream the results or not. | |
- other fields: the sampling parameters (See `SamplingParams` for details). | |
""" | |
request_dict = await request.json() | |
return await _generate(request_dict, raw_request=request) | |
async def decode_image(image_base64: str) -> Image.Image: | |
image_data = base64.b64decode(image_base64) | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
async def custom_process_prompt(encoder_prompt: str, decoder_prompt: str, | |
image_base64: str) -> ExplicitEncoderDecoderPrompt: | |
assert engine is not None | |
tokenizer = engine.engine.get_tokenizer_group().tokenizer | |
image = await decode_image(image_base64) | |
if encoder_prompt == "": | |
encoder_prompt = "0" * 783 # For Dolphin | |
if decoder_prompt == "": | |
decoder_prompt_ids = tokenizer.bos_token_id | |
else: | |
decoder_prompt = f"<s>{decoder_prompt.strip()} <Answer/>" | |
decoder_prompt_ids = tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"] | |
enc_dec_prompt = ExplicitEncoderDecoderPrompt( | |
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), | |
decoder_prompt=TokensPrompt(prompt_token_ids=decoder_prompt_ids), | |
) | |
return enc_dec_prompt | |
async def _generate(request_dict: dict, raw_request: Request) -> Response: | |
encoder_prompt = request_dict.pop("encoder_prompt", "") | |
decoder_prompt = request_dict.pop("decoder_prompt", "") | |
image_base64 = request_dict.pop("image_base64", "") | |
stream = request_dict.pop("stream", False) | |
sampling_params = SamplingParams(**request_dict) | |
request_id = random_uuid() | |
assert engine is not None | |
enc_dec_prompt = await custom_process_prompt(encoder_prompt, decoder_prompt, image_base64) | |
results_generator = engine.generate(enc_dec_prompt, sampling_params, request_id) | |
# Streaming case | |
async def stream_results() -> AsyncGenerator[bytes, None]: | |
async for request_output in results_generator: | |
prompt = request_output.prompt | |
assert prompt is not None | |
text_outputs = [ | |
prompt + output.text for output in request_output.outputs | |
] | |
ret = {"text": text_outputs} | |
yield (json.dumps(ret) + "\n").encode("utf-8") | |
if stream: | |
return StreamingResponse(stream_results()) | |
# Non-streaming case | |
final_output = None | |
try: | |
async for request_output in results_generator: | |
final_output = request_output | |
except asyncio.CancelledError: | |
return Response(status_code=499) | |
assert final_output is not None | |
prompt = final_output.prompt | |
assert prompt is not None | |
text_outputs = [prompt + output.text.strip() for output in final_output.outputs] | |
ret = {"text": text_outputs} | |
return JSONResponse(ret) | |
def build_app(args: Namespace) -> FastAPI: | |
global app | |
app.root_path = args.root_path | |
return app | |
async def init_app( | |
args: Namespace, | |
llm_engine: Optional[AsyncLLMEngine] = None, | |
) -> FastAPI: | |
app = build_app(args) | |
global engine | |
engine_args = AsyncEngineArgs.from_cli_args(args) | |
engine = (llm_engine | |
if llm_engine is not None else AsyncLLMEngine.from_engine_args( | |
engine_args, usage_context=UsageContext.API_SERVER)) | |
app.state.engine_client = engine | |
return app | |
async def run_server(args: Namespace, | |
llm_engine: Optional[AsyncLLMEngine] = None, | |
**uvicorn_kwargs: Any) -> None: | |
logger.info("vLLM API server version %s", VLLM_VERSION) | |
logger.info("args: %s", args) | |
set_ulimit() | |
app = await init_app(args, llm_engine) | |
assert engine is not None | |
shutdown_task = await serve_http( | |
app, | |
sock=None, | |
enable_ssl_refresh=args.enable_ssl_refresh, | |
host=args.host, | |
port=args.port, | |
log_level=args.log_level, | |
timeout_keep_alive=TIMEOUT_KEEP_ALIVE, | |
ssl_keyfile=args.ssl_keyfile, | |
ssl_certfile=args.ssl_certfile, | |
ssl_ca_certs=args.ssl_ca_certs, | |
ssl_cert_reqs=args.ssl_cert_reqs, | |
**uvicorn_kwargs, | |
) | |
await shutdown_task | |
if __name__ == "__main__": | |
parser = FlexibleArgumentParser() | |
parser.add_argument("--host", type=str, default=None) | |
parser.add_argument("--port", type=parser.check_port, default=8000) | |
parser.add_argument("--ssl-keyfile", type=str, default=None) | |
parser.add_argument("--ssl-certfile", type=str, default=None) | |
parser.add_argument("--ssl-ca-certs", | |
type=str, | |
default=None, | |
help="The CA certificates file") | |
parser.add_argument( | |
"--enable-ssl-refresh", | |
action="store_true", | |
default=False, | |
help="Refresh SSL Context when SSL certificate files change") | |
parser.add_argument( | |
"--ssl-cert-reqs", | |
type=int, | |
default=int(ssl.CERT_NONE), | |
help="Whether client certificate is required (see stdlib ssl module's)" | |
) | |
parser.add_argument( | |
"--root-path", | |
type=str, | |
default=None, | |
help="FastAPI root_path when app is behind a path based routing proxy") | |
parser.add_argument("--log-level", type=str, default="debug") | |
parser = AsyncEngineArgs.add_cli_args(parser) | |
args = parser.parse_args() | |
asyncio.run(run_server(args)) | |