Leon4gr45 commited on
Commit
0442491
·
verified ·
1 Parent(s): fb40632

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from typing import List
6
+ from contextlib import asynccontextmanager
7
+
8
+ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
9
+
10
+ # --- Configuration ---
11
+ MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Llama-3.1-8B-Instruct")
12
+ engine = None
13
+
14
+ # --- Lifespan Manager for Model Loading ---
15
+ # This is the correct way to load a model on startup in FastAPI.
16
+ @asynccontextmanager
17
+ async def lifespan(app: FastAPI):
18
+ global engine
19
+ print(f"Lifespan startup: Loading model {MODEL_ID}...")
20
+ engine_args = AsyncEngineArgs(
21
+ model=MODEL_ID,
22
+ tokenizer="hf-internal-testing/llama-tokenizer",
23
+ tensor_parallel_size=1,
24
+ gpu_memory_utilization=0.90,
25
+ download_dir="/data/huggingface" # Cache directory inside the container
26
+ )
27
+ engine = AsyncLLMEngine.from_engine_args(engine_args)
28
+ print("Model loading complete.")
29
+ yield
30
+ # Cleanup logic can be added here if needed
31
+
32
+ # 1. Create the FastAPI app instance FIRST
33
+ app = FastAPI(lifespan=lifespan)
34
+
35
+ # --- API Data Models ---
36
+ class ChatMessage(BaseModel):
37
+ role: str
38
+ content: str
39
+
40
+ class ChatCompletionRequest(BaseModel):
41
+ messages: List[ChatMessage]
42
+ model: str = MODEL_ID
43
+ temperature: float = 0.7
44
+ max_tokens: int = 1024
45
+
46
+ # 2. Define the API endpoint on the FastAPI `app` object
47
+ @app.post("/v1/chat/completions")
48
+ async def chat_completions(request: ChatCompletionRequest):
49
+ if not engine:
50
+ return {"error": "Model is not ready or has failed to load."}, 503
51
+
52
+ user_prompt = request.messages[-1].content
53
+ sampling_params = SamplingParams(temperature=request.temperature, max_tokens=request.max_tokens)
54
+ request_id = f"api-{os.urandom(4).hex()}"
55
+
56
+ results_generator = engine.generate(user_prompt, sampling_params, request_id)
57
+ final_output = await results_generator
58
+
59
+ return {
60
+ "choices": [{"message": {"role": "assistant", "content": final_output.outputs[0].text}}]
61
+ }
62
+
63
+ # 3. Create the Gradio UI in a separate object
64
+ async def gradio_predict(prompt: str):
65
+ if not engine:
66
+ yield "Model is not ready. Please wait a few moments after startup."
67
+ return
68
+
69
+ sampling_params = SamplingParams(temperature=0.7, max_tokens=1024)
70
+ stream = engine.generate(prompt, sampling_params, f"gradio-req-{os.urandom(4).hex()}")
71
+
72
+ async for result in stream:
73
+ yield result.outputs[0].text
74
+
75
+ gradio_ui = gr.Blocks()
76
+ with gradio_ui:
77
+ gr.Markdown(f"# VLLM Server for {MODEL_ID}")
78
+ gr.Markdown("This UI and the `/v1/chat/completions` API are served from the same container.")
79
+ with gr.Row():
80
+ inp = gr.Textbox(lines=4, label="Input")
81
+ out = gr.Textbox(lines=10, label="Output", interactive=False)
82
+ btn = gr.Button("Generate")
83
+ btn.click(fn=gradio_predict, inputs=inp, outputs=out)
84
+
85
+ # 4. Mount the Gradio UI onto the FastAPI app at the root path
86
+ app = gr.mount_gradio_app(app, gradio_ui, path="/")