Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,14 +5,20 @@ from typing import List, Optional
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
from ctransformers import AutoModelForCausalLM
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
MODEL_TYPE = "llama"
|
11 |
|
12 |
-
|
13 |
CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
|
14 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
|
17 |
_model = None
|
18 |
|
@@ -32,12 +38,11 @@ def get_model():
|
|
32 |
model_path_or_repo_id=os.path.dirname(local_path),
|
33 |
model_file=os.path.basename(local_path),
|
34 |
model_type=MODEL_TYPE,
|
35 |
-
context_length=
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
)
|
42 |
return _model
|
43 |
|
@@ -55,7 +60,17 @@ class GenerateOut(BaseModel):
|
|
55 |
|
56 |
@app.get("/")
|
57 |
def health():
|
58 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
@app.post("/generate", response_model=GenerateOut)
|
61 |
def generate(body: GenerateIn):
|
@@ -72,4 +87,7 @@ def generate(body: GenerateIn):
|
|
72 |
)
|
73 |
return GenerateOut(completion=text)
|
74 |
except Exception as e:
|
|
|
|
|
|
|
75 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
6 |
from ctransformers import AutoModelForCausalLM
|
7 |
|
8 |
+
# --- Model & cache config ---
|
9 |
+
REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
|
10 |
+
FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" # safer quant than Q4_K_L
|
11 |
MODEL_TYPE = "llama"
|
12 |
|
|
|
13 |
CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
|
14 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
15 |
|
16 |
+
# Conservative defaults (override via env if needed)
|
17 |
+
CTX_LEN = int(os.environ.get("CTX_LEN", "1024"))
|
18 |
+
BATCH_SIZE = int(os.environ.get("BATCH", "16"))
|
19 |
+
THREADS = int(os.environ.get("THREADS", "4"))
|
20 |
+
GPU_LAYERS = int(os.environ.get("GPU_LAYERS", "0"))
|
21 |
+
|
22 |
app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
|
23 |
_model = None
|
24 |
|
|
|
38 |
model_path_or_repo_id=os.path.dirname(local_path),
|
39 |
model_file=os.path.basename(local_path),
|
40 |
model_type=MODEL_TYPE,
|
41 |
+
context_length=CTX_LEN,
|
42 |
+
batch_size=BATCH_SIZE,
|
43 |
+
threads=THREADS,
|
44 |
+
gpu_layers=GPU_LAYERS,
|
45 |
+
f16_kv=True,
|
|
|
46 |
)
|
47 |
return _model
|
48 |
|
|
|
60 |
|
61 |
@app.get("/")
|
62 |
def health():
|
63 |
+
return {
|
64 |
+
"status": "ok",
|
65 |
+
"cache_dir": CACHE_DIR,
|
66 |
+
"model": {"repo": REPO_ID, "file": FILENAME, "type": MODEL_TYPE},
|
67 |
+
"settings": {
|
68 |
+
"CTX_LEN": CTX_LEN,
|
69 |
+
"BATCH_SIZE": BATCH_SIZE,
|
70 |
+
"THREADS": THREADS,
|
71 |
+
"GPU_LAYERS": GPU_LAYERS,
|
72 |
+
},
|
73 |
+
}
|
74 |
|
75 |
@app.post("/generate", response_model=GenerateOut)
|
76 |
def generate(body: GenerateIn):
|
|
|
87 |
)
|
88 |
return GenerateOut(completion=text)
|
89 |
except Exception as e:
|
90 |
+
# helpful for debugging in Space logs
|
91 |
+
import sys, traceback
|
92 |
+
traceback.print_exc(file=sys.stderr)
|
93 |
raise HTTPException(status_code=500, detail=str(e))
|