omaryasserhassan commited on
Commit
dd38c20
·
verified ·
1 Parent(s): cea5896

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -152
app.py CHANGED
@@ -1,163 +1,84 @@
1
- import os, json, re
2
- from typing import Any, Dict, List, Optional
3
- from fastapi import FastAPI, HTTPException, Header
4
  from pydantic import BaseModel
5
- from llama_cpp import Llama
6
-
7
- # Inference knobs (you can still override via Settings → Variables)
8
- N_CTX = int(os.getenv("N_CTX", 2048))
9
- N_BATCH = int(os.getenv("N_BATCH", 64))
10
- N_THREADS = os.cpu_count() or 2
11
-
12
- API_SECRET = os.getenv("API_SECRET") # optional bearer auth
13
-
14
- MODELS_DIR = "/app/models" # baked into the image by Dockerfile
15
- MODEL_FILE_HINT = os.getenv("MODEL_FILE") # for /healthz display only
16
-
17
- app = FastAPI(title="Qwen Planner API (CPU)")
18
-
19
- llm: Optional[Llama] = None
20
- model_loaded = False
21
- chosen_model_path: Optional[str] = None
22
-
23
- def require_auth(authorization: Optional[str]) -> None:
24
- if API_SECRET and authorization != f"Bearer {API_SECRET}":
25
- raise HTTPException(status_code=401, detail="Unauthorized")
26
-
27
- def extract_json_block(txt: str) -> str:
28
- m = re.search(r"\{.*\}\s*$", txt, flags=re.S)
29
- if not m:
30
- raise ValueError("No JSON object found in output.")
31
- return m.group(0)
32
-
33
- def ensure_model():
34
- global llm, model_loaded, chosen_model_path
35
- if llm is not None:
36
- return
37
- # discover baked gguf
38
- if not os.path.isdir(MODELS_DIR):
39
- raise RuntimeError(f"Models directory not found: {MODELS_DIR}")
40
- ggufs: List[str] = []
41
- for root, _, files in os.walk(MODELS_DIR):
42
- for f in files:
43
- if f.endswith(".gguf"):
44
- ggufs.append(os.path.join(root, f))
45
- if not ggufs:
46
- raise RuntimeError("No .gguf files found in /app/models. Rebuild image with model baked in.")
47
-
48
- # prefer q4 if multiple
49
- q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
50
- chosen_model_path = (q4 or ggufs)[0]
51
- print(f"[loader] Loading GGUF: {chosen_model_path}")
52
 
53
- llm = Llama(
54
- model_path=chosen_model_path,
55
- n_ctx=N_CTX,
56
- n_threads=N_THREADS,
57
- n_batch=N_BATCH,
58
- logits_all=False,
59
- n_gpu_layers=0,
60
  )
61
- model_loaded = True
62
 
63
- @app.get("/healthz")
64
- def healthz():
 
 
 
 
 
 
 
 
 
 
 
 
65
  return {
66
  "status": "ok",
67
- "loaded": model_loaded,
68
- "chosen_model_path": chosen_model_path,
69
- "model_file_hint": MODEL_FILE_HINT,
70
- "n_ctx": N_CTX,
71
- "n_batch": N_BATCH,
72
- "threads": N_THREADS,
73
  }
74
 
75
- SYSTEM_PROMPT_CHAT = "You are a concise assistant. Reply briefly in plain text."
76
-
77
- class ChatReq(BaseModel):
78
- prompt: str
79
-
80
- @app.post("/chat")
81
- def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)):
82
- require_auth(authorization)
83
- try:
84
- ensure_model()
85
- except Exception as e:
86
- raise HTTPException(status_code=503, detail=f"loading_error: {e}")
87
-
88
  try:
