omaryasserhassan commited on
Commit
fcea2ac
·
verified ·
1 Parent(s): 947b2e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -53
app.py CHANGED
@@ -1,11 +1,12 @@
1
- import os, json, re
2
- from fastapi import FastAPI, HTTPException
 
3
  from pydantic import BaseModel
4
- from huggingface_hub import snapshot_download
5
- from llama_cpp import Llama
6
 
7
- # ---------- pick a writable cache dir (tries in order) ----------
8
- def first_writable(paths):
 
 
9
  for p in paths:
10
  if not p:
11
  continue
@@ -18,46 +19,64 @@ def first_writable(paths):
18
  return p
19
  except Exception:
20
  continue
21
- raise RuntimeError("No writable cache dir found")
 
 
 
22
 
23
  CACHE_BASE = first_writable([
24
- os.getenv("SPACE_CACHE_DIR"), # optional override via Settings → Variables
25
- "/app/.cache", # WORKDIR is usually writable on HF Spaces
26
- "/tmp/app_cache", # safe fallback
 
27
  ])
28
 
29
  HF_HOME = os.path.join(CACHE_BASE, "huggingface")
30
- MODELS_DIR = os.path.join(CACHE_BASE, "models")
 
 
31
  os.makedirs(HF_HOME, exist_ok=True)
 
32
  os.makedirs(MODELS_DIR, exist_ok=True)
33
 
34
- # Tell huggingface_hub to cache under our writable dir
35
- os.environ["HF_HOME"] = HF_HOME
36
- os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub")
37
 
38
- # ---- Model selection (can be overridden in Settings → Variables) ----
 
 
39
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
40
- MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf") # optional hint
41
- MODEL_REV = os.getenv("MODEL_REV") # optional: pin a commit SHA
42
 
43
- # Inference knobs (reduce if memory tight: N_CTX=1024, N_BATCH=32)
44
- N_CTX = int(os.getenv("N_CTX", 2048))
45
- N_BATCH = int(os.getenv("N_BATCH", 64))
46
- N_THREADS = os.cpu_count() or 2
47
 
48
- app = FastAPI(title="Qwen Planner API (CPU)")
 
49
 
50
- llm = None
51
- model_loaded = False
52
- chosen_model_path = None # for /healthz reporting
 
53
 
 
 
 
54
 
55
- def ensure_model():
 
 
 
56
  """
57
  Lazy-load the model. Downloads any .gguf if needed, then auto-selects one:
58
- 1) exact MODEL_FILE if present,
59
- 2) else a *q4*.gguf,
60
- 3) else the first .gguf found.
 
61
  """
62
  global llm, model_loaded, chosen_model_path
63
  if llm is not None:
@@ -66,13 +85,13 @@ def ensure_model():
66
  local_dir = snapshot_download(
67
  repo_id=MODEL_REPO,
68
  revision=MODEL_REV,
69
- allow_patterns=["*.gguf"], # be flexible on filenames
70
  local_dir=MODELS_DIR,
71
- local_dir_use_symlinks=False,
72
  )
73
 
74
- # discover gguf files
75
- ggufs = []
76
  for root, _, files in os.walk(local_dir):
77
  for f in files:
78
  if f.endswith(".gguf"):
@@ -80,33 +99,77 @@ def ensure_model():
80
  if not ggufs:
81
  raise FileNotFoundError("No .gguf files found after download.")
82
 
83
- # selection logic
84
- model_path = None
85
  if MODEL_FILE:
86
  cand = os.path.join(local_dir, MODEL_FILE)
87
  if os.path.exists(cand):
88
- model_path = cand
89
- if model_path is None:
90
  q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
91
- model_path = (q4 or ggufs)[0]
92
 
93
- chosen_model_path = model_path
94
- print(f"[loader] Using GGUF: {model_path}")
95
 
 
96
  llm = Llama(
97
- model_path=model_path,
98
  n_ctx=N_CTX,
99
  n_threads=N_THREADS,
100
  n_batch=N_BATCH,
101
  logits_all=False,
 
102
  )
103
  model_loaded = True
104
 
105
  except Exception as e:
106
- # surface a clear error to the HTTP layer
107
  raise RuntimeError(f"ensure_model failed: {e}")
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @app.get("/healthz")
111
  def healthz():
112
  return {
@@ -116,26 +179,22 @@ def healthz():
116
  "model_repo": MODEL_REPO,
117
  "model_file_hint": MODEL_FILE,
118
  "chosen_model_path": chosen_model_path,
 
 
 
119
  }
120
 
121
-
122
- SYSTEM_PROMPT = "You are a concise assistant. Reply briefly in plain text."
123
-
124
- class ChatReq(BaseModel):
125
- prompt: str
126
-
127
  @app.post("/chat")
128
- def chat(req: ChatReq):
129
- # Load (or reuse) model
130
  try:
131
- ensure_model() # may take minutes on first-ever call
132
  except Exception as e:
133
  raise HTTPException(status_code=503, detail=f"loading_error: {e}")
134
 
135
- # Inference
136
  try:
