llmapi2 / app.py
mojaalagevai's picture
Update app.py
61d4f38 verified
import os
from typing import List, Tuple, Optional
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")
# Ensure models directory exists
MODEL_DIR = "./models"
os.makedirs(MODEL_DIR, exist_ok=True)
# Model info for download
MODELS_INFO = [
{
"repo_id": "bartowski/Dolphin3.0-Llama3.2-1B-GGUF",
"filename": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf"
},
{
"repo_id": "bartowski/Dolphin3.0-Qwen2.5-0.5B-GGUF",
"filename": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf"
},
{
"repo_id": "bartowski/Qwen2.5-Coder-14B-Instruct-GGUF",
"filename": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf"
}
]
# Download all models if not present
for model_info in MODELS_INFO:
model_path = os.path.join(MODEL_DIR, model_info["filename"])
if not os.path.exists(model_path):
print(f"Downloading {model_info['filename']} from {model_info['repo_id']}...")
try:
hf_hub_download(
repo_id=model_info["repo_id"],
filename=model_info["filename"],
local_dir=MODEL_DIR
)
print(f"Downloaded {model_info['filename']}")
except Exception as e:
print(f"Error downloading {model_info['filename']}: {e}")
# Available model keys (used in API)
AVAILABLE_MODELS = {
"qwen": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf",
"llama": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf",
"coder": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf"
}
# Global LLM instance
llm = None
llm_model = None
def load_model(model_key: str):
global llm, llm_model
model_name = AVAILABLE_MODELS.get(model_key)
if not model_name:
raise ValueError(f"Invalid model key: {model_key}")
model_path = os.path.join(MODEL_DIR, model_name)
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
if llm is None or llm_model != model_name:
llm = Llama(
model_path=model_path,
flash_attn=False,
n_gpu_layers=0,
n_batch=8,
n_ctx=2048,
n_threads=8,
n_threads_batch=8,
)
llm_model = model_name
return llm
class ChatRequest(BaseModel):
message: str # Required
history: Optional[List[Tuple[str, str]]] = [] # Default: empty list
model: Optional[str] = "qwen" # Default model key
system_prompt: Optional[str] = "You are Dolphin, a helpful AI assistant."
max_tokens: Optional[int] = 1024
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.95
top_k: Optional[int] = 40
repeat_penalty: Optional[float] = 1.1
class ChatResponse(BaseModel):
response: str
class ModelInfoResponse(BaseModel):
models: List[str]
app = FastAPI(
title="Dolphin 3.0 LLM API",
description="REST API for Dolphin 3.0 models using Llama.cpp backend.",
version="1.0",
docs_url="/docs", # Only Swagger docs
redoc_url=None # Disable ReDoc
)
@app.get("/models", response_model=ModelInfoResponse)
def get_available_models():
"""Returns the list of supported models."""
return {"models": list(AVAILABLE_MODELS.keys())}
@app.post("/chat", response_model=ChatResponse)
def chat(request: ChatRequest):
try:
# Load model
load_model(request.model)
provider = LlamaCppPythonProvider(llm)
agent = LlamaCppAgent(
provider,
system_prompt=request.system_prompt,
predefined_messages_formatter_type=MessagesFormatterType.CHATML,
)
settings = provider.get_provider_default_settings()
settings.temperature = request.temperature
settings.top_k = request.top_k
settings.top_p = request.top_p
settings.max_tokens = request.max_tokens
settings.repeat_penalty = request.repeat_penalty
messages = BasicChatHistory()
# Add history
for user_msg, assistant_msg in request.history:
messages.add_message({"role": Roles.user, "content": user_msg})
messages.add_message({"role": Roles.assistant, "content": assistant_msg})
# Get response
response = agent.get_chat_response(
request.message,
llm_sampling_settings=settings,
chat_history=messages,
print_output=False,
)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)