Spaces:
Paused
Paused
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py | |
#!/usr/bin/env python | |
import asyncio | |
import base64 | |
import io | |
import logging | |
import signal | |
from http import HTTPStatus | |
from PIL import Image | |
from typing import Optional | |
import click | |
import uvicorn | |
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse, Response | |
from tensorrt_llm.executor import CppExecutorError, RequestError | |
from dolphin_runner import DolphinRunner, InferenceConfig | |
TIMEOUT_KEEP_ALIVE = 5 # seconds. | |
async def decode_image(image_base64: str) -> Image.Image: | |
image_data = base64.b64decode(image_base64) | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
class LlmServer: | |
def __init__(self, runner: DolphinRunner): | |
self.runner = runner | |
self.app = FastAPI() | |
self.register_routes() | |
def register_routes(self): | |
self.app.add_api_route("/health", self.health, methods=["GET"]) | |
self.app.add_api_route("/generate", self.generate, methods=["POST"]) | |
async def health(self) -> Response: | |
return Response(status_code=200) | |
async def generate(self, request: Request) -> Response: | |
""" Generate completion for the request. | |
The request should be a JSON object with the following fields: | |
- prompt: the prompt to use for the generation. | |
- image_base64: the image to use for the generation. | |
""" | |
request_dict = await request.json() | |
prompt = request_dict.pop("prompt", "") | |
logging.info(f"request prompt: {prompt}") | |
image_base64 = request_dict.pop("image_base64", "") | |
image = await decode_image(image_base64) | |
try: | |
output_texts = self.runner.run([prompt], [image], 4024) | |
output_texts = [texts[0] for texts in output_texts] | |
return JSONResponse({"text": output_texts[0]}) | |
except RequestError as e: | |
return JSONResponse(content=str(e), | |
status_code=HTTPStatus.BAD_REQUEST) | |
except CppExecutorError: | |
# If internal executor error is raised, shutdown the server | |
signal.raise_signal(signal.SIGINT) | |
async def __call__(self, host, port): | |
config = uvicorn.Config(self.app, | |
host=host, | |
port=port, | |
log_level="info", | |
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) | |
await uvicorn.Server(config).serve() | |
def entrypoint(hf_model_dir: str, | |
visual_engine_dir: str, | |
llm_engine_dir: str, | |
max_batch_size: int, | |
max_new_tokens: int, | |
host: Optional[str] = None, | |
port: int = 8000): | |
host = host or "0.0.0.0" | |
port = port or 8000 | |
logging.info(f"Starting server at {host}:{port}") | |
config = InferenceConfig( | |
max_new_tokens=max_new_tokens, | |
batch_size=max_batch_size, | |
log_level="info", | |
hf_model_dir=hf_model_dir, | |
visual_engine_dir=visual_engine_dir, | |
llm_engine_dir=llm_engine_dir, | |
) | |
dolphin_runner = DolphinRunner(config) | |
server = LlmServer(runner=dolphin_runner) | |
asyncio.run(server(host, port)) | |
if __name__ == "__main__": | |
entrypoint() |