137
  full_prompt = (
138
- f"<|system|>\n{SYSTEM_PROMPT}\n</|system|>\n"
139
  f"<|user|>\n{req.prompt}\n</|user|>\n"
140
  )
141
  out = llm(
@@ -148,3 +207,47 @@ def chat(req: ChatReq):
148
  return {"response": out["choices"][0]["text"].strip()}
149
  except Exception as e:
150
  raise HTTPException(status_code=500, detail=f"infer_error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, re, time
2
+ from typing import Any, Dict, List, Optional
3
+ from fastapi import FastAPI, HTTPException, Header
4
  from pydantic import BaseModel
 
 
5
 
6
+ # -------------------------------------------------------------------
7
+ # Choose a writable cache dir *before* importing huggingface_hub
8
+ # -------------------------------------------------------------------
9
+ def first_writable(paths: List[Optional[str]]) -> str:
10
  for p in paths:
11
  if not p:
12
  continue
 
19
  return p
20
  except Exception:
21
  continue
22
+ # final fallback
23
+ p = "/tmp/app_cache"
24
+ os.makedirs(p, exist_ok=True)
25
+ return p
26
 
27
  CACHE_BASE = first_writable([
28
+ os.getenv("SPACE_CACHE_DIR"), # optional override via Settings → Variables
29
+ "/app/.cache", # WORKDIR is usually writable on HF Spaces
30
+ "/home/user/.cache", # typical home dir
31
+ "/tmp/app_cache", # safe fallback
32
  ])
33
 
34
  HF_HOME = os.path.join(CACHE_BASE, "huggingface")
35
+ os.environ["HF_HOME"] = HF_HOME
36
+ os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub")
37
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
38
  os.makedirs(HF_HOME, exist_ok=True)
39
+ MODELS_DIR = os.path.join(CACHE_BASE, "models")
40
  os.makedirs(MODELS_DIR, exist_ok=True)
41
 
42
+ # Only now import libs that read the env vars
43
+ from huggingface_hub import snapshot_download
44
+ from llama_cpp import Llama
45
 
46
+ # -------------------------------------------------------------------
47
+ # Config (can be overridden in Settings → Variables)
48
+ # -------------------------------------------------------------------
49
  MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF")
50
+ MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf") # hint; not mandatory
51
+ MODEL_REV = os.getenv("MODEL_REV") # optional commit SHA to pin
52
 
53
+ # Tuning (lower if memory is tight: N_CTX=1024, N_BATCH=32)
54
+ N_CTX = int(os.getenv("N_CTX", 2048))
55
+ N_BATCH = int(os.getenv("N_BATCH", 64))
56
+ N_THREADS = os.cpu_count() or 2
57
 
58
+ # Optional bearer auth for endpoints
59
+ API_SECRET = os.getenv("API_SECRET") # set in Settings → Variables if you want auth
60
 
61
+ # -------------------------------------------------------------------
62
+ # App + globals
63
+ # -------------------------------------------------------------------
64
+ app = FastAPI(title="Qwen Planner API (CPU)")
65
 
66
+ llm: Optional[Llama] = None
67
+ model_loaded: bool = False
68
+ chosen_model_path: Optional[str] = None
69
 
70
+ # -------------------------------------------------------------------
71
+ # Model loader (lazy, robust gguf discovery)
72
+ # -------------------------------------------------------------------
73
+ def ensure_model() -> None:
74
  """
75
  Lazy-load the model. Downloads any .gguf if needed, then auto-selects one:
76
+ 1) exact MODEL_FILE if present,
77
+ 2) else a *q4*.gguf,
78
+ 3) else the first .gguf found.
79
+ Surfaces clear errors to the HTTP layer.
80
  """
81
  global llm, model_loaded, chosen_model_path
82
  if llm is not None:
 
85
  local_dir = snapshot_download(
86
  repo_id=MODEL_REPO,
87
  revision=MODEL_REV,
88
+ allow_patterns=["*.gguf"], # flexible on filenames
89
  local_dir=MODELS_DIR,
90
+ local_dir_use_symlinks=False, # copy instead of symlink
91
  )
92
 
93
+ # find gguf files
94
+ ggufs: List[str] = []
95
  for root, _, files in os.walk(local_dir):
96
  for f in files:
97
  if f.endswith(".gguf"):
 
99
  if not ggufs:
100
  raise FileNotFoundError("No .gguf files found after download.")
101
 
102
+ # choose file
103
+ path = None
104
  if MODEL_FILE:
105
  cand = os.path.join(local_dir, MODEL_FILE)
106
  if os.path.exists(cand):
107
+ path = cand
108
+ if path is None:
109
  q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
110
+ path = (q4 or ggufs)[0]
111
 
112
+ chosen_model_path = path
113
+ print(f"[loader] Using GGUF: {path}")
114
 
115
+ # load model (CPU)
116
  llm = Llama(
117
+ model_path=path,
118
  n_ctx=N_CTX,
119
  n_threads=N_THREADS,
120
  n_batch=N_BATCH,
121
  logits_all=False,
122
+ n_gpu_layers=0, # ensure CPU
123
  )
124
  model_loaded = True
125
 
126
  except Exception as e:
 
127
  raise RuntimeError(f"ensure_model failed: {e}")
128
 
129
+ # -------------------------------------------------------------------
130
+ # Helpers
131
+ # -------------------------------------------------------------------
132
+ def require_auth(authorization: Optional[str]) -> None:
133
+ if API_SECRET and authorization != f"Bearer {API_SECRET}":
134
+ raise HTTPException(status_code=401, detail="Unauthorized")
135
+
136
+ def extract_json_block(txt: str) -> str:
137
+ m = re.search(r"\{.*\}\s*$", txt, flags=re.S)
138
+ if not m:
139
+ raise ValueError("No JSON object found in output.")
140
+ return m.group(0)
141
+
142
+ # -------------------------------------------------------------------
143
+ # Schemas
144
+ # -------------------------------------------------------------------
145
+ SYSTEM_PROMPT_CHAT = "You are a concise assistant. Reply briefly in plain text."
146
+
147
+ class ChatReq(BaseModel):
148
+ prompt: str
149
 
150
+ class PlanRequest(BaseModel):
151
+ profile: Dict[str, Any]
152
+ sample_rows: List[Dict[str, Any]]
153
+ goal: str = "auto" # "classification" | "regression" | "auto"
154
+ constraints: Dict[str, Any] = {}
155
+
156
+ SYSTEM_PROMPT_PLAN = """You are a data-planning assistant.
157
+ Return ONLY minified JSON matching exactly this schema:
158
+ {
159
+ "cleaning": [{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}],
160
+ "encoding": [{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}],
161
+ "scaling": "none|standard|robust|minmax",
162
+ "target": {"name":"<col_or_empty>","type":"classification|regression|auto"},
163
+ "split": {"strategy":"random|stratified","test_size":0.2,"cv":5},
164
+ "metric": "f1|roc_auc|accuracy|mae|rmse|r2",
165
+ "models": ["lgbm","rf","xgb","logreg","ridge","catboost"],
166
+ "notes":"<short justification>"
167
+ }
168
+ No prose. No markdown. JSON only."""
169
+
170
+ # -------------------------------------------------------------------
171
+ # Routes
172
+ # -------------------------------------------------------------------
173
  @app.get("/healthz")
174
  def healthz():
175
  return {
 
179
  "model_repo": MODEL_REPO,
180
  "model_file_hint": MODEL_FILE,
181
  "chosen_model_path": chosen_model_path,
182
+ "n_ctx": N_CTX,
183
+ "n_batch": N_BATCH,
184
+ "threads": N_THREADS,
185
  }
186
 
 
 
 
 
 
 
187
  @app.post("/chat")
188
+ def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)):
189
+ require_auth(authorization)
190
  try:
