omaryasserhassan's picture
Update app.py
e39db98 verified
raw
history blame
5.37 kB
import os, json, re
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
from llama_cpp import Llama
# Inference knobs (you can still override via Settings → Variables)
N_CTX = int(os.getenv("N_CTX", 2048))
N_BATCH = int(os.getenv("N_BATCH", 64))
N_THREADS = os.cpu_count() or 2
API_SECRET = os.getenv("API_SECRET") # optional bearer auth
MODELS_DIR = "/app/models" # baked into the image by Dockerfile
MODEL_FILE_HINT = os.getenv("MODEL_FILE") # for /healthz display only
app = FastAPI(title="Qwen Planner API (CPU)")
llm: Optional[Llama] = None
model_loaded = False
chosen_model_path: Optional[str] = None
def require_auth(authorization: Optional[str]) -> None:
if API_SECRET and authorization != f"Bearer {API_SECRET}":
raise HTTPException(status_code=401, detail="Unauthorized")
def extract_json_block(txt: str) -> str:
m = re.search(r"\{.*\}\s*$", txt, flags=re.S)
if not m:
raise ValueError("No JSON object found in output.")
return m.group(0)
def ensure_model():
global llm, model_loaded, chosen_model_path
if llm is not None:
return
# discover baked gguf
if not os.path.isdir(MODELS_DIR):
raise RuntimeError(f"Models directory not found: {MODELS_DIR}")
ggufs: List[str] = []
for root, _, files in os.walk(MODELS_DIR):
for f in files:
if f.endswith(".gguf"):
ggufs.append(os.path.join(root, f))
if not ggufs:
raise RuntimeError("No .gguf files found in /app/models. Rebuild image with model baked in.")
# prefer q4 if multiple
q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
chosen_model_path = (q4 or ggufs)[0]
print(f"[loader] Loading GGUF: {chosen_model_path}")
llm = Llama(
model_path=chosen_model_path,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
logits_all=False,
n_gpu_layers=0,
)
model_loaded = True
@app.get("/healthz")
def healthz():
return {
"status": "ok",
"loaded": model_loaded,
"chosen_model_path": chosen_model_path,
"model_file_hint": MODEL_FILE_HINT,
"n_ctx": N_CTX,
"n_batch": N_BATCH,
"threads": N_THREADS,
}
SYSTEM_PROMPT_CHAT = "You are a concise assistant. Reply briefly in plain text."
class ChatReq(BaseModel):
prompt: str
@app.post("/chat")
def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)):
require_auth(authorization)
try:
ensure_model()
except Exception as e:
raise HTTPException(status_code=503, detail=f"loading_error: {e}")
try:
full_prompt = (
f"<|system|>\n{SYSTEM_PROMPT_CHAT}\n</|system|>\n"
f"<|user|>\n{req.prompt}\n</|user|>\n"
)
out = llm(
prompt=full_prompt,
temperature=0.2,
top_p=0.9,
max_tokens=256,
stop=["</s>"],
)
return {"response": out["choices"][0]["text"].strip()}
except Exception as e:
raise HTTPException(status_code=500, detail=f"infer_error: {e}")
# -------- planner endpoint --------
class PlanRequest(BaseModel):
profile: Dict[str, Any]
sample_rows: List[Dict[str, Any]]
goal: str = "auto"
constraints: Dict[str, Any] = {}
SYSTEM_PROMPT_PLAN = """You are a data-planning assistant.
Return ONLY minified JSON matching exactly this schema:
{
"cleaning":[{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}],
"encoding":[{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}],
"scaling":"none|standard|robust|minmax",
"target":{"name":"<col_or_empty>","type":"classification|regression|auto"},
"split":{"strategy":"random|stratified","test_size":0.2,"cv":5},
"metric":"f1|roc_auc|accuracy|mae|rmse|r2",
"models":["lgbm","rf","xgb","logreg","ridge","catboost"],
"notes":"<short justification>"
}
No prose. No markdown. JSON only."""
@app.post("/plan")
def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)):
require_auth(authorization)
try:
ensure_model()
except Exception as e:
raise HTTPException(status_code=503, detail=f"loading_error: {e}")
try:
sample = req.sample_rows[:200]
profile_json = json.dumps(req.profile)[:8000]
sample_json = json.dumps(sample)[:8000]
constraints_json = json.dumps(req.constraints)[:2000]
user_block = (
f"Goal:{req.goal}\n"
f"Constraints:{constraints_json}\n"
f"Profile:{profile_json}\n"
f"Sample:{sample_json}\n"
)
full_prompt = (
f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n"
f"<|user|>\n{user_block}\n</|user|>\n"
)
out = llm(
prompt=full_prompt,
temperature=0.2,
top_p=0.9,
max_tokens=512,
stop=["</s>"],
)
text = out["choices"][0]["text"]
payload = extract_json_block(text)
data = json.loads(payload)
return {"plan": data}
except ValueError as e:
raise HTTPException(status_code=422, detail=f"bad_json: {e}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"infer_error: {e}")