Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# app.py (7-feature aligned, server-safe)
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import joblib
|
@@ -14,8 +13,8 @@ warnings.filterwarnings("ignore")
|
|
14 |
MODEL_PATH = "models/svm_pipeline.joblib"
|
15 |
BG_PATH = "data/bg.csv"
|
16 |
|
17 |
-
# 模型最终需要的
|
18 |
-
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR"]
|
19 |
|
20 |
# 加载模型与背景
|
21 |
pipeline = joblib.load(MODEL_PATH)
|
@@ -52,15 +51,13 @@ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray
|
|
52 |
|
53 |
# ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位)
|
54 |
feat = np.asarray(feat_1d, dtype=float).copy()
|
55 |
-
# 这里给出一套通用规则:ALB/TP/CREA/PNI 一位小数;TBA 两位;AST_ALT 与 AAPR 两位
|
56 |
round_map = {
|
57 |
-
"ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3
|
58 |
}
|
59 |
feat_rounded = [
|
60 |
np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames)
|
61 |
]
|
62 |
|
63 |
-
# 如需进一步减字,可把特征名缩写或去掉空格,这里保持原名
|
64 |
shap.force_plot(
|
65 |
base_val,
|
66 |
np.asarray(shap_1d).reshape(-1),
|
@@ -80,7 +77,8 @@ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray
|
|
80 |
def _coerce_float(x):
|
81 |
return float(x) if x is not None and x != "" else np.nan
|
82 |
|
83 |
-
|
|
|
84 |
status = []
|
85 |
try:
|
86 |
# ---- 1) 取数并校验 ----
|
@@ -91,21 +89,24 @@ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
|
|
91 |
CREA = _coerce_float(CREA)
|
92 |
LYM = _coerce_float(LYM)
|
93 |
ALP = _coerce_float(ALP)
|
|
|
94 |
|
95 |
-
vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP]
|
96 |
if any(np.isnan(v) for v in vals):
|
97 |
return None, None, "Error: 所有输入必须为数值且不可缺失。"
|
98 |
|
99 |
if ALP <= 0:
|
100 |
return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
|
|
|
|
|
101 |
|
102 |
# ---- 2) 衍生指标 ----
|
103 |
PNI = ALB + 5.0 * LYM
|
104 |
AAPR = ALB / ALP
|
105 |
-
status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.
|
106 |
|
107 |
-
# ---- 3) 组装最终
|
108 |
-
x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR]], dtype=np.float64)
|
109 |
|
110 |
if hasattr(pipeline, "predict_proba"):
|
111 |
classes_ = getattr(pipeline, "classes_", None)
|
@@ -162,7 +163,7 @@ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
|
|
162 |
except Exception as e_force:
|
163 |
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
|
164 |
order = np.argsort(np.abs(sv))[::-1]
|
165 |
-
topk = order[:min(
|
166 |
plt.close('all')
|
167 |
fig = plt.figure(figsize=(8, 5), dpi=160)
|
168 |
plt.barh(np.array(FEATURES)[topk], sv[topk])
|
@@ -176,19 +177,21 @@ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
|
|
176 |
except Exception as e:
|
177 |
return None, None, f"Fatal error: {repr(e)}"
|
178 |
|
179 |
-
# ======
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
183 |
|
184 |
|
185 |
# ====== Gradio 界面 ======
|
186 |
with gr.Blocks() as demo:
|
187 |
gr.Markdown(
|
188 |
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
|
189 |
-
"Please enter ALB, TP, TBA, AST/ALT, CREA, LYM,
|
190 |
"**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
|
191 |
-
"LYM(×10⁹/L), ALP(U/L)."
|
192 |
)
|
193 |
|
194 |
with gr.Row():
|
@@ -201,6 +204,7 @@ with gr.Blocks() as demo:
|
|
201 |
gr.Number(label="CREA (μmol/L)"),
|
202 |
gr.Number(label="LYM (×10⁹/L)"),
|
203 |
gr.Number(label="ALP (U/L)"),
|
|
|
204 |
]
|
205 |
ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
|
206 |
|
@@ -223,4 +227,4 @@ with gr.Blocks() as demo:
|
|
223 |
)
|
224 |
|
225 |
if __name__ == "__main__":
|
226 |
-
demo.launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import joblib
|
|
|
13 |
MODEL_PATH = "models/svm_pipeline.joblib"
|
14 |
BG_PATH = "data/bg.csv"
|
15 |
|
16 |
+
# 模型最终需要的 8 个特征(顺序必须与训练一致)
|
17 |
+
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR", "DBIL"]
|
18 |
|
19 |
# 加载模型与背景
|
20 |
pipeline = joblib.load(MODEL_PATH)
|
|
|
51 |
|
52 |
# ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位)
|
53 |
feat = np.asarray(feat_1d, dtype=float).copy()
|
|
|
54 |
round_map = {
|
55 |
+
"ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3, "DBIL": 1
|
56 |
}
|
57 |
feat_rounded = [
|
58 |
np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames)
|
59 |
]
|
60 |
|
|
|
61 |
shap.force_plot(
|
62 |
base_val,
|
63 |
np.asarray(shap_1d).reshape(-1),
|
|
|
77 |
def _coerce_float(x):
|
78 |
return float(x) if x is not None and x != "" else np.nan
|
79 |
|
80 |
+
|
81 |
+
def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL, nsamples=200):
|
82 |
status = []
|
83 |
try:
|
84 |
# ---- 1) 取数并校验 ----
|
|
|
89 |
CREA = _coerce_float(CREA)
|
90 |
LYM = _coerce_float(LYM)
|
91 |
ALP = _coerce_float(ALP)
|
92 |
+
DBIL = _coerce_float(DBIL)
|
93 |
|
94 |
+
vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL]
|
95 |
if any(np.isnan(v) for v in vals):
|
96 |
return None, None, "Error: 所有输入必须为数值且不可缺失。"
|
97 |
|
98 |
if ALP <= 0:
|
99 |
return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
|
100 |
+
if DBIL < 0:
|
101 |
+
return None, None, "Error: DBIL 不能为负值。"
|
102 |
|
103 |
# ---- 2) 衍生指标 ----
|
104 |
PNI = ALB + 5.0 * LYM
|
105 |
AAPR = ALB / ALP
|
106 |
+
status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.3f}")
|
107 |
|
108 |
+
# ---- 3) 组装最终 8 特征并预测 ----
|
109 |
+
x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR, DBIL]], dtype=np.float64)
|
110 |
|
111 |
if hasattr(pipeline, "predict_proba"):
|
112 |
classes_ = getattr(pipeline, "classes_", None)
|
|
|
163 |
except Exception as e_force:
|
164 |
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
|
165 |
order = np.argsort(np.abs(sv))[::-1]
|
166 |
+
topk = order[:min(len(FEATURES), sv.shape[0])]
|
167 |
plt.close('all')
|
168 |
fig = plt.figure(figsize=(8, 5), dpi=160)
|
169 |
plt.barh(np.array(FEATURES)[topk], sv[topk])
|
|
|
177 |
except Exception as e:
|
178 |
return None, None, f"Fatal error: {repr(e)}"
|
179 |
|
180 |
+
# ====== 示例输入(8 项 + nsamples)======
|
181 |
+
# 给定示例:
|
182 |
+
# ALB=40.7, TP=61.7, TBA=4.9, AST/ALT=1.0, CREA=54.3, PNI=48.6, AAPR=0.473, DBIL=8.5
|
183 |
+
# 由于前端输入仍为原始 6 项 + DBIL(PNI、AAPR 在后端计算),需“还原” LYM 与 ALP:
|
184 |
+
# LYM=(PNI-ALB)/5=(48.6-40.7)/5=1.58; ALP=ALB/AAPR=40.7/0.473≈86.0465
|
185 |
+
example_values = [40.7, 61.7, 4.9, 1.0, 54.3, 1.58, 86.0465, 8.5, 200]
|
186 |
|
187 |
|
188 |
# ====== Gradio 界面 ======
|
189 |
with gr.Blocks() as demo:
|
190 |
gr.Markdown(
|
191 |
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
|
192 |
+
"Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, ALP, and DBIL.\n\n"
|
193 |
"**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
|
194 |
+
"LYM(×10⁹/L), ALP(U/L), DBIL(μmol/L)."
|
195 |
)
|
196 |
|
197 |
with gr.Row():
|
|
|
204 |
gr.Number(label="CREA (μmol/L)"),
|
205 |
gr.Number(label="LYM (×10⁹/L)"),
|
206 |
gr.Number(label="ALP (U/L)"),
|
207 |
+
gr.Number(label="DBIL (μmol/L)"),
|
208 |
]
|
209 |
ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
|
210 |
|
|
|
227 |
)
|
228 |
|
229 |
if __name__ == "__main__":
|
230 |
+
demo.launch()
|