Meige / app.py
Multiple123's picture
Update app.py
f2b3a5b verified
raw
history blame
8.39 kB
# app.py (7-feature aligned, server-safe)
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg") # 非交互后端,服务器端更稳
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# ====== 模型与背景数据 ======
MODEL_PATH = "models/svm_pipeline.joblib"
BG_PATH = "data/bg.csv"
# 模型最终需要的 7 个特征(顺序必须与训练一致)
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR"]
# 加载模型与背景
pipeline = joblib.load(MODEL_PATH)
bg_df = pd.read_csv(BG_PATH)
missing_bg = [c for c in FEATURES if c not in bg_df.columns]
if missing_bg:
raise ValueError(f"背景集缺少列: {missing_bg}")
bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64)
# 预测函数(供 KernelExplainer 调用)——返回正类概率/分数
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
df = pd.DataFrame(x_nd, columns=FEATURES)
# 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict
if hasattr(pipeline, "predict_proba"):
proba = pipeline.predict_proba(df)
# 确定正类索引(假定正类标签为 1;若不是,请在此处修改)
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
return proba[:, pos_idx]
elif hasattr(pipeline, "decision_function"):
score = pipeline.decision_function(df)
return score if isinstance(score, np.ndarray) else np.asarray(score)
else:
pred = pipeline.predict(df)
return pred if isinstance(pred, np.ndarray) else np.asarray(pred)
# 只初始化一次 explainer(性能更稳)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
"""返回 matplotlib Figure(旧接口,服务器端稳定)"""
plt.close('all')
shap.force_plot(
base_val,
np.asarray(shap_1d).reshape(-1),
np.asarray(feat_1d).reshape(-1),
feature_names=list(fnames),
matplotlib=True, show=False
)
fig = plt.gcf()
fig.set_size_inches(8, 4)
plt.tight_layout()
return fig
def _coerce_float(x):
return float(x) if x is not None and x != "" else np.nan
def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
status = []
try:
# ---- 1) 取数并校验 ----
ALB = _coerce_float(ALB)
TP = _coerce_float(TP)
TBA = _coerce_float(TBA)
AST_ALT = _coerce_float(AST_ALT)
CREA = _coerce_float(CREA)
LYM = _coerce_float(LYM)
ALP = _coerce_float(ALP)
vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP]
if any(np.isnan(v) for v in vals):
return None, None, "Error: 所有输入必须为数值且不可缺失。"
if ALP <= 0:
return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
# ---- 2) 衍生指标 ----
PNI = ALB + 5.0 * LYM
AAPR = ALB / ALP
status.append(f"Derived: PNI={PNI:.3f}, AAPR={AAPR:.3f}")
# ---- 3) 组装最终 7 特征并预测 ----
x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR]], dtype=np.float64)
if hasattr(pipeline, "predict_proba"):
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx])
status.append(f"Pred prob: {prob:.3f}")
else:
# 若无概率,给出分数
score = float(
pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0]
) if hasattr(pipeline, "decision_function") else float(
pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0]
)
prob = score
status.append(f"Pred score: {score:.3f}")
# ---- 4) SHAP 计算 ----
ns = int(nsamples) if nsamples is not None else 200
shap_out = explainer.shap_values(x_row, nsamples=ns)
# 统一提取“一维贡献向量”
if isinstance(shap_out, list):
# 二分类:list 长度=2,取正类
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
sv = np.asarray(shap_out[pos_idx], dtype=np.float64)
if sv.ndim == 2: # (1, n_features)
sv = sv[0, :]
else:
sv = np.asarray(shap_out, dtype=np.float64)
if sv.ndim == 3: # (1, n_features, n_classes)
sv = sv[0, :, 1]
elif sv.ndim == 2: # (1, n_features)
sv = sv[0, :]
else:
sv = sv.reshape(-1)
status.append(f"SHAP vector shape: {sv.shape}")
# base value
ev = explainer.expected_value
if isinstance(ev, (list, np.ndarray)):
ev = np.asarray(ev).reshape(-1)
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
base_val = float(ev[pos_idx if len(ev) > pos_idx else 0])
else:
base_val = float(ev)
# ---- 5) 绘图(优先 force,失败退条形图)----
try:
fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES)
status.append("Rendered force plot (matplotlib).")
return round(float(prob), 3), fig, "\n".join(status)
except Exception as e_force:
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
order = np.argsort(np.abs(sv))[::-1]
topk = order[:min(7, sv.shape[0])]
plt.close('all')
fig = plt.figure(figsize=(8, 5), dpi=160)
plt.barh(np.array(FEATURES)[topk], sv[topk])
plt.xlabel("SHAP value")
plt.title("Top features (single-sample contribution)")
plt.gca().invert_yaxis()
plt.tight_layout()
status.append("Rendered bar fallback.")
return round(float(prob), 3), fig, "\n".join(status)
except Exception as e:
return None, None, f"Fatal error: {repr(e)}"
# ====== 示例输入(仅 7 项 + nsamples)======
example_values = [38.0, 68.0, 6.5, 1.0, 75.0, 1.2, 80.0, 200]
# 顺序:ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples
# 注:上例将派生 PNI=ALB+5*LYM=44、AAPR=ALB/ALP=0.475,与训练对齐
# ====== Gradio 界面 ======
with gr.Blocks() as demo:
gr.Markdown(
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
"输入 **ALB, TP, TBA, AST/ALT, CREA, LYM, ALP**;应用会内部计算 **PNI=ALB+5×LYM** 与 **AAPR=ALB/ALP**,"
"并以这 7 个最终特征喂给模型和 SHAP。\n\n"
"**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
"LYM(×10⁹/L), ALP(U/L)."
)
with gr.Row():
with gr.Column(scale=1):
inputs = [
gr.Number(label="ALB (g/L)"),
gr.Number(label="TP (g/L)"),
gr.Number(label="TBA (μmol/L)"),
gr.Number(label="AST/ALT"),
gr.Number(label="CREA (μmol/L)"),
gr.Number(label="LYM (×10⁹/L)"),
gr.Number(label="ALP (U/L)"),
]
ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
btn_fill = gr.Button("Fill Example")
btn_predict = gr.Button("Predict")
with gr.Column(scale=1):
out_prob = gr.Number(label="Predicted Probability / Score")
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
out_log = gr.Textbox(label="Status", lines=8)
def _fill_example():
return tuple(example_values)
btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
btn_predict.click(
fn=predict_and_explain,
inputs=[*inputs, ns_slider],
outputs=[out_prob, out_plot, out_log]
)
if __name__ == "__main__":
demo.launch()