Spaces:
Running
Running
import os | |
from typing import List, Tuple, Optional | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from huggingface_hub import hf_hub_download | |
from llama_cpp import Llama | |
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
# Suppress warnings | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Ensure models directory exists | |
MODEL_DIR = "./models" | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
# Model info for download | |
MODELS_INFO = [ | |
{ | |
"repo_id": "bartowski/Dolphin3.0-Llama3.2-1B-GGUF", | |
"filename": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf" | |
}, | |
{ | |
"repo_id": "bartowski/Dolphin3.0-Qwen2.5-0.5B-GGUF", | |
"filename": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf" | |
}, | |
{ | |
"repo_id": "bartowski/Qwen2.5-Coder-14B-Instruct-GGUF", | |
"filename": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf" | |
} | |
] | |
# Download all models if not present | |
for model_info in MODELS_INFO: | |
model_path = os.path.join(MODEL_DIR, model_info["filename"]) | |
if not os.path.exists(model_path): | |
print(f"Downloading {model_info['filename']} from {model_info['repo_id']}...") | |
try: | |
hf_hub_download( | |
repo_id=model_info["repo_id"], | |
filename=model_info["filename"], | |
local_dir=MODEL_DIR | |
) | |
print(f"Downloaded {model_info['filename']}") | |
except Exception as e: | |
print(f"Error downloading {model_info['filename']}: {e}") | |
# Available model keys (used in API) | |
AVAILABLE_MODELS = { | |
"qwen": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf", | |
"llama": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf", | |
"coder": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf" | |
} | |
# Global LLM instance | |
llm = None | |
llm_model = None | |
def load_model(model_key: str): | |
global llm, llm_model | |
model_name = AVAILABLE_MODELS.get(model_key) | |
if not model_name: | |
raise ValueError(f"Invalid model key: {model_key}") | |
model_path = os.path.join(MODEL_DIR, model_name) | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at {model_path}") | |
if llm is None or llm_model != model_name: | |
llm = Llama( | |
model_path=model_path, | |
flash_attn=False, | |
n_gpu_layers=0, | |
n_batch=8, | |
n_ctx=2048, | |
n_threads=8, | |
n_threads_batch=8, | |
) | |
llm_model = model_name | |
return llm | |
class ChatRequest(BaseModel): | |
message: str # Required | |
history: Optional[List[Tuple[str, str]]] = [] # Default: empty list | |
model: Optional[str] = "qwen" # Default model key | |
system_prompt: Optional[str] = "You are Dolphin, a helpful AI assistant." | |
max_tokens: Optional[int] = 1024 | |
temperature: Optional[float] = 0.7 | |
top_p: Optional[float] = 0.95 | |
top_k: Optional[int] = 40 | |
repeat_penalty: Optional[float] = 1.1 | |
class ChatResponse(BaseModel): | |
response: str | |
class ModelInfoResponse(BaseModel): | |
models: List[str] | |
app = FastAPI( | |
title="Dolphin 3.0 LLM API", | |
description="REST API for Dolphin 3.0 models using Llama.cpp backend.", | |
version="1.0", | |
docs_url="/docs", # Only Swagger docs | |
redoc_url=None # Disable ReDoc | |
) | |
def get_available_models(): | |
"""Returns the list of supported models.""" | |
return {"models": list(AVAILABLE_MODELS.keys())} | |
def chat(request: ChatRequest): | |
try: | |
# Load model | |
load_model(request.model) | |
provider = LlamaCppPythonProvider(llm) | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=request.system_prompt, | |
predefined_messages_formatter_type=MessagesFormatterType.CHATML, | |
) | |
settings = provider.get_provider_default_settings() | |
settings.temperature = request.temperature | |
settings.top_k = request.top_k | |
settings.top_p = request.top_p | |
settings.max_tokens = request.max_tokens | |
settings.repeat_penalty = request.repeat_penalty | |
messages = BasicChatHistory() | |
# Add history | |
for user_msg, assistant_msg in request.history: | |
messages.add_message({"role": Roles.user, "content": user_msg}) | |
messages.add_message({"role": Roles.assistant, "content": assistant_msg}) | |
# Get response | |
response = agent.get_chat_response( | |
request.message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
print_output=False, | |
) | |
return {"response": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |