Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from typing import List | |
from contextlib import asynccontextmanager | |
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams | |
# --- Configuration --- | |
MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Llama-3.1-8B-Instruct") | |
engine = None | |
# --- Lifespan Manager for Model Loading --- | |
async def lifespan(app: FastAPI): | |
global engine | |
print(f"Lifespan startup: Loading model {MODEL_ID}...") | |
engine_args = AsyncEngineArgs( | |
model=MODEL_ID, | |
tokenizer="hf-internal-testing/llama-tokenizer", | |
tensor_parallel_size=1, | |
gpu_memory_utilization=0.90, | |
download_dir="/data/huggingface" | |
) | |
engine = AsyncLLMEngine.from_engine_args(engine_args) | |
print("Model loading complete.") | |
yield | |
# 1. Create the FastAPI app instance | |
app = FastAPI(lifespan=lifespan) | |
# --- API Data Models --- | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatCompletionRequest(BaseModel): | |
messages: List[ChatMessage] | |
model: str = MODEL_ID | |
temperature: float = 0.7 | |
max_tokens: int = 1024 | |
# 2. Define the API endpoint on the FastAPI `app` object | |
async def chat_completions(request: ChatCompletionRequest): | |
if not engine: | |
return {"error": "Model is not ready or has failed to load."}, 503 | |
user_prompt = request.messages[-1].content | |
sampling_params = SamplingParams(temperature=request.temperature, max_tokens=request.max_tokens) | |
request_id = f"api-{os.urandom(4).hex()}" | |
results_generator = engine.generate(user_prompt, sampling_params, request_id) | |
final_output = await results_generator | |
return { | |
"choices": [{"message": {"role": "assistant", "content": final_output.outputs[0].text}}] | |
} | |
# 3. Create the Gradio UI | |
async def gradio_predict(prompt: str): | |
if not engine: | |
yield "Model is not ready. Please wait a few moments after startup." | |
return | |
sampling_params = SamplingParams(temperature=0.7, max_tokens=1024) | |
stream = engine.generate(prompt, sampling_params, f"gradio-req-{os.urandom(4).hex()}") | |
async for result in stream: | |
yield result.outputs[0].text | |
gradio_ui = gr.Blocks() | |
with gradio_ui: | |
gr.Markdown(f"# VLLM Server for {MODEL_ID}") | |
gr.Markdown("This UI and the `/v1/chat/completions` API are served from the same container.") | |
with gr.Row(): | |
inp = gr.Textbox(lines=4, label="Input") | |
out = gr.Textbox(lines=10, label="Output", interactive=False) | |
btn = gr.Button("Generate") | |
btn.click(fn=gradio_predict, inputs=inp, outputs=out) | |
# 4. Mount the Gradio UI onto the FastAPI app | |
app = gr.mount_gradio_app(app, gradio_ui, path="/") |