omaryasserhassan commited on
Commit
947b2e3
·
verified ·
1 Parent(s): fc85eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -30
app.py CHANGED
@@ -4,7 +4,7 @@ from pydantic import BaseModel
4
  from huggingface_hub import snapshot_download
5
  from llama_cpp import Llama
6
 
7
- # ---------- pick a writable cache dir ----------
8
  def first_writable(paths):
9
  for p in paths:
10
  if not p:
@@ -21,9 +21,9 @@ def first_writable(paths):
21
  raise RuntimeError("No writable cache dir found")
22
 
23
  CACHE_BASE = first_writable([
24
- os.getenv("SPACE_CACHE_DIR"), # optional env override
25
- "/app/.cache", # WORKDIR is usually writable on Spaces
26
- "/tmp/app_cache", # always writable fallback
27
  ])
28
 
29
  HF_HOME = os.path.join(CACHE_BASE, "huggingface")
@@ -31,46 +31,93 @@ MODELS_DIR = os.path.join(CACHE_BASE, "models")
31
  os.makedirs(HF_HOME, exist_ok=True)
32
  os.makedirs(MODELS_DIR, exist_ok=True)
33
 
 
34
  os.environ["HF_HOME"] = HF_HOME
35
  os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub")
36
 
37
- # ---- Model selection (override in Settings → Variables if needed) ----
38
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
39
- MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf")
 
40
 
41
- # Inference knobs
42
- N_CTX = int(os.getenv("N_CTX", 2048))
43
- N_BATCH = int(os.getenv("N_BATCH", 64))
44
  N_THREADS = os.cpu_count() or 2
45
 
46
  app = FastAPI(title="Qwen Planner API (CPU)")
47
 
48
  llm = None
49
  model_loaded = False
 
 
50
 
51
  def ensure_model():
52
- global llm, model_loaded
 
 
 
 
 
 
53
  if llm is not None:
54
  return
55
- local_dir = snapshot_download(
56
- repo_id=MODEL_REPO,
57
- allow_patterns=[MODEL_FILE],
58
- local_dir=MODELS_DIR,
59
- local_dir_use_symlinks=False,
60
- )
61
- model_path = os.path.join(local_dir, MODEL_FILE)
62
- llm = Llama(
63
- model_path=model_path,
64
- n_ctx=N_CTX,
65
- n_threads=N_THREADS,
66
- n_batch=N_BATCH,
67
- logits_all=False,
68
- )
69
- model_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  @app.get("/healthz")
72
  def healthz():
73
- return {"status": "ok", "loaded": model_loaded, "cache_base": CACHE_BASE, "model_file": MODEL_FILE}
 
 
 
 
 
 
 
 
74
 
75
  SYSTEM_PROMPT = "You are a concise assistant. Reply briefly in plain text."
76
 
@@ -79,7 +126,25 @@ class ChatReq(BaseModel):
79
 
80
  @app.post("/chat")
81
  def chat(req: ChatReq):
