Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,34 +1,68 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
import joblib
|
4 |
import shap
|
5 |
import numpy as np
|
6 |
import matplotlib
|
7 |
-
matplotlib.use("Agg")
|
8 |
import matplotlib.pyplot as plt
|
9 |
import warnings
|
10 |
warnings.filterwarnings("ignore")
|
11 |
|
12 |
# ====== 模型与背景数据 ======
|
13 |
MODEL_PATH = "models/SVM_pipeline.pkl"
|
14 |
-
BG_PATH
|
15 |
|
|
|
16 |
feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
|
17 |
|
|
|
18 |
pipeline = joblib.load(MODEL_PATH)
|
19 |
bg_df = pd.read_csv(BG_PATH)
|
20 |
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
|
21 |
|
22 |
-
|
|
|
23 |
df = pd.DataFrame(x_nd, columns=feature_names)
|
24 |
return pipeline.predict_proba(df)
|
25 |
|
|
|
26 |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
|
27 |
|
28 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
try:
|
30 |
-
#
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
HALP = HGB * ALB * LYM / PLT
|
33 |
AAPR = ALB / ALP
|
34 |
conuts = (
|
@@ -36,36 +70,79 @@ def predict_and_explain(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP,
|
|
36 |
(0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
|
37 |
(0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
|
38 |
)
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
prob = float(pipeline.predict_proba(
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
plt.close('all')
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
plt.tight_layout()
|
55 |
-
|
56 |
-
return round(prob, 3), fig, "
|
57 |
|
58 |
except Exception as e:
|
59 |
-
return None, None, f"
|
60 |
|
|
|
|
|
61 |
example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 33, 2.2, 164, 67.9, 2.8, 200]
|
|
|
62 |
|
63 |
-
|
64 |
with gr.Blocks() as demo:
|
65 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
with gr.Row():
|
68 |
-
with gr.Column():
|
69 |
inputs = [
|
70 |
gr.Number(label="HGB (g/L)"),
|
71 |
gr.Number(label="HDL-C (mmol/L)"),
|
@@ -80,11 +157,24 @@ with gr.Blocks() as demo:
|
|
80 |
gr.Number(label="CHOL (mmol/L)")
|
81 |
]
|
82 |
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
|
83 |
-
|
84 |
-
gr.Button("
|
85 |
-
|
86 |
-
|
87 |
-
with gr.Column():
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py (robust, server-safe)
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import joblib
|
5 |
import shap
|
6 |
import numpy as np
|
7 |
import matplotlib
|
8 |
+
matplotlib.use("Agg") # 非交互后端,服务器端更稳
|
9 |
import matplotlib.pyplot as plt
|
10 |
import warnings
|
11 |
warnings.filterwarnings("ignore")
|
12 |
|
13 |
# ====== 模型与背景数据 ======
|
14 |
MODEL_PATH = "models/SVM_pipeline.pkl"
|
15 |
+
BG_PATH = "data/bg.csv"
|
16 |
|
17 |
+
# 模型最终需要的10个特征(顺序必须与训练一致)
|
18 |
feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
|
19 |
|
20 |
+
# 加载模型与背景
|
21 |
pipeline = joblib.load(MODEL_PATH)
|
22 |
bg_df = pd.read_csv(BG_PATH)
|
23 |
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
|
24 |
|
25 |
+
# 预测函数(供 KernelExplainer 调用)
|
26 |
+
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
|
27 |
df = pd.DataFrame(x_nd, columns=feature_names)
|
28 |
return pipeline.predict_proba(df)
|
29 |
|
30 |
+
# 只初始化一次 explainer(性能更稳)
|
31 |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
|
32 |
|
33 |
+
def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
|
34 |
+
"""返回 matplotlib Figure(旧接口,服务器端稳定)"""
|
35 |
+
plt.close('all')
|
36 |
+
shap.force_plot(
|
37 |
+
base_val, np.asarray(shap_1d).reshape(-1), np.asarray(feat_1d).reshape(-1),
|
38 |
+
feature_names=list(fnames), matplotlib=True, show=False
|
39 |
+
)
|
40 |
+
fig = plt.gcf()
|
41 |
+
fig.set_size_inches(8, 4)
|
42 |
+
plt.tight_layout()
|
43 |
+
return fig
|
44 |
+
|
45 |
+
def predict_and_explain(
|
46 |
+
HGB, HDL_C, DBIL, AST_ALT, UA, GFR,
|
47 |
+
ALB, LYM, PLT, ALP, CHOL,
|
48 |
+
nsamples=200
|
49 |
+
):
|
50 |
+
status = []
|
51 |
try:
|
52 |
+
# ---- 1) 衍生指标(由原始输入计算)----
|
53 |
+
try:
|
54 |
+
HGB = float(HGB); HDL_C = float(HDL_C); DBIL = float(DBIL); AST_ALT = float(AST_ALT)
|
55 |
+
UA = float(UA); GFR = float(GFR)
|
56 |
+
ALB = float(ALB); LYM = float(LYM); PLT = float(PLT)
|
57 |
+
ALP = float(ALP); CHOL = float(CHOL)
|
58 |
+
except Exception:
|
59 |
+
return None, None, "Error: some inputs are not numeric."
|
60 |
+
|
61 |
+
# 防极端值(避免除0)
|
62 |
+
if PLT <= 0 or ALP <= 0:
|
63 |
+
return None, None, "Error: PLT and ALP must be > 0."
|
64 |
+
|
65 |
+
PNI = ALB + 5.0 * LYM
|
66 |
HALP = HGB * ALB * LYM / PLT
|
67 |
AAPR = ALB / ALP
|
68 |
conuts = (
|
|
|
70 |
(0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
|
71 |
(0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
|
72 |
)
|
73 |
+
x_row = np.array([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], dtype=np.float64)
|
74 |
+
status.append(f"Derived: PNI={PNI:.3f}, HALP={HALP:.3f}, AAPR={AAPR:.3f}, CONUTS={conuts}")
|
75 |
+
|
76 |
+
# ---- 2) 概率 ----
|
77 |
+
prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=feature_names))[0, 1])
|
78 |
+
status.append(f"Pred prob computed: {prob:.3f}")
|
79 |
+
|
80 |
+
# ---- 3) SHAP 计算 ----
|
81 |
+
ns = int(nsamples) if nsamples is not None else 200
|
82 |
+
shap_out = explainer.shap_values(x_row, nsamples=ns)
|
83 |
+
|
84 |
+
# 统一提取“正类”一维向量
|
85 |
+
if isinstance(shap_out, list):
|
86 |
+
sv = np.asarray(shap_out[1], dtype=np.float64)
|
87 |
+
if sv.ndim == 2:
|
88 |
+
sv = sv[0, :]
|
89 |
+
else:
|
90 |
+
sv = np.asarray(shap_out, dtype=np.float64)
|
91 |
+
if sv.ndim == 3: # (1, n_features, n_classes)
|
92 |
+
sv = sv[0, :, 1]
|
93 |
+
elif sv.ndim == 2: # (1, n_features)
|
94 |
+
sv = sv[0, :]
|
95 |
+
else:
|
96 |
+
sv = sv.reshape(-1)
|
97 |
+
status.append(f"SHAP 1D shape: {sv.shape}; features: {x_row.shape[1:]}")
|
98 |
+
|
99 |
+
# base value 取正类
|
100 |
+
ev = explainer.expected_value
|
101 |
+
if isinstance(ev, (list, np.ndarray)):
|
102 |
+
ev = np.asarray(ev).reshape(-1)
|
103 |
+
base_val = float(ev[1] if len(ev) > 1 else ev[0])
|
104 |
+
else:
|
105 |
+
base_val = float(ev)
|
106 |
+
|
107 |
+
# ---- 4) 绘图:优先力图;失败则条形图兜底 ----
|
108 |
+
try:
|
109 |
+
fig = _render_force_plot(base_val, sv, x_row[0, :], feature_names)
|
110 |
+
status.append("Rendered force plot (matplotlib).")
|
111 |
+
return round(prob, 3), fig, "\n".join(status)
|
112 |
+
except Exception as e_force:
|
113 |
+
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
|
114 |
+
|
115 |
+
order = np.argsort(np.abs(sv))[::-1]
|
116 |
+
topk = order[:min(10, sv.shape[0])]
|
117 |
plt.close('all')
|
118 |
+
fig = plt.figure(figsize=(8, 5), dpi=160)
|
119 |
+
plt.barh(np.array(feature_names)[topk], sv[topk])
|
120 |
+
plt.xlabel("SHAP value")
|
121 |
+
plt.title("Top features (single-sample contribution)")
|
122 |
+
plt.gca().invert_yaxis()
|
123 |
plt.tight_layout()
|
124 |
+
status.append("Rendered bar fallback.")
|
125 |
+
return round(prob, 3), fig, "\n".join(status)
|
126 |
|
127 |
except Exception as e:
|
128 |
+
return None, None, f"Fatal error: {repr(e)}"
|
129 |
|
130 |
+
# ====== 示例:一组“原始指标”可复现你之前的 PNI/HALP/AAPR/CONUTS ======
|
131 |
+
# 对应:PNI=44, HALP≈60.8, AAPR≈0.486, CONUTS=4
|
132 |
example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 33, 2.2, 164, 67.9, 2.8, 200]
|
133 |
+
# 顺序:HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples
|
134 |
|
135 |
+
# ====== Gradio 界面 ======
|
136 |
with gr.Blocks() as demo:
|
137 |
+
gr.Markdown(
|
138 |
+
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
|
139 |
+
"Enter **original clinical indicators**; the app will derive PNI/HALP/AAPR/CONUTS internally.\n\n"
|
140 |
+
"**Units**: HGB (g/L), HDL‑C (mmol/L), DBIL (μmol/L), AST/ALT (ratio), UA (μmol/L), "
|
141 |
+
"GFR (mL/min/1.73 m²), ALB (g/L), LYM (×10⁹/L), PLT (×10⁹/L), ALP (U/L), CHOL (mmol/L)."
|
142 |
+
)
|
143 |
|
144 |
with gr.Row():
|
145 |
+
with gr.Column(scale=1):
|
146 |
inputs = [
|
147 |
gr.Number(label="HGB (g/L)"),
|
148 |
gr.Number(label="HDL-C (mmol/L)"),
|
|
|
157 |
gr.Number(label="CHOL (mmol/L)")
|
158 |
]
|
159 |
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
|
160 |
+
|
161 |
+
btn_fill = gr.Button("Fill Example")
|
162 |
+
btn_predict = gr.Button("Predict")
|
163 |
+
|
164 |
+
with gr.Column(scale=1):
|
165 |
+
out_prob = gr.Number(label="Predicted Probability")
|
166 |
+
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
|
167 |
+
out_log = gr.Textbox(label="Status", lines=6)
|
168 |
+
|
169 |
+
def _fill_example():
|
170 |
+
return tuple(example_values)
|
171 |
+
|
172 |
+
btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
|
173 |
+
btn_predict.click(
|
174 |
+
fn=predict_and_explain,
|
175 |
+
inputs=[*inputs, ns_slider],
|
176 |
+
outputs=[out_prob, out_plot, out_log]
|
177 |
+
)
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
demo.launch()
|