191
+ ensure_model() # first call may take minutes (download + load)
192
  except Exception as e:
193
  raise HTTPException(status_code=503, detail=f"loading_error: {e}")
194
 
 
195
  try:
196
  full_prompt = (
197
+ f"<|system|>\n{SYSTEM_PROMPT_CHAT}\n</|system|>\n"
198
  f"<|user|>\n{req.prompt}\n</|user|>\n"
199
  )
200
  out = llm(
 
207
  return {"response": out["choices"][0]["text"].strip()}
208
  except Exception as e:
209
  raise HTTPException(status_code=500, detail=f"infer_error: {e}")
210
+
211
+ @app.post("/plan")
212
+ def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)):
213
+ require_auth(authorization)
214
+ try:
215
+ ensure_model()
216
+ except Exception as e:
217
+ raise HTTPException(status_code=503, detail=f"loading_error: {e}")
218
+
219
+ try:
220
+ # Keep inputs small for free tier
221
+ sample = req.sample_rows[:200]
222
+ profile_json = json.dumps(req.profile)[:8000]
223
+ sample_json = json.dumps(sample)[:8000]
224
+ constraints_json = json.dumps(req.constraints)[:2000]
225
+
226
+ user_block = (
227
+ f"Goal:{req.goal}\n"
228
+ f"Constraints:{constraints_json}\n"
229
+ f"Profile:{profile_json}\n"
230
+ f"Sample:{sample_json}\n"
231
+ )
232
+
233
+ full_prompt = (
234
+ f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n"
235
+ f"<|user|>\n{user_block}\n</|user|>\n"
236
+ )
237
+
238
+ out = llm(
239
+ prompt=full_prompt,
240
+ temperature=0.2,
241
+ top_p=0.9,
242
+ max_tokens=512,
243
+ stop=["</s>"],
244
+ )
245
+ text = out["choices"][0]["text"]
246
+ payload = extract_json_block(text)
247
+ data = json.loads(payload)
248
+ return {"plan": data}
249
+
250
+ except ValueError as e:
251
+ raise HTTPException(status_code=422, detail=f"bad_json: {e}")
252
+ except Exception as e:
253
+ raise HTTPException(status_code=500, detail=f"infer_error: {e}")