Leon4gr45's picture
Update app.py
12530ec verified
import os
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from contextlib import asynccontextmanager
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
# --- Configuration ---
MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Llama-3.1-8B-Instruct")
engine = None
# --- Lifespan Manager for Model Loading ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global engine
print(f"Lifespan startup: Loading model {MODEL_ID}...")
engine_args = AsyncEngineArgs(
model=MODEL_ID,
tokenizer="hf-internal-testing/llama-tokenizer",
tensor_parallel_size=1,
gpu_memory_utilization=0.90,
download_dir="/data/huggingface"
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
print("Model loading complete.")
yield
# 1. Create the FastAPI app instance
app = FastAPI(lifespan=lifespan)
# --- API Data Models ---
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
model: str = MODEL_ID
temperature: float = 0.7
max_tokens: int = 1024
# 2. Define the API endpoint on the FastAPI `app` object
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
if not engine:
return {"error": "Model is not ready or has failed to load."}, 503
user_prompt = request.messages[-1].content
sampling_params = SamplingParams(temperature=request.temperature, max_tokens=request.max_tokens)
request_id = f"api-{os.urandom(4).hex()}"
results_generator = engine.generate(user_prompt, sampling_params, request_id)
final_output = await results_generator
return {
"choices": [{"message": {"role": "assistant", "content": final_output.outputs[0].text}}]
}
# 3. Create the Gradio UI
async def gradio_predict(prompt: str):
if not engine:
yield "Model is not ready. Please wait a few moments after startup."
return
sampling_params = SamplingParams(temperature=0.7, max_tokens=1024)
stream = engine.generate(prompt, sampling_params, f"gradio-req-{os.urandom(4).hex()}")
async for result in stream:
yield result.outputs[0].text
gradio_ui = gr.Blocks()
with gradio_ui:
gr.Markdown(f"# VLLM Server for {MODEL_ID}")
gr.Markdown("This UI and the `/v1/chat/completions` API are served from the same container.")
with gr.Row():
inp = gr.Textbox(lines=4, label="Input")
out = gr.Textbox(lines=10, label="Output", interactive=False)
btn = gr.Button("Generate")
btn.click(fn=gradio_predict, inputs=inp, outputs=out)
# 4. Mount the Gradio UI onto the FastAPI app
app = gr.mount_gradio_app(app, gradio_ui, path="/")