Leon4gr45's picture
Create app.py
0442491 verified
raw
history blame
2.98 kB
import os
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from contextlib import asynccontextmanager
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
# --- Configuration ---
MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Llama-3.1-8B-Instruct")
engine = None
# --- Lifespan Manager for Model Loading ---
# This is the correct way to load a model on startup in FastAPI.
@asynccontextmanager
async def lifespan(app: FastAPI):
global engine
print(f"Lifespan startup: Loading model {MODEL_ID}...")
engine_args = AsyncEngineArgs(
model=MODEL_ID,
tokenizer="hf-internal-testing/llama-tokenizer",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
download_dir="/data/huggingface" # Cache directory inside the container
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
print("Model loading complete.")
yield
# Cleanup logic can be added here if needed
# 1. Create the FastAPI app instance FIRST
app = FastAPI(lifespan=lifespan)
# --- API Data Models ---
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
model: str = MODEL_ID
temperature: float = 0.7
max_tokens: int = 1024
# 2. Define the API endpoint on the FastAPI `app` object
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
if not engine:
return {"error": "Model is not ready or has failed to load."}, 503
user_prompt = request.messages[-1].content
sampling_params = SamplingParams(temperature=request.temperature, max_tokens=request.max_tokens)
request_id = f"api-{os.urandom(4).hex()}"
results_generator = engine.generate(user_prompt, sampling_params, request_id)
final_output = await results_generator
return {
"choices": [{"message": {"role": "assistant", "content": final_output.outputs[0].text}}]
}
# 3. Create the Gradio UI in a separate object
async def gradio_predict(prompt: str):
if not engine:
yield "Model is not ready. Please wait a few moments after startup."
return
sampling_params = SamplingParams(temperature=0.7, max_tokens=1024)
stream = engine.generate(prompt, sampling_params, f"gradio-req-{os.urandom(4).hex()}")
async for result in stream:
yield result.outputs[0].text
gradio_ui = gr.Blocks()
with gradio_ui:
gr.Markdown(f"# VLLM Server for {MODEL_ID}")
gr.Markdown("This UI and the `/v1/chat/completions` API are served from the same container.")
with gr.Row():
inp = gr.Textbox(lines=4, label="Input")
out = gr.Textbox(lines=10, label="Output", interactive=False)
btn = gr.Button("Generate")
btn.click(fn=gradio_predict, inputs=inp, outputs=out)
# 4. Mount the Gradio UI onto the FastAPI app at the root path
app = gr.mount_gradio_app(app, gradio_ui, path="/")