Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from huggingface_hub import hf_hub_download | |
from ctransformers import AutoModelForCausalLM | |
# ------------------------ | |
# Model configuration | |
# ------------------------ | |
REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF" | |
FILENAME = "Llama-3.2-3B-Instruct-Q4_K_L.gguf" | |
MODEL_TYPE = "llama" | |
# ------------------------ | |
# Persistent cache (Docker Spaces -> /data) | |
# ------------------------ | |
CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache") | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)") | |
_model = None | |
def get_model(): | |
global _model | |
if _model is not None: | |
return _model | |
# Download exact GGUF file to persistent cache | |
local_path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=FILENAME, | |
local_dir=CACHE_DIR, | |
local_dir_use_symlinks=False, | |
) | |
# Load with ctransformers (CPU by default) | |
_model = AutoModelForCausalLM.from_pretrained( | |
model_path_or_repo_id=os.path.dirname(local_path), | |
model_file=os.path.basename(local_path), | |
model_type=MODEL_TYPE, | |
gpu_layers=int(os.environ.get("GPU_LAYERS", "0")), # set >0 on GPU Spaces | |
context_length=int(os.environ.get("CTX_LEN", "4096")), | |
) | |
return _model | |
class GenerateIn(BaseModel): | |
prompt: str | |
max_new_tokens: int = 256 | |
temperature: float = 0.7 | |
top_p: float = 0.95 | |
top_k: int = 40 | |
repetition_penalty: float = 1.1 | |
stop: Optional[List[str]] = None | |
class GenerateOut(BaseModel): | |
completion: str | |
def health(): | |
return { | |
"status": "ok", | |
"model": {"repo_id": REPO_ID, "filename": FILENAME, "type": MODEL_TYPE}, | |
"cache_dir": CACHE_DIR, | |
"endpoints": {"POST /generate": "Generate a completion"}, | |
} | |
def generate(body: GenerateIn): | |
try: | |
model = get_model() | |
text = model( | |
body.prompt, | |
max_new_tokens=body.max_new_tokens, | |
temperature=body.temperature, | |
top_p=body.top_p, | |
top_k=body.top_k, | |
repetition_penalty=body.repetition_penalty, | |
stop=body.stop, | |
) | |
return GenerateOut(completion=text) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |