File size: 2,720 Bytes
dd38c20
 
09102da
dd38c20
 
 
 
f5558db
 
 
dd38c20
 
 
 
 
f5558db
 
 
 
 
 
dd38c20
 
 
 
 
 
 
 
 
 
 
 
 
 
e39db98
dd38c20
 
 
 
f5558db
 
 
 
 
e39db98
dd38c20
fcea2ac
dd38c20
 
 
 
 
 
 
 
 
 
 
 
 
 
f5558db
 
 
 
 
 
 
 
 
 
 
947b2e3
dd38c20
 
947b2e3
dd38c20
 
 
 
 
 
 
 
 
fcea2ac
dd38c20
fcea2ac
f5558db
 
 
dd38c20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from huggingface_hub import hf_hub_download
from ctransformers import AutoModelForCausalLM

# --- Model & cache config ---
REPO_ID   = "bartowski/Llama-3.2-3B-Instruct-GGUF"
FILENAME  = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"  # safer quant than Q4_K_L
MODEL_TYPE = "llama"

CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
os.makedirs(CACHE_DIR, exist_ok=True)

# Conservative defaults (override via env if needed)
CTX_LEN     = int(os.environ.get("CTX_LEN", "1024"))
BATCH_SIZE  = int(os.environ.get("BATCH", "16"))
THREADS     = int(os.environ.get("THREADS", "4"))
GPU_LAYERS  = int(os.environ.get("GPU_LAYERS", "0"))

app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
_model = None

def get_model():
    global _model
    if _model is not None:
        return _model

    local_path = hf_hub_download(
        repo_id=REPO_ID,
        filename=FILENAME,
        local_dir=CACHE_DIR,
        local_dir_use_symlinks=False,
    )

    _model = AutoModelForCausalLM.from_pretrained(
        model_path_or_repo_id=os.path.dirname(local_path),
        model_file=os.path.basename(local_path),
        model_type=MODEL_TYPE,
        context_length=CTX_LEN,
        batch_size=BATCH_SIZE,
        threads=THREADS,
        gpu_layers=GPU_LAYERS,
        f16_kv=True,
    )
    return _model

class GenerateIn(BaseModel):
    prompt: str
    max_new_tokens: int = 256
    temperature: float = 0.7
    top_p: float = 0.95
    top_k: int = 40
    repetition_penalty: float = 1.1
    stop: Optional[List[str]] = None

class GenerateOut(BaseModel):
    completion: str

@app.get("/")
def health():
    return {
        "status": "ok",
        "cache_dir": CACHE_DIR,
        "model": {"repo": REPO_ID, "file": FILENAME, "type": MODEL_TYPE},
        "settings": {
            "CTX_LEN": CTX_LEN,
            "BATCH_SIZE": BATCH_SIZE,
            "THREADS": THREADS,
            "GPU_LAYERS": GPU_LAYERS,
        },
    }

@app.post("/generate", response_model=GenerateOut)
def generate(body: GenerateIn):
    try:
        model = get_model()
        text = model(
            body.prompt,
            max_new_tokens=body.max_new_tokens,
            temperature=body.temperature,
            top_p=body.top_p,
            top_k=body.top_k,
            repetition_penalty=body.repetition_penalty,
            stop=body.stop,
        )
        return GenerateOut(completion=text)
    except Exception as e:
        # helpful for debugging in Space logs
        import sys, traceback
        traceback.print_exc(file=sys.stderr)
        raise HTTPException(status_code=500, detail=str(e))