89
- full_prompt = (
90
- f"<|system|>\n{SYSTEM_PROMPT_CHAT}\n</|system|>\n"
91
- f"<|user|>\n{req.prompt}\n</|user|>\n"
92
- )
93
- out = llm(
94
- prompt=full_prompt,
95
- temperature=0.2,
96
- top_p=0.9,
97
- max_tokens=256,
98
- stop=["</s>"],
99
- )
100
- return {"response": out["choices"][0]["text"].strip()}
101
- except Exception as e:
102
- raise HTTPException(status_code=500, detail=f"infer_error: {e}")
103
-
104
- # -------- planner endpoint --------
105
- class PlanRequest(BaseModel):
106
- profile: Dict[str, Any]
107
- sample_rows: List[Dict[str, Any]]
108
- goal: str = "auto"
109
- constraints: Dict[str, Any] = {}
110
-
111
- SYSTEM_PROMPT_PLAN = """You are a data-planning assistant.
112
- Return ONLY minified JSON matching exactly this schema:
113
- {
114
- "cleaning":[{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}],
115
- "encoding":[{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}],
116
- "scaling":"none|standard|robust|minmax",
117
- "target":{"name":"<col_or_empty>","type":"classification|regression|auto"},
118
- "split":{"strategy":"random|stratified","test_size":0.2,"cv":5},
119
- "metric":"f1|roc_auc|accuracy|mae|rmse|r2",
120
- "models":["lgbm","rf","xgb","logreg","ridge","catboost"],
121
- "notes":"<short justification>"
122
- }
123
- No prose. No markdown. JSON only."""
124
-
125
- @app.post("/plan")
126
- def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)):
127
- require_auth(authorization)
128
- try:
129
- ensure_model()
130
- except Exception as e:
131
- raise HTTPException(status_code=503, detail=f"loading_error: {e}")
132
-
133
- try:
134
- sample = req.sample_rows[:200]
135
- profile_json = json.dumps(req.profile)[:8000]
136
- sample_json = json.dumps(sample)[:8000]
137
- constraints_json = json.dumps(req.constraints)[:2000]
138
-
139
- user_block = (
140
- f"Goal:{req.goal}\n"
141
- f"Constraints:{constraints_json}\n"
142
- f"Profile:{profile_json}\n"
143
- f"Sample:{sample_json}\n"
144
- )
145
- full_prompt = (
146
- f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n"
147
- f"<|user|>\n{user_block}\n</|user|>\n"
148
- )
149
- out = llm(
150
- prompt=full_prompt,
151
- temperature=0.2,
152
- top_p=0.9,
153
- max_tokens=512,
154
- stop=["</s>"],
155
  )
156
- text = out["choices"][0]["text"]
157
- payload = extract_json_block(text)
158
- data = json.loads(payload)
159
- return {"plan": data}
160
- except ValueError as e:
161
- raise HTTPException(status_code=422, detail=f"bad_json: {e}")
162
  except Exception as e:
163
- raise HTTPException(status_code=500, detail=f"infer_error: {e}")
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
 
3
  from pydantic import BaseModel
4
+ from typing import List, Optional
5
+ from huggingface_hub import hf_hub_download
6
+ from ctransformers import AutoModelForCausalLM
7
+
8
+ # ------------------------
9
+ # Model configuration
10
+ # ------------------------
11
+ REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
12
+ FILENAME = "Llama-3.2-3B-Instruct-Q4_K_L.gguf"
13
+ MODEL_TYPE = "llama"
14
+
15
+ # ------------------------
16
+ # Persistent cache (Docker Spaces -> /data)
17
+ # ------------------------
18
+ CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
19
+ os.makedirs(CACHE_DIR, exist_ok=True)
20
+
21
+ app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
22
+
23
+ _model = None
24
+
25
+ def get_model():
26
+ global _model
27
+ if _model is not None:
28
+ return _model
29
+
30
+ # Download exact GGUF file to persistent cache
31
+ local_path = hf_hub_download(
32
+ repo_id=REPO_ID,
33
+ filename=FILENAME,
34
+ local_dir=CACHE_DIR,
35
+ local_dir_use_symlinks=False,
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Load with ctransformers (CPU by default)
39
+ _model = AutoModelForCausalLM.from_pretrained(
40
+ model_path_or_repo_id=os.path.dirname(local_path),
41
+ model_file=os.path.basename(local_path),
42
+ model_type=MODEL_TYPE,
43
+ gpu_layers=int(os.environ.get("GPU_LAYERS", "0")), # set >0 on GPU Spaces
44
+ context_length=int(os.environ.get("CTX_LEN", "4096")),
45
  )
46
+ return _model
47
 
48
+ class GenerateIn(BaseModel):
49
+ prompt: str
50
+ max_new_tokens: int = 256
51
+ temperature: float = 0.7
52
+ top_p: float = 0.95
53
+ top_k: int = 40
54
+ repetition_penalty: float = 1.1
55
+ stop: Optional[List[str]] = None
56
+
57
+ class GenerateOut(BaseModel):
58
+ completion: str
59
+
60
+ @app.get("/")
61
+ def health():
62
  return {
63
  "status": "ok",
64
+ "model": {"repo_id": REPO_ID, "filename": FILENAME, "type": MODEL_TYPE},
65
+ "cache_dir": CACHE_DIR,
66
+ "endpoints": {"POST /generate": "Generate a completion"},
 
 
 
67
  }
68
 
69
+ @app.post("/generate", response_model=GenerateOut)
70
+ def generate(body: GenerateIn):
 
 
 
 
 
 
 
 
 
 
 
71
  try:
72
+ model = get_model()
73
+ text = model(
74
+ body.prompt,
75
+ max_new_tokens=body.max_new_tokens,
76
+ temperature=body.temperature,
77
+ top_p=body.top_p,
78
+ top_k=body.top_k,
79
+ repetition_penalty=body.repetition_penalty,
80
+ stop=body.stop,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
+ return GenerateOut(completion=text)
 
 
 
 
 
83
  except Exception as e:
84
+ raise HTTPException(status_code=500, detail=str(e))