Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,163 +1,84 @@
|
|
1 |
-
import os
|
2 |
-
from
|
3 |
-
from fastapi import FastAPI, HTTPException, Header
|
4 |
from pydantic import BaseModel
|
5 |
-
from
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
if not os.path.isdir(MODELS_DIR):
|
39 |
-
raise RuntimeError(f"Models directory not found: {MODELS_DIR}")
|
40 |
-
ggufs: List[str] = []
|
41 |
-
for root, _, files in os.walk(MODELS_DIR):
|
42 |
-
for f in files:
|
43 |
-
if f.endswith(".gguf"):
|
44 |
-
ggufs.append(os.path.join(root, f))
|
45 |
-
if not ggufs:
|
46 |
-
raise RuntimeError("No .gguf files found in /app/models. Rebuild image with model baked in.")
|
47 |
-
|
48 |
-
# prefer q4 if multiple
|
49 |
-
q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()]
|
50 |
-
chosen_model_path = (q4 or ggufs)[0]
|
51 |
-
print(f"[loader] Loading GGUF: {chosen_model_path}")
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
)
|
61 |
-
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
return {
|
66 |
"status": "ok",
|
67 |
-
"
|
68 |
-
"
|
69 |
-
"
|
70 |
-
"n_ctx": N_CTX,
|
71 |
-
"n_batch": N_BATCH,
|
72 |
-
"threads": N_THREADS,
|
73 |
}
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
class ChatReq(BaseModel):
|
78 |
-
prompt: str
|
79 |
-
|
80 |
-
@app.post("/chat")
|
81 |
-
def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)):
|
82 |
-
require_auth(authorization)
|
83 |
-
try:
|
84 |
-
ensure_model()
|
85 |
-
except Exception as e:
|
86 |
-
raise HTTPException(status_code=503, detail=f"loading_error: {e}")
|
87 |
-
|
88 |
try:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
stop=["</s>"],
|
99 |
-
)
|
100 |
-
return {"response": out["choices"][0]["text"].strip()}
|
101 |
-
except Exception as e:
|
102 |
-
raise HTTPException(status_code=500, detail=f"infer_error: {e}")
|
103 |
-
|
104 |
-
# -------- planner endpoint --------
|
105 |
-
class PlanRequest(BaseModel):
|
106 |
-
profile: Dict[str, Any]
|
107 |
-
sample_rows: List[Dict[str, Any]]
|
108 |
-
goal: str = "auto"
|
109 |
-
constraints: Dict[str, Any] = {}
|
110 |
-
|
111 |
-
SYSTEM_PROMPT_PLAN = """You are a data-planning assistant.
|
112 |
-
Return ONLY minified JSON matching exactly this schema:
|
113 |
-
{
|
114 |
-
"cleaning":[{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}],
|
115 |
-
"encoding":[{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}],
|
116 |
-
"scaling":"none|standard|robust|minmax",
|
117 |
-
"target":{"name":"<col_or_empty>","type":"classification|regression|auto"},
|
118 |
-
"split":{"strategy":"random|stratified","test_size":0.2,"cv":5},
|
119 |
-
"metric":"f1|roc_auc|accuracy|mae|rmse|r2",
|
120 |
-
"models":["lgbm","rf","xgb","logreg","ridge","catboost"],
|
121 |
-
"notes":"<short justification>"
|
122 |
-
}
|
123 |
-
No prose. No markdown. JSON only."""
|
124 |
-
|
125 |
-
@app.post("/plan")
|
126 |
-
def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)):
|
127 |
-
require_auth(authorization)
|
128 |
-
try:
|
129 |
-
ensure_model()
|
130 |
-
except Exception as e:
|
131 |
-
raise HTTPException(status_code=503, detail=f"loading_error: {e}")
|
132 |
-
|
133 |
-
try:
|
134 |
-
sample = req.sample_rows[:200]
|
135 |
-
profile_json = json.dumps(req.profile)[:8000]
|
136 |
-
sample_json = json.dumps(sample)[:8000]
|
137 |
-
constraints_json = json.dumps(req.constraints)[:2000]
|
138 |
-
|
139 |
-
user_block = (
|
140 |
-
f"Goal:{req.goal}\n"
|
141 |
-
f"Constraints:{constraints_json}\n"
|
142 |
-
f"Profile:{profile_json}\n"
|
143 |
-
f"Sample:{sample_json}\n"
|
144 |
-
)
|
145 |
-
full_prompt = (
|
146 |
-
f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n"
|
147 |
-
f"<|user|>\n{user_block}\n</|user|>\n"
|
148 |
-
)
|
149 |
-
out = llm(
|
150 |
-
prompt=full_prompt,
|
151 |
-
temperature=0.2,
|
152 |
-
top_p=0.9,
|
153 |
-
max_tokens=512,
|
154 |
-
stop=["</s>"],
|
155 |
)
|
156 |
-
|
157 |
-
payload = extract_json_block(text)
|
158 |
-
data = json.loads(payload)
|
159 |
-
return {"plan": data}
|
160 |
-
except ValueError as e:
|
161 |
-
raise HTTPException(status_code=422, detail=f"bad_json: {e}")
|
162 |
except Exception as e:
|
163 |
-
raise HTTPException(status_code=500, detail=
|
|
|
1 |
+
import os
|
2 |
+
from fastapi import FastAPI, HTTPException
|
|
|
3 |
from pydantic import BaseModel
|
4 |
+
from typing import List, Optional
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
from ctransformers import AutoModelForCausalLM
|
7 |
+
|
8 |
+
# ------------------------
|
9 |
+
# Model configuration
|
10 |
+
# ------------------------
|
11 |
+
REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
|
12 |
+
FILENAME = "Llama-3.2-3B-Instruct-Q4_K_L.gguf"
|
13 |
+
MODEL_TYPE = "llama"
|
14 |
+
|
15 |
+
# ------------------------
|
16 |
+
# Persistent cache (Docker Spaces -> /data)
|
17 |
+
# ------------------------
|
18 |
+
CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
|
19 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
20 |
+
|
21 |
+
app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
|
22 |
+
|
23 |
+
_model = None
|
24 |
+
|
25 |
+
def get_model():
|
26 |
+
global _model
|
27 |
+
if _model is not None:
|
28 |
+
return _model
|
29 |
+
|
30 |
+
# Download exact GGUF file to persistent cache
|
31 |
+
local_path = hf_hub_download(
|
32 |
+
repo_id=REPO_ID,
|
33 |
+
filename=FILENAME,
|
34 |
+
local_dir=CACHE_DIR,
|
35 |
+
local_dir_use_symlinks=False,
|
36 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
# Load with ctransformers (CPU by default)
|
39 |
+
_model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_path_or_repo_id=os.path.dirname(local_path),
|
41 |
+
model_file=os.path.basename(local_path),
|
42 |
+
model_type=MODEL_TYPE,
|
43 |
+
gpu_layers=int(os.environ.get("GPU_LAYERS", "0")), # set >0 on GPU Spaces
|
44 |
+
context_length=int(os.environ.get("CTX_LEN", "4096")),
|
45 |
)
|
46 |
+
return _model
|
47 |
|
48 |
+
class GenerateIn(BaseModel):
|
49 |
+
prompt: str
|
50 |
+
max_new_tokens: int = 256
|
51 |
+
temperature: float = 0.7
|
52 |
+
top_p: float = 0.95
|
53 |
+
top_k: int = 40
|
54 |
+
repetition_penalty: float = 1.1
|
55 |
+
stop: Optional[List[str]] = None
|
56 |
+
|
57 |
+
class GenerateOut(BaseModel):
|
58 |
+
completion: str
|
59 |
+
|
60 |
+
@app.get("/")
|
61 |
+
def health():
|
62 |
return {
|
63 |
"status": "ok",
|
64 |
+
"model": {"repo_id": REPO_ID, "filename": FILENAME, "type": MODEL_TYPE},
|
65 |
+
"cache_dir": CACHE_DIR,
|
66 |
+
"endpoints": {"POST /generate": "Generate a completion"},
|
|
|
|
|
|
|
67 |
}
|
68 |
|
69 |
+
@app.post("/generate", response_model=GenerateOut)
|
70 |
+
def generate(body: GenerateIn):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
try:
|
72 |
+
model = get_model()
|
73 |
+
text = model(
|
74 |
+
body.prompt,
|
75 |
+
max_new_tokens=body.max_new_tokens,
|
76 |
+
temperature=body.temperature,
|
77 |
+
top_p=body.top_p,
|
78 |
+
top_k=body.top_k,
|
79 |
+
repetition_penalty=body.repetition_penalty,
|
80 |
+
stop=body.stop,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
)
|
82 |
+
return GenerateOut(completion=text)
|
|
|
|
|
|
|
|
|
|
|
83 |
except Exception as e:
|
84 |
+
raise HTTPException(status_code=500, detail=str(e))
|