82
- ensure_model() # lazy load on first call
83
- full_prompt = f"<|system|>\n{SYSTEM_PROMPT}\n</|system|>\n<|user|>\n{req.prompt}\n</|user|>\n"
84
- out = llm(prompt=full_prompt, temperature=0.2, top_p=0.9, max_tokens=256, stop=["</s>"])
85
- return {"response": out["choices"][0]["text"].strip()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from huggingface_hub import snapshot_download
5
  from llama_cpp import Llama
6
 
7
+ # ---------- pick a writable cache dir (tries in order) ----------
8
  def first_writable(paths):
9
  for p in paths:
10
  if not p:
 
21
  raise RuntimeError("No writable cache dir found")
22
 
23
  CACHE_BASE = first_writable([
24
+ os.getenv("SPACE_CACHE_DIR"), # optional override via Settings → Variables
25
+ "/app/.cache", # WORKDIR is usually writable on HF Spaces
26
+ "/tmp/app_cache", # safe fallback
27
  ])
28
 
29
  HF_HOME = os.path.join(CACHE_BASE, "huggingface")
 
31
  os.makedirs(HF_HOME, exist_ok=True)
32
  os.makedirs(MODELS_DIR, exist_ok=True)
33
 
34
+ # Tell huggingface_hub to cache under our writable dir
35
  os.environ["HF_HOME"] = HF_HOME
36
  os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub")
37
 
38
+ # ---- Model selection (can be overridden in Settings → Variables) ----
39
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
40
+ MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf") # optional hint
41
+ MODEL_REV = os.getenv("MODEL_REV") # optional: pin a commit SHA
42
 
43
+ # Inference knobs (reduce if memory tight: N_CTX=1024, N_BATCH=32)
44
+ N_CTX = int(os.getenv("N_CTX", 2048))
45
+ N_BATCH = int(os.getenv("N_BATCH", 64))
46
  N_THREADS = os.cpu_count() or 2
47
 
48
  app = FastAPI(title="Qwen Planner API (CPU)")
49
 
50
  llm = None
51
  model_loaded = False
52
+ chosen_model_path = None # for /healthz reporting
53
+
54
 
55
  def ensure_model():
56
+ """
57
+ Lazy-load the model. Downloads any .gguf if needed, then auto-selects one:
58
+ 1) exact MODEL_FILE if present,
59
+ 2) else a *q4*.gguf,
60
+ 3) else the first .gguf found.
61
+ """
62
+ global llm, model_loaded, chosen_model_path
63
  if llm is not None:
64
  return
65
+ try:
66
+ local_dir = snapshot_download(
67
+ repo_id=MODEL_REPO,
68
+ revision=MODEL_REV,
69
+ allow_patterns=["*.gguf"], # be flexible on filenames
70
+ local_dir=MODELS_DIR,
71
+ local_dir_use_symlinks=False,
72
+ )
73
+
74
+ # discover gguf files
75
+ ggufs = []
76
+ for root, _, files in os.walk(local_dir):
77
+ for f in files:
78
+ if f.endswith(".gguf"):
79
+ ggufs.append(os.path.join(root, f))
80
+ if not ggufs:
81
+ raise FileNotFoundError("No .gguf files found after download.")
82
+
83
+ # selection logic
84
+ model_path = None
85
+ if MODEL_FILE:
86
+ cand = os.path.join(local_dir, MODEL_FILE)
87
+ if os.path.exists(cand):
88
+ model_path = cand
89
+ if model_path is None:
90
+ q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
91
+ model_path = (q4 or ggufs)[0]
92
+
93
+ chosen_model_path = model_path
94
+ print(f"[loader] Using GGUF: {model_path}")
95
+
96
+ llm = Llama(
97
+ model_path=model_path,
98
+ n_ctx=N_CTX,
99
+ n_threads=N_THREADS,
100
+ n_batch=N_BATCH,
101
+ logits_all=False,
102
+ )
103
+ model_loaded = True
104
+
105
+ except Exception as e:
106
+ # surface a clear error to the HTTP layer
107
+ raise RuntimeError(f"ensure_model failed: {e}")
108
+
109
 
110
  @app.get("/healthz")
111
  def healthz():
112
+ return {
113
+ "status": "ok",
114
+ "loaded": model_loaded,
115
+ "cache_base": CACHE_BASE,
116
+ "model_repo": MODEL_REPO,
117
+ "model_file_hint": MODEL_FILE,
118
+ "chosen_model_path": chosen_model_path,
119
+ }
120
+
121
 
122
  SYSTEM_PROMPT = "You are a concise assistant. Reply briefly in plain text."
123
 
 
126
 
127
  @app.post("/chat")
128
  def chat(req: ChatReq):
129
+ # Load (or reuse) model
130
+ try:
131
+ ensure_model() # may take minutes on first-ever call
132
+ except Exception as e:
133
+ raise HTTPException(status_code=503, detail=f"loading_error: {e}")
134
+
135
+ # Inference
136
+ try:
137
+ full_prompt = (
138
+ f"<|system|>\n{SYSTEM_PROMPT}\n</|system|>\n"
139
+ f"<|user|>\n{req.prompt}\n</|user|>\n"
140
+ )
141
+ out = llm(
142
+ prompt=full_prompt,
143
+ temperature=0.2,
144
+ top_p=0.9,
145
+ max_tokens=256,
146
+ stop=["</s>"],
147
+ )
148
+ return {"response": out["choices"][0]["text"].strip()}
149
+ except Exception as e:
150
+ raise HTTPException(status_code=500, detail=f"infer_error: {e}")