Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|
4 |
from huggingface_hub import snapshot_download
|
5 |
from llama_cpp import Llama
|
6 |
|
7 |
-
# ---------- pick a writable cache dir ----------
|
8 |
def first_writable(paths):
|
9 |
for p in paths:
|
10 |
if not p:
|
@@ -21,9 +21,9 @@ def first_writable(paths):
|
|
21 |
raise RuntimeError("No writable cache dir found")
|
22 |
|
23 |
CACHE_BASE = first_writable([
|
24 |
-
os.getenv("SPACE_CACHE_DIR"), # optional
|
25 |
-
"/app/.cache", # WORKDIR is usually writable on Spaces
|
26 |
-
"/tmp/app_cache", #
|
27 |
])
|
28 |
|
29 |
HF_HOME = os.path.join(CACHE_BASE, "huggingface")
|
@@ -31,46 +31,93 @@ MODELS_DIR = os.path.join(CACHE_BASE, "models")
|
|
31 |
os.makedirs(HF_HOME, exist_ok=True)
|
32 |
os.makedirs(MODELS_DIR, exist_ok=True)
|
33 |
|
|
|
34 |
os.environ["HF_HOME"] = HF_HOME
|
35 |
os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub")
|
36 |
|
37 |
-
# ---- Model selection (
|
38 |
MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
|
39 |
-
MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf")
|
|
|
40 |
|
41 |
-
# Inference knobs
|
42 |
-
N_CTX
|
43 |
-
N_BATCH
|
44 |
N_THREADS = os.cpu_count() or 2
|
45 |
|
46 |
app = FastAPI(title="Qwen Planner API (CPU)")
|
47 |
|
48 |
llm = None
|
49 |
model_loaded = False
|
|
|
|
|
50 |
|
51 |
def ensure_model():
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
if llm is not None:
|
54 |
return
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
@app.get("/healthz")
|
72 |
def healthz():
|
73 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
SYSTEM_PROMPT = "You are a concise assistant. Reply briefly in plain text."
|
76 |
|
@@ -79,7 +126,25 @@ class ChatReq(BaseModel):
|
|
79 |
|
80 |
@app.post("/chat")
|
81 |
def chat(req: ChatReq):
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from huggingface_hub import snapshot_download
|
5 |
from llama_cpp import Llama
|
6 |
|
7 |
+
# ---------- pick a writable cache dir (tries in order) ----------
|
8 |
def first_writable(paths):
|
9 |
for p in paths:
|
10 |
if not p:
|
|
|
21 |
raise RuntimeError("No writable cache dir found")
|
22 |
|
23 |
CACHE_BASE = first_writable([
|
24 |
+
os.getenv("SPACE_CACHE_DIR"), # optional override via Settings → Variables
|
25 |
+
"/app/.cache", # WORKDIR is usually writable on HF Spaces
|
26 |
+
"/tmp/app_cache", # safe fallback
|
27 |
])
|
28 |
|
29 |
HF_HOME = os.path.join(CACHE_BASE, "huggingface")
|
|
|
31 |
os.makedirs(HF_HOME, exist_ok=True)
|
32 |
os.makedirs(MODELS_DIR, exist_ok=True)
|
33 |
|
34 |
+
# Tell huggingface_hub to cache under our writable dir
|
35 |
os.environ["HF_HOME"] = HF_HOME
|
36 |
os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub")
|
37 |
|
38 |
+
# ---- Model selection (can be overridden in Settings → Variables) ----
|
39 |
MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
|
40 |
+
MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf") # optional hint
|
41 |
+
MODEL_REV = os.getenv("MODEL_REV") # optional: pin a commit SHA
|
42 |
|
43 |
+
# Inference knobs (reduce if memory tight: N_CTX=1024, N_BATCH=32)
|
44 |
+
N_CTX = int(os.getenv("N_CTX", 2048))
|
45 |
+
N_BATCH = int(os.getenv("N_BATCH", 64))
|
46 |
N_THREADS = os.cpu_count() or 2
|
47 |
|
48 |
app = FastAPI(title="Qwen Planner API (CPU)")
|
49 |
|
50 |
llm = None
|
51 |
model_loaded = False
|
52 |
+
chosen_model_path = None # for /healthz reporting
|
53 |
+
|
54 |
|
55 |
def ensure_model():
|
56 |
+
"""
|
57 |
+
Lazy-load the model. Downloads any .gguf if needed, then auto-selects one:
|
58 |
+
1) exact MODEL_FILE if present,
|
59 |
+
2) else a *q4*.gguf,
|
60 |
+
3) else the first .gguf found.
|
61 |
+
"""
|
62 |
+
global llm, model_loaded, chosen_model_path
|
63 |
if llm is not None:
|
64 |
return
|
65 |
+
try:
|
66 |
+
local_dir = snapshot_download(
|
67 |
+
repo_id=MODEL_REPO,
|
68 |
+
revision=MODEL_REV,
|
69 |
+
allow_patterns=["*.gguf"], # be flexible on filenames
|
70 |
+
local_dir=MODELS_DIR,
|
71 |
+
local_dir_use_symlinks=False,
|
72 |
+
)
|
73 |
+
|
74 |
+
# discover gguf files
|
75 |
+
ggufs = []
|
76 |
+
for root, _, files in os.walk(local_dir):
|
77 |
+
for f in files:
|
78 |
+
if f.endswith(".gguf"):
|
79 |
+
ggufs.append(os.path.join(root, f))
|
80 |
+
if not ggufs:
|
81 |
+
raise FileNotFoundError("No .gguf files found after download.")
|
82 |
+
|
83 |
+
# selection logic
|
84 |
+
model_path = None
|
85 |
+
if MODEL_FILE:
|
86 |
+
cand = os.path.join(local_dir, MODEL_FILE)
|
87 |
+
if os.path.exists(cand):
|
88 |
+
model_path = cand
|
89 |
+
if model_path is None:
|
90 |
+
q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
|
91 |
+
model_path = (q4 or ggufs)[0]
|
92 |
+
|
93 |
+
chosen_model_path = model_path
|
94 |
+
print(f"[loader] Using GGUF: {model_path}")
|
95 |
+
|
96 |
+
llm = Llama(
|
97 |
+
model_path=model_path,
|
98 |
+
n_ctx=N_CTX,
|
99 |
+
n_threads=N_THREADS,
|
100 |
+
n_batch=N_BATCH,
|
101 |
+
logits_all=False,
|
102 |
+
)
|
103 |
+
model_loaded = True
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
# surface a clear error to the HTTP layer
|
107 |
+
raise RuntimeError(f"ensure_model failed: {e}")
|
108 |
+
|
109 |
|
110 |
@app.get("/healthz")
|
111 |
def healthz():
|
112 |
+
return {
|
113 |
+
"status": "ok",
|
114 |
+
"loaded": model_loaded,
|
115 |
+
"cache_base": CACHE_BASE,
|
116 |
+
"model_repo": MODEL_REPO,
|
117 |
+
"model_file_hint": MODEL_FILE,
|
118 |
+
"chosen_model_path": chosen_model_path,
|
119 |
+
}
|
120 |
+
|
121 |
|
122 |
SYSTEM_PROMPT = "You are a concise assistant. Reply briefly in plain text."
|
123 |
|
|
|
126 |
|
127 |
@app.post("/chat")
|
128 |
def chat(req: ChatReq):
|
129 |
+
# Load (or reuse) model
|
130 |
+
try:
|
131 |
+
ensure_model() # may take minutes on first-ever call
|
132 |
+
except Exception as e:
|
133 |
+
raise HTTPException(status_code=503, detail=f"loading_error: {e}")
|
134 |
+
|
135 |
+
# Inference
|
136 |
+
try:
|
137 |
+
full_prompt = (
|
138 |
+
f"<|system|>\n{SYSTEM_PROMPT}\n</|system|>\n"
|
139 |
+
f"<|user|>\n{req.prompt}\n</|user|>\n"
|
140 |
+
)
|
141 |
+
out = llm(
|
142 |
+
prompt=full_prompt,
|
143 |
+
temperature=0.2,
|
144 |
+
top_p=0.9,
|
145 |
+
max_tokens=256,
|
146 |
+
stop=["</s>"],
|
147 |
+
)
|
148 |
+
return {"response": out["choices"][0]["text"].strip()}
|
149 |
+
except Exception as e:
|
150 |
+
raise HTTPException(status_code=500, detail=f"infer_error: {e}")
|