Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py (
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import joblib
|
@@ -11,21 +11,37 @@ import warnings
|
|
11 |
warnings.filterwarnings("ignore")
|
12 |
|
13 |
# ====== 模型与背景数据 ======
|
14 |
-
MODEL_PATH = "models/
|
15 |
BG_PATH = "data/bg.csv"
|
16 |
|
17 |
-
# 模型最终需要的
|
18 |
-
|
19 |
|
20 |
# 加载模型与背景
|
21 |
pipeline = joblib.load(MODEL_PATH)
|
|
|
22 |
bg_df = pd.read_csv(BG_PATH)
|
23 |
-
|
|
|
|
|
|
|
24 |
|
25 |
-
# 预测函数(供 KernelExplainer
|
26 |
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
|
27 |
-
df = pd.DataFrame(x_nd, columns=
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
# 只初始化一次 explainer(性能更稳)
|
31 |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
|
@@ -34,57 +50,73 @@ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray
|
|
34 |
"""返回 matplotlib Figure(旧接口,服务器端稳定)"""
|
35 |
plt.close('all')
|
36 |
shap.force_plot(
|
37 |
-
base_val,
|
38 |
-
|
|
|
|
|
|
|
39 |
)
|
40 |
fig = plt.gcf()
|
41 |
fig.set_size_inches(8, 4)
|
42 |
plt.tight_layout()
|
43 |
return fig
|
44 |
|
45 |
-
def
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
):
|
50 |
status = []
|
51 |
try:
|
52 |
-
# ---- 1)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
if
|
63 |
-
return None, None, "Error:
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
PNI = ALB + 5.0 * LYM
|
66 |
-
HALP = HGB * ALB * LYM / PLT
|
67 |
AAPR = ALB / ALP
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
ns = int(nsamples) if nsamples is not None else 200
|
82 |
shap_out = explainer.shap_values(x_row, nsamples=ns)
|
83 |
|
84 |
-
#
|
85 |
if isinstance(shap_out, list):
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
sv = sv[0, :]
|
89 |
else:
|
90 |
sv = np.asarray(shap_out, dtype=np.float64)
|
@@ -94,77 +126,75 @@ def predict_and_explain(
|
|
94 |
sv = sv[0, :]
|
95 |
else:
|
96 |
sv = sv.reshape(-1)
|
97 |
-
status.append(f"SHAP
|
98 |
|
99 |
-
# base value
|
100 |
ev = explainer.expected_value
|
101 |
if isinstance(ev, (list, np.ndarray)):
|
102 |
ev = np.asarray(ev).reshape(-1)
|
103 |
-
|
|
|
|
|
104 |
else:
|
105 |
base_val = float(ev)
|
106 |
|
107 |
-
# ----
|
108 |
try:
|
109 |
-
fig = _render_force_plot(base_val, sv, x_row[0, :],
|
110 |
status.append("Rendered force plot (matplotlib).")
|
111 |
-
return round(prob, 3), fig, "\n".join(status)
|
112 |
except Exception as e_force:
|
113 |
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
return round(prob, 3), fig, "\n".join(status)
|
126 |
|
127 |
except Exception as e:
|
128 |
return None, None, f"Fatal error: {repr(e)}"
|
129 |
|
130 |
-
# ======
|
131 |
-
|
132 |
-
|
133 |
-
#
|
134 |
|
135 |
# ====== Gradio 界面 ======
|
136 |
with gr.Blocks() as demo:
|
137 |
gr.Markdown(
|
138 |
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
|
139 |
-
"
|
140 |
-
"
|
141 |
-
"
|
|
|
142 |
)
|
143 |
|
144 |
with gr.Row():
|
145 |
with gr.Column(scale=1):
|
146 |
inputs = [
|
147 |
-
gr.Number(label="HGB (g/L)"),
|
148 |
-
gr.Number(label="HDL-C (mmol/L)"),
|
149 |
-
gr.Number(label="DBIL (μmol/L)"),
|
150 |
-
gr.Number(label="AST/ALT"),
|
151 |
-
gr.Number(label="UA (μmol/L)"),
|
152 |
-
gr.Number(label="GFR (mL/min/1.73 m²)"),
|
153 |
gr.Number(label="ALB (g/L)"),
|
|
|
|
|
|
|
|
|
154 |
gr.Number(label="LYM (×10⁹/L)"),
|
155 |
-
gr.Number(label="PLT (×10⁹/L)"),
|
156 |
gr.Number(label="ALP (U/L)"),
|
157 |
-
gr.Number(label="CHOL (mmol/L)")
|
158 |
]
|
159 |
-
ns_slider = gr.Slider(100,
|
160 |
|
161 |
btn_fill = gr.Button("Fill Example")
|
162 |
btn_predict = gr.Button("Predict")
|
163 |
|
164 |
with gr.Column(scale=1):
|
165 |
-
out_prob = gr.Number(label="Predicted Probability")
|
166 |
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
|
167 |
-
out_log = gr.Textbox(label="Status", lines=
|
168 |
|
169 |
def _fill_example():
|
170 |
return tuple(example_values)
|
|
|
1 |
+
# app.py (7-feature aligned, server-safe)
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import joblib
|
|
|
11 |
warnings.filterwarnings("ignore")
|
12 |
|
13 |
# ====== 模型与背景数据 ======
|
14 |
+
MODEL_PATH = "models/svm_pipeline.joblib"
|
15 |
BG_PATH = "data/bg.csv"
|
16 |
|
17 |
+
# 模型最终需要的 7 个特征(顺序必须与训练一致)
|
18 |
+
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR"]
|
19 |
|
20 |
# 加载模型与背景
|
21 |
pipeline = joblib.load(MODEL_PATH)
|
22 |
+
|
23 |
bg_df = pd.read_csv(BG_PATH)
|
24 |
+
missing_bg = [c for c in FEATURES if c not in bg_df.columns]
|
25 |
+
if missing_bg:
|
26 |
+
raise ValueError(f"背景集缺少列: {missing_bg}")
|
27 |
+
bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64)
|
28 |
|
29 |
+
# 预测函数(供 KernelExplainer 调用)——返回正类概率/分数
|
30 |
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
|
31 |
+
df = pd.DataFrame(x_nd, columns=FEATURES)
|
32 |
+
# 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict
|
33 |
+
if hasattr(pipeline, "predict_proba"):
|
34 |
+
proba = pipeline.predict_proba(df)
|
35 |
+
# 确定正类索引(假定正类标签为 1;若不是,请在此处修改)
|
36 |
+
classes_ = getattr(pipeline, "classes_", None)
|
37 |
+
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
|
38 |
+
return proba[:, pos_idx]
|
39 |
+
elif hasattr(pipeline, "decision_function"):
|
40 |
+
score = pipeline.decision_function(df)
|
41 |
+
return score if isinstance(score, np.ndarray) else np.asarray(score)
|
42 |
+
else:
|
43 |
+
pred = pipeline.predict(df)
|
44 |
+
return pred if isinstance(pred, np.ndarray) else np.asarray(pred)
|
45 |
|
46 |
# 只初始化一次 explainer(性能更稳)
|
47 |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
|
|
|
50 |
"""返回 matplotlib Figure(旧接口,服务器端稳定)"""
|
51 |
plt.close('all')
|
52 |
shap.force_plot(
|
53 |
+
base_val,
|
54 |
+
np.asarray(shap_1d).reshape(-1),
|
55 |
+
np.asarray(feat_1d).reshape(-1),
|
56 |
+
feature_names=list(fnames),
|
57 |
+
matplotlib=True, show=False
|
58 |
)
|
59 |
fig = plt.gcf()
|
60 |
fig.set_size_inches(8, 4)
|
61 |
plt.tight_layout()
|
62 |
return fig
|
63 |
|
64 |
+
def _coerce_float(x):
|
65 |
+
return float(x) if x is not None and x != "" else np.nan
|
66 |
+
|
67 |
+
def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
|
|
|
68 |
status = []
|
69 |
try:
|
70 |
+
# ---- 1) 取数并校验 ----
|
71 |
+
ALB = _coerce_float(ALB)
|
72 |
+
TP = _coerce_float(TP)
|
73 |
+
TBA = _coerce_float(TBA)
|
74 |
+
AST_ALT = _coerce_float(AST_ALT)
|
75 |
+
CREA = _coerce_float(CREA)
|
76 |
+
LYM = _coerce_float(LYM)
|
77 |
+
ALP = _coerce_float(ALP)
|
78 |
+
|
79 |
+
vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP]
|
80 |
+
if any(np.isnan(v) for v in vals):
|
81 |
+
return None, None, "Error: 所有输入必须为数值且不可缺失。"
|
82 |
+
|
83 |
+
if ALP <= 0:
|
84 |
+
return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
|
85 |
+
|
86 |
+
# ---- 2) 衍生指标 ----
|
87 |
PNI = ALB + 5.0 * LYM
|
|
|
88 |
AAPR = ALB / ALP
|
89 |
+
status.append(f"Derived: PNI={PNI:.3f}, AAPR={AAPR:.3f}")
|
90 |
+
|
91 |
+
# ---- 3) 组装最终 7 特征并预测 ----
|
92 |
+
x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR]], dtype=np.float64)
|
93 |
+
|
94 |
+
if hasattr(pipeline, "predict_proba"):
|
95 |
+
classes_ = getattr(pipeline, "classes_", None)
|
96 |
+
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
|
97 |
+
prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx])
|
98 |
+
status.append(f"Pred prob: {prob:.3f}")
|
99 |
+
else:
|
100 |
+
# 若无概率,给出分数
|
101 |
+
score = float(
|
102 |
+
pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0]
|
103 |
+
) if hasattr(pipeline, "decision_function") else float(
|
104 |
+
pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0]
|
105 |
+
)
|
106 |
+
prob = score
|
107 |
+
status.append(f"Pred score: {score:.3f}")
|
108 |
+
|
109 |
+
# ---- 4) SHAP 计算 ----
|
110 |
ns = int(nsamples) if nsamples is not None else 200
|
111 |
shap_out = explainer.shap_values(x_row, nsamples=ns)
|
112 |
|
113 |
+
# 统一提取“一维贡献向量”
|
114 |
if isinstance(shap_out, list):
|
115 |
+
# 二分类:list 长度=2,取正类
|
116 |
+
classes_ = getattr(pipeline, "classes_", None)
|
117 |
+
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
|
118 |
+
sv = np.asarray(shap_out[pos_idx], dtype=np.float64)
|
119 |
+
if sv.ndim == 2: # (1, n_features)
|
120 |
sv = sv[0, :]
|
121 |
else:
|
122 |
sv = np.asarray(shap_out, dtype=np.float64)
|
|
|
126 |
sv = sv[0, :]
|
127 |
else:
|
128 |
sv = sv.reshape(-1)
|
129 |
+
status.append(f"SHAP vector shape: {sv.shape}")
|
130 |
|
131 |
+
# base value
|
132 |
ev = explainer.expected_value
|
133 |
if isinstance(ev, (list, np.ndarray)):
|
134 |
ev = np.asarray(ev).reshape(-1)
|
135 |
+
classes_ = getattr(pipeline, "classes_", None)
|
136 |
+
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
|
137 |
+
base_val = float(ev[pos_idx if len(ev) > pos_idx else 0])
|
138 |
else:
|
139 |
base_val = float(ev)
|
140 |
|
141 |
+
# ---- 5) 绘图(优先 force,失败退条形图)----
|
142 |
try:
|
143 |
+
fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES)
|
144 |
status.append("Rendered force plot (matplotlib).")
|
145 |
+
return round(float(prob), 3), fig, "\n".join(status)
|
146 |
except Exception as e_force:
|
147 |
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
|
148 |
+
order = np.argsort(np.abs(sv))[::-1]
|
149 |
+
topk = order[:min(7, sv.shape[0])]
|
150 |
+
plt.close('all')
|
151 |
+
fig = plt.figure(figsize=(8, 5), dpi=160)
|
152 |
+
plt.barh(np.array(FEATURES)[topk], sv[topk])
|
153 |
+
plt.xlabel("SHAP value")
|
154 |
+
plt.title("Top features (single-sample contribution)")
|
155 |
+
plt.gca().invert_yaxis()
|
156 |
+
plt.tight_layout()
|
157 |
+
status.append("Rendered bar fallback.")
|
158 |
+
return round(float(prob), 3), fig, "\n".join(status)
|
|
|
159 |
|
160 |
except Exception as e:
|
161 |
return None, None, f"Fatal error: {repr(e)}"
|
162 |
|
163 |
+
# ====== 示例输入(仅 7 项 + nsamples)======
|
164 |
+
example_values = [38.0, 68.0, 6.5, 1.0, 75.0, 1.2, 80.0, 200]
|
165 |
+
# 顺序:ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples
|
166 |
+
# 注:上例将派生 PNI=ALB+5*LYM=44、AAPR=ALB/ALP=0.475,与训练对齐
|
167 |
|
168 |
# ====== Gradio 界面 ======
|
169 |
with gr.Blocks() as demo:
|
170 |
gr.Markdown(
|
171 |
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
|
172 |
+
"输入 **ALB, TP, TBA, AST/ALT, CREA, LYM, ALP**;应用会内部计算 **PNI=ALB+5×LYM** 与 **AAPR=ALB/ALP**,"
|
173 |
+
"并以这 7 个最终特征喂给模型和 SHAP。\n\n"
|
174 |
+
"**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
|
175 |
+
"LYM(×10⁹/L), ALP(U/L)."
|
176 |
)
|
177 |
|
178 |
with gr.Row():
|
179 |
with gr.Column(scale=1):
|
180 |
inputs = [
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
gr.Number(label="ALB (g/L)"),
|
182 |
+
gr.Number(label="TP (g/L)"),
|
183 |
+
gr.Number(label="TBA (μmol/L)"),
|
184 |
+
gr.Number(label="AST/ALT"),
|
185 |
+
gr.Number(label="CREA (μmol/L)"),
|
186 |
gr.Number(label="LYM (×10⁹/L)"),
|
|
|
187 |
gr.Number(label="ALP (U/L)"),
|
|
|
188 |
]
|
189 |
+
ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
|
190 |
|
191 |
btn_fill = gr.Button("Fill Example")
|
192 |
btn_predict = gr.Button("Predict")
|
193 |
|
194 |
with gr.Column(scale=1):
|
195 |
+
out_prob = gr.Number(label="Predicted Probability / Score")
|
196 |
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
|
197 |
+
out_log = gr.Textbox(label="Status", lines=8)
|
198 |
|
199 |
def _fill_example():
|
200 |
return tuple(example_values)
|