File size: 5,372 Bytes
e39db98
fcea2ac
 
09102da
fcea2ac
fc85eed
e39db98
fcea2ac
 
 
09102da
e39db98
 
 
 
09102da
fcea2ac
947b2e3
fcea2ac
e39db98
fcea2ac
09102da
fcea2ac
 
 
 
 
 
 
 
 
 
e39db98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcea2ac
09102da
 
947b2e3
 
 
 
e39db98
fcea2ac
 
 
947b2e3
 
e39db98
 
 
 
 
09102da
fcea2ac
 
947b2e3
e39db98
947b2e3
 
 
 
 
fcea2ac
947b2e3
 
 
 
 
 
 
 
 
 
 
 
fcea2ac
e39db98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcea2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os, json, re
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, Header
from pydantic import BaseModel
from llama_cpp import Llama

# Inference knobs (you can still override via Settings → Variables)
N_CTX      = int(os.getenv("N_CTX", 2048))
N_BATCH    = int(os.getenv("N_BATCH", 64))
N_THREADS  = os.cpu_count() or 2

API_SECRET = os.getenv("API_SECRET")  # optional bearer auth

MODELS_DIR = "/app/models"  # baked into the image by Dockerfile
MODEL_FILE_HINT = os.getenv("MODEL_FILE")  # for /healthz display only

app = FastAPI(title="Qwen Planner API (CPU)")

llm: Optional[Llama] = None
model_loaded = False
chosen_model_path: Optional[str] = None

def require_auth(authorization: Optional[str]) -> None:
    if API_SECRET and authorization != f"Bearer {API_SECRET}":
        raise HTTPException(status_code=401, detail="Unauthorized")

def extract_json_block(txt: str) -> str:
    m = re.search(r"\{.*\}\s*$", txt, flags=re.S)
    if not m:
        raise ValueError("No JSON object found in output.")
    return m.group(0)

def ensure_model():
    global llm, model_loaded, chosen_model_path
    if llm is not None:
        return
    # discover baked gguf
    if not os.path.isdir(MODELS_DIR):
        raise RuntimeError(f"Models directory not found: {MODELS_DIR}")
    ggufs: List[str] = []
    for root, _, files in os.walk(MODELS_DIR):
        for f in files:
            if f.endswith(".gguf"):
                ggufs.append(os.path.join(root, f))
    if not ggufs:
        raise RuntimeError("No .gguf files found in /app/models. Rebuild image with model baked in.")

    # prefer q4 if multiple
    q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
    chosen_model_path = (q4 or ggufs)[0]
    print(f"[loader] Loading GGUF: {chosen_model_path}")

    llm = Llama(
        model_path=chosen_model_path,
        n_ctx=N_CTX,
        n_threads=N_THREADS,
        n_batch=N_BATCH,
        logits_all=False,
        n_gpu_layers=0,
    )
    model_loaded = True

@app.get("/healthz")
def healthz():
    return {
        "status": "ok",
        "loaded": model_loaded,
        "chosen_model_path": chosen_model_path,
        "model_file_hint": MODEL_FILE_HINT,
        "n_ctx": N_CTX,
        "n_batch": N_BATCH,
        "threads": N_THREADS,
    }

SYSTEM_PROMPT_CHAT = "You are a concise assistant. Reply briefly in plain text."

class ChatReq(BaseModel):
    prompt: str

@app.post("/chat")
def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)):
    require_auth(authorization)
    try:
        ensure_model()
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"loading_error: {e}")

    try:
        full_prompt = (
            f"<|system|>\n{SYSTEM_PROMPT_CHAT}\n</|system|>\n"
            f"<|user|>\n{req.prompt}\n</|user|>\n"
        )
        out = llm(
            prompt=full_prompt,
            temperature=0.2,
            top_p=0.9,
            max_tokens=256,
            stop=["</s>"],
        )
        return {"response": out["choices"][0]["text"].strip()}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"infer_error: {e}")

# -------- planner endpoint --------
class PlanRequest(BaseModel):
    profile: Dict[str, Any]
    sample_rows: List[Dict[str, Any]]
    goal: str = "auto"
    constraints: Dict[str, Any] = {}

SYSTEM_PROMPT_PLAN = """You are a data-planning assistant.
Return ONLY minified JSON matching exactly this schema:
{
 "cleaning":[{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}],
 "encoding":[{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}],
 "scaling":"none|standard|robust|minmax",
 "target":{"name":"<col_or_empty>","type":"classification|regression|auto"},
 "split":{"strategy":"random|stratified","test_size":0.2,"cv":5},
 "metric":"f1|roc_auc|accuracy|mae|rmse|r2",
 "models":["lgbm","rf","xgb","logreg","ridge","catboost"],
 "notes":"<short justification>"
}
No prose. No markdown. JSON only."""

@app.post("/plan")
def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)):
    require_auth(authorization)
    try:
        ensure_model()
    except Exception as e:
        raise HTTPException(status_code=503, detail=f"loading_error: {e}")

    try:
        sample = req.sample_rows[:200]
        profile_json = json.dumps(req.profile)[:8000]
        sample_json  = json.dumps(sample)[:8000]
        constraints_json = json.dumps(req.constraints)[:2000]

        user_block = (
            f"Goal:{req.goal}\n"
            f"Constraints:{constraints_json}\n"
            f"Profile:{profile_json}\n"
            f"Sample:{sample_json}\n"
        )
        full_prompt = (
            f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n"
            f"<|user|>\n{user_block}\n</|user|>\n"
        )
        out = llm(
            prompt=full_prompt,
            temperature=0.2,
            top_p=0.9,
            max_tokens=512,
            stop=["</s>"],
        )
        text = out["choices"][0]["text"]
        payload = extract_json_block(text)
        data = json.loads(payload)
        return {"plan": data}
    except ValueError as e:
        raise HTTPException(status_code=422, detail=f"bad_json: {e}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"infer_error: {e}")