Spaces:
Sleeping
Sleeping
import os, json, re | |
from typing import Any, Dict, List, Optional | |
from fastapi import FastAPI, HTTPException, Header | |
from pydantic import BaseModel | |
from llama_cpp import Llama | |
# Inference knobs (you can still override via Settings → Variables) | |
N_CTX = int(os.getenv("N_CTX", 2048)) | |
N_BATCH = int(os.getenv("N_BATCH", 64)) | |
N_THREADS = os.cpu_count() or 2 | |
API_SECRET = os.getenv("API_SECRET") # optional bearer auth | |
MODELS_DIR = "/app/models" # baked into the image by Dockerfile | |
MODEL_FILE_HINT = os.getenv("MODEL_FILE") # for /healthz display only | |
app = FastAPI(title="Qwen Planner API (CPU)") | |
llm: Optional[Llama] = None | |
model_loaded = False | |
chosen_model_path: Optional[str] = None | |
def require_auth(authorization: Optional[str]) -> None: | |
if API_SECRET and authorization != f"Bearer {API_SECRET}": | |
raise HTTPException(status_code=401, detail="Unauthorized") | |
def extract_json_block(txt: str) -> str: | |
m = re.search(r"\{.*\}\s*$", txt, flags=re.S) | |
if not m: | |
raise ValueError("No JSON object found in output.") | |
return m.group(0) | |
def ensure_model(): | |
global llm, model_loaded, chosen_model_path | |
if llm is not None: | |
return | |
# discover baked gguf | |
if not os.path.isdir(MODELS_DIR): | |
raise RuntimeError(f"Models directory not found: {MODELS_DIR}") | |
ggufs: List[str] = [] | |
for root, _, files in os.walk(MODELS_DIR): | |
for f in files: | |
if f.endswith(".gguf"): | |
ggufs.append(os.path.join(root, f)) | |
if not ggufs: | |
raise RuntimeError("No .gguf files found in /app/models. Rebuild image with model baked in.") | |
# prefer q4 if multiple | |
q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()] | |
chosen_model_path = (q4 or ggufs)[0] | |
print(f"[loader] Loading GGUF: {chosen_model_path}") | |
llm = Llama( | |
model_path=chosen_model_path, | |
n_ctx=N_CTX, | |
n_threads=N_THREADS, | |
n_batch=N_BATCH, | |
logits_all=False, | |
n_gpu_layers=0, | |
) | |
model_loaded = True | |
def healthz(): | |
return { | |
"status": "ok", | |
"loaded": model_loaded, | |
"chosen_model_path": chosen_model_path, | |
"model_file_hint": MODEL_FILE_HINT, | |
"n_ctx": N_CTX, | |
"n_batch": N_BATCH, | |
"threads": N_THREADS, | |
} | |
SYSTEM_PROMPT_CHAT = "You are a concise assistant. Reply briefly in plain text." | |
class ChatReq(BaseModel): | |
prompt: str | |
def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)): | |
require_auth(authorization) | |
try: | |
ensure_model() | |
except Exception as e: | |
raise HTTPException(status_code=503, detail=f"loading_error: {e}") | |
try: | |
full_prompt = ( | |
f"<|system|>\n{SYSTEM_PROMPT_CHAT}\n</|system|>\n" | |
f"<|user|>\n{req.prompt}\n</|user|>\n" | |
) | |
out = llm( | |
prompt=full_prompt, | |
temperature=0.2, | |
top_p=0.9, | |
max_tokens=256, | |
stop=["</s>"], | |
) | |
return {"response": out["choices"][0]["text"].strip()} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"infer_error: {e}") | |
# -------- planner endpoint -------- | |
class PlanRequest(BaseModel): | |
profile: Dict[str, Any] | |
sample_rows: List[Dict[str, Any]] | |
goal: str = "auto" | |
constraints: Dict[str, Any] = {} | |
SYSTEM_PROMPT_PLAN = """You are a data-planning assistant. | |
Return ONLY minified JSON matching exactly this schema: | |
{ | |
"cleaning":[{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}], | |
"encoding":[{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}], | |
"scaling":"none|standard|robust|minmax", | |
"target":{"name":"<col_or_empty>","type":"classification|regression|auto"}, | |
"split":{"strategy":"random|stratified","test_size":0.2,"cv":5}, | |
"metric":"f1|roc_auc|accuracy|mae|rmse|r2", | |
"models":["lgbm","rf","xgb","logreg","ridge","catboost"], | |
"notes":"<short justification>" | |
} | |
No prose. No markdown. JSON only.""" | |
def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)): | |
require_auth(authorization) | |
try: | |
ensure_model() | |
except Exception as e: | |
raise HTTPException(status_code=503, detail=f"loading_error: {e}") | |
try: | |
sample = req.sample_rows[:200] | |
profile_json = json.dumps(req.profile)[:8000] | |
sample_json = json.dumps(sample)[:8000] | |
constraints_json = json.dumps(req.constraints)[:2000] | |
user_block = ( | |
f"Goal:{req.goal}\n" | |
f"Constraints:{constraints_json}\n" | |
f"Profile:{profile_json}\n" | |
f"Sample:{sample_json}\n" | |
) | |
full_prompt = ( | |
f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n" | |
f"<|user|>\n{user_block}\n</|user|>\n" | |
) | |
out = llm( | |
prompt=full_prompt, | |
temperature=0.2, | |
top_p=0.9, | |
max_tokens=512, | |
stop=["</s>"], | |
) | |
text = out["choices"][0]["text"] | |
payload = extract_json_block(text) | |
data = json.loads(payload) | |
return {"plan": data} | |
except ValueError as e: | |
raise HTTPException(status_code=422, detail=f"bad_json: {e}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"infer_error: {e}") | |