Meige / app.py
Multiple123's picture
Update app.py
7aa125b verified
# app.py (robust, server-safe)
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg") # 非交互后端,服务器端更稳
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# ====== 模型与背景数据 ======
MODEL_PATH = "models/SVM_pipeline.pkl"
BG_PATH = "data/bg.csv"
# 模型最终需要的10个特征(顺序必须与训练一致)
feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
# 加载模型与背景
pipeline = joblib.load(MODEL_PATH)
bg_df = pd.read_csv(BG_PATH)
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
# 预测函数(供 KernelExplainer 调用)
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
df = pd.DataFrame(x_nd, columns=feature_names)
return pipeline.predict_proba(df)
# 只初始化一次 explainer(性能更稳)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
"""返回 matplotlib Figure(旧接口,服务器端稳定)"""
plt.close('all')
shap.force_plot(
base_val, np.asarray(shap_1d).reshape(-1), np.asarray(feat_1d).reshape(-1),
feature_names=list(fnames), matplotlib=True, show=False
)
fig = plt.gcf()
fig.set_size_inches(8, 4)
plt.tight_layout()
return fig
def predict_and_explain(
HGB, HDL_C, DBIL, AST_ALT, UA, GFR,
ALB, LYM, PLT, ALP, CHOL,
nsamples=200
):
status = []
try:
# ---- 1) 衍生指标(由原始输入计算)----
try:
HGB = float(HGB); HDL_C = float(HDL_C); DBIL = float(DBIL); AST_ALT = float(AST_ALT)
UA = float(UA); GFR = float(GFR)
ALB = float(ALB); LYM = float(LYM); PLT = float(PLT)
ALP = float(ALP); CHOL = float(CHOL)
except Exception:
return None, None, "Error: some inputs are not numeric."
# 防极端值(避免除0)
if PLT <= 0 or ALP <= 0:
return None, None, "Error: PLT and ALP must be > 0."
PNI = ALB + 5.0 * LYM
HALP = HGB * ALB * LYM / PLT
AAPR = ALB / ALP
conuts = (
(0 if ALB >= 35 else 2 if ALB >= 30 else 4 if ALB >= 25 else 6) +
(0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
(0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
)
x_row = np.array([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], dtype=np.float64)
status.append(f"Derived: PNI={PNI:.3f}, HALP={HALP:.3f}, AAPR={AAPR:.3f}, CONUTS={conuts}")
# ---- 2) 概率 ----
prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=feature_names))[0, 1])
status.append(f"Pred prob computed: {prob:.3f}")
# ---- 3) SHAP 计算 ----
ns = int(nsamples) if nsamples is not None else 200
shap_out = explainer.shap_values(x_row, nsamples=ns)
# 统一提取“正类”一维向量
if isinstance(shap_out, list):
sv = np.asarray(shap_out[1], dtype=np.float64)
if sv.ndim == 2:
sv = sv[0, :]
else:
sv = np.asarray(shap_out, dtype=np.float64)
if sv.ndim == 3: # (1, n_features, n_classes)
sv = sv[0, :, 1]
elif sv.ndim == 2: # (1, n_features)
sv = sv[0, :]
else:
sv = sv.reshape(-1)
status.append(f"SHAP 1D shape: {sv.shape}; features: {x_row.shape[1:]}")
# base value 取正类
ev = explainer.expected_value
if isinstance(ev, (list, np.ndarray)):
ev = np.asarray(ev).reshape(-1)
base_val = float(ev[1] if len(ev) > 1 else ev[0])
else:
base_val = float(ev)
# ---- 4) 绘图:优先力图;失败则条形图兜底 ----
try:
fig = _render_force_plot(base_val, sv, x_row[0, :], feature_names)
status.append("Rendered force plot (matplotlib).")
return round(prob, 3), fig, "\n".join(status)
except Exception as e_force:
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
order = np.argsort(np.abs(sv))[::-1]
topk = order[:min(10, sv.shape[0])]
plt.close('all')
fig = plt.figure(figsize=(8, 5), dpi=160)
plt.barh(np.array(feature_names)[topk], sv[topk])
plt.xlabel("SHAP value")
plt.title("Top features (single-sample contribution)")
plt.gca().invert_yaxis()
plt.tight_layout()
status.append("Rendered bar fallback.")
return round(prob, 3), fig, "\n".join(status)
except Exception as e:
return None, None, f"Fatal error: {repr(e)}"
# ====== 示例:一组“原始指标”可复现你之前的 PNI/HALP/AAPR/CONUTS ======
# 对应:PNI=44, HALP≈60.8, AAPR≈0.486, CONUTS=4
example_values = [167, 1.76, 8.6, 0.97, 310, 75, 33, 2.2, 164, 68, 2.8, 200]
# 顺序:HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples
# ====== Gradio 界面 ======
with gr.Blocks() as demo:
gr.Markdown(
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
"Enter **original clinical indicators**; the app will derive PNI/HALP/AAPR/CONUTS internally.\n\n"
"**Units**: HGB (g/L), HDL‑C (mmol/L), DBIL (μmol/L), AST/ALT (ratio), UA (μmol/L), "
"GFR (mL/min/1.73 m²), ALB (g/L), LYM (×10⁹/L), PLT (×10⁹/L), ALP (U/L), CHOL (mmol/L)."
)
with gr.Row():
with gr.Column(scale=1):
inputs = [
gr.Number(label="HGB (g/L)"),
gr.Number(label="HDL-C (mmol/L)"),
gr.Number(label="DBIL (μmol/L)"),
gr.Number(label="AST/ALT"),
gr.Number(label="UA (μmol/L)"),
gr.Number(label="GFR (mL/min/1.73 m²)"),
gr.Number(label="ALB (g/L)"),
gr.Number(label="LYM (×10⁹/L)"),
gr.Number(label="PLT (×10⁹/L)"),
gr.Number(label="ALP (U/L)"),
gr.Number(label="CHOL (mmol/L)")
]
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
btn_fill = gr.Button("Fill Example")
btn_predict = gr.Button("Predict")
with gr.Column(scale=1):
out_prob = gr.Number(label="Predicted Probability")
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
out_log = gr.Textbox(label="Status", lines=6)
def _fill_example():
return tuple(example_values)
btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
btn_predict.click(
fn=predict_and_explain,
inputs=[*inputs, ns_slider],
outputs=[out_prob, out_plot, out_log]
)
if __name__ == "__main__":
demo.launch()