Spaces:
Sleeping
Sleeping
File size: 2,720 Bytes
dd38c20 09102da dd38c20 f5558db dd38c20 f5558db dd38c20 e39db98 dd38c20 f5558db e39db98 dd38c20 fcea2ac dd38c20 f5558db 947b2e3 dd38c20 947b2e3 dd38c20 fcea2ac dd38c20 fcea2ac f5558db dd38c20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import hf_hub_download
from ctransformers import AutoModelForCausalLM
# --- Model & cache config ---
REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" # safer quant than Q4_K_L
MODEL_TYPE = "llama"
CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
os.makedirs(CACHE_DIR, exist_ok=True)
# Conservative defaults (override via env if needed)
CTX_LEN = int(os.environ.get("CTX_LEN", "1024"))
BATCH_SIZE = int(os.environ.get("BATCH", "16"))
THREADS = int(os.environ.get("THREADS", "4"))
GPU_LAYERS = int(os.environ.get("GPU_LAYERS", "0"))
app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
_model = None
def get_model():
global _model
if _model is not None:
return _model
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
local_dir=CACHE_DIR,
local_dir_use_symlinks=False,
)
_model = AutoModelForCausalLM.from_pretrained(
model_path_or_repo_id=os.path.dirname(local_path),
model_file=os.path.basename(local_path),
model_type=MODEL_TYPE,
context_length=CTX_LEN,
batch_size=BATCH_SIZE,
threads=THREADS,
gpu_layers=GPU_LAYERS,
f16_kv=True,
)
return _model
class GenerateIn(BaseModel):
prompt: str
max_new_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 40
repetition_penalty: float = 1.1
stop: Optional[List[str]] = None
class GenerateOut(BaseModel):
completion: str
@app.get("/")
def health():
return {
"status": "ok",
"cache_dir": CACHE_DIR,
"model": {"repo": REPO_ID, "file": FILENAME, "type": MODEL_TYPE},
"settings": {
"CTX_LEN": CTX_LEN,
"BATCH_SIZE": BATCH_SIZE,
"THREADS": THREADS,
"GPU_LAYERS": GPU_LAYERS,
},
}
@app.post("/generate", response_model=GenerateOut)
def generate(body: GenerateIn):
try:
model = get_model()
text = model(
body.prompt,
max_new_tokens=body.max_new_tokens,
temperature=body.temperature,
top_p=body.top_p,
top_k=body.top_k,
repetition_penalty=body.repetition_penalty,
stop=body.stop,
)
return GenerateOut(completion=text)
except Exception as e:
# helpful for debugging in Space logs
import sys, traceback
traceback.print_exc(file=sys.stderr)
raise HTTPException(status_code=500, detail=str(e))
|