Spaces:
Sleeping
Sleeping
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from huggingface_hub import hf_hub_download | |
from ctransformers import AutoModelForCausalLM | |
# --- Model & cache config --- | |
REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF" | |
FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" # safer quant than Q4_K_L | |
MODEL_TYPE = "llama" | |
CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache") | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Conservative defaults (override via env if needed) | |
CTX_LEN = int(os.environ.get("CTX_LEN", "1024")) | |
BATCH_SIZE = int(os.environ.get("BATCH", "16")) | |
THREADS = int(os.environ.get("THREADS", "4")) | |
GPU_LAYERS = int(os.environ.get("GPU_LAYERS", "0")) | |
app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)") | |
_model = None | |
def get_model(): | |
global _model | |
if _model is not None: | |
return _model | |
local_path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=FILENAME, | |
local_dir=CACHE_DIR, | |
local_dir_use_symlinks=False, | |
) | |
_model = AutoModelForCausalLM.from_pretrained( | |
model_path_or_repo_id=os.path.dirname(local_path), | |
model_file=os.path.basename(local_path), | |
model_type=MODEL_TYPE, | |
context_length=CTX_LEN, | |
batch_size=BATCH_SIZE, | |
threads=THREADS, | |
gpu_layers=GPU_LAYERS, | |
f16_kv=True, | |
) | |
return _model | |
class GenerateIn(BaseModel): | |
prompt: str | |
max_new_tokens: int = 256 | |
temperature: float = 0.7 | |
top_p: float = 0.95 | |
top_k: int = 40 | |
repetition_penalty: float = 1.1 | |
stop: Optional[List[str]] = None | |
class GenerateOut(BaseModel): | |
completion: str | |
def health(): | |
return { | |
"status": "ok", | |
"cache_dir": CACHE_DIR, | |
"model": {"repo": REPO_ID, "file": FILENAME, "type": MODEL_TYPE}, | |
"settings": { | |
"CTX_LEN": CTX_LEN, | |
"BATCH_SIZE": BATCH_SIZE, | |
"THREADS": THREADS, | |
"GPU_LAYERS": GPU_LAYERS, | |
}, | |
} | |
def generate(body: GenerateIn): | |
try: | |
model = get_model() | |
text = model( | |
body.prompt, | |
max_new_tokens=body.max_new_tokens, | |
temperature=body.temperature, | |
top_p=body.top_p, | |
top_k=body.top_k, | |
repetition_penalty=body.repetition_penalty, | |
stop=body.stop, | |
) | |
return GenerateOut(completion=text) | |
except Exception as e: | |
# helpful for debugging in Space logs | |
import sys, traceback | |
traceback.print_exc(file=sys.stderr) | |
raise HTTPException(status_code=500, detail=str(e)) | |