import os import gradio as gr from fastapi import FastAPI from pydantic import BaseModel from typing import List from contextlib import asynccontextmanager from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams # --- Configuration --- MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Llama-3.1-8B-Instruct") engine = None # --- Lifespan Manager for Model Loading --- @asynccontextmanager async def lifespan(app: FastAPI): global engine print(f"Lifespan startup: Loading model {MODEL_ID}...") engine_args = AsyncEngineArgs( model=MODEL_ID, tokenizer="hf-internal-testing/llama-tokenizer", tensor_parallel_size=1, gpu_memory_utilization=0.90, download_dir="/data/huggingface" ) engine = AsyncLLMEngine.from_engine_args(engine_args) print("Model loading complete.") yield # 1. Create the FastAPI app instance app = FastAPI(lifespan=lifespan) # --- API Data Models --- class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] model: str = MODEL_ID temperature: float = 0.7 max_tokens: int = 1024 # 2. Define the API endpoint on the FastAPI `app` object @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): if not engine: return {"error": "Model is not ready or has failed to load."}, 503 user_prompt = request.messages[-1].content sampling_params = SamplingParams(temperature=request.temperature, max_tokens=request.max_tokens) request_id = f"api-{os.urandom(4).hex()}" results_generator = engine.generate(user_prompt, sampling_params, request_id) final_output = await results_generator return { "choices": [{"message": {"role": "assistant", "content": final_output.outputs[0].text}}] } # 3. Create the Gradio UI async def gradio_predict(prompt: str): if not engine: yield "Model is not ready. Please wait a few moments after startup." return sampling_params = SamplingParams(temperature=0.7, max_tokens=1024) stream = engine.generate(prompt, sampling_params, f"gradio-req-{os.urandom(4).hex()}") async for result in stream: yield result.outputs[0].text gradio_ui = gr.Blocks() with gradio_ui: gr.Markdown(f"# VLLM Server for {MODEL_ID}") gr.Markdown("This UI and the `/v1/chat/completions` API are served from the same container.") with gr.Row(): inp = gr.Textbox(lines=4, label="Input") out = gr.Textbox(lines=10, label="Output", interactive=False) btn = gr.Button("Generate") btn.click(fn=gradio_predict, inputs=inp, outputs=out) # 4. Mount the Gradio UI onto the FastAPI app app = gr.mount_gradio_app(app, gradio_ui, path="/")