import gradio as gr import pandas as pd import joblib import shap import numpy as np import matplotlib matplotlib.use("Agg") # 非交互后端,服务器端更稳 import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") # ====== 模型与背景数据 ====== MODEL_PATH = "models/svm_pipeline.joblib" BG_PATH = "data/bg.csv" # 模型最终需要的 8 个特征(顺序必须与训练一致) FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR", "DBIL"] # 加载模型与背景 pipeline = joblib.load(MODEL_PATH) bg_df = pd.read_csv(BG_PATH) missing_bg = [c for c in FEATURES if c not in bg_df.columns] if missing_bg: raise ValueError(f"背景集缺少列: {missing_bg}") bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64) # 预测函数(供 KernelExplainer 调用)——返回正类概率/分数 def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray: df = pd.DataFrame(x_nd, columns=FEATURES) # 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict if hasattr(pipeline, "predict_proba"): proba = pipeline.predict_proba(df) # 确定正类索引(假定正类标签为 1;若不是,请在此处修改) classes_ = getattr(pipeline, "classes_", None) pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 return proba[:, pos_idx] elif hasattr(pipeline, "decision_function"): score = pipeline.decision_function(df) return score if isinstance(score, np.ndarray) else np.asarray(score) else: pred = pipeline.predict(df) return pred if isinstance(pred, np.ndarray) else np.asarray(pred) # 只初始化一次 explainer(性能更稳) explainer = shap.KernelExplainer(_predict_proba_nd, bg_array) def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames): """返回 matplotlib Figure(旧接口,服务器端稳定)""" plt.close('all') # ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位) feat = np.asarray(feat_1d, dtype=float).copy() round_map = { "ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3, "DBIL": 1 } feat_rounded = [ np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames) ] shap.force_plot( base_val, np.asarray(shap_1d).reshape(-1), np.asarray(feat_rounded).reshape(-1), feature_names=list(fnames), matplotlib=True, show=False ) fig = plt.gcf() fig.set_size_inches(14, 3.6) # 加宽、适当降低高度 fig.set_dpi(180) # 提高分辨率 plt.tight_layout(pad=0.6) # 更紧凑 return fig def _coerce_float(x): return float(x) if x is not None and x != "" else np.nan def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL, nsamples=200): status = [] try: # ---- 1) 取数并校验 ---- ALB = _coerce_float(ALB) TP = _coerce_float(TP) TBA = _coerce_float(TBA) AST_ALT = _coerce_float(AST_ALT) CREA = _coerce_float(CREA) LYM = _coerce_float(LYM) ALP = _coerce_float(ALP) DBIL = _coerce_float(DBIL) vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL] if any(np.isnan(v) for v in vals): return None, None, "Error: 所有输入必须为数值且不可缺失。" if ALP <= 0: return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。" if DBIL < 0: return None, None, "Error: DBIL 不能为负值。" # ---- 2) 衍生指标 ---- PNI = ALB + 5.0 * LYM AAPR = ALB / ALP status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.3f}") # ---- 3) 组装最终 8 特征并预测 ---- x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR, DBIL]], dtype=np.float64) if hasattr(pipeline, "predict_proba"): classes_ = getattr(pipeline, "classes_", None) pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx]) status.append(f"Pred prob: {prob:.3f}") else: # 若无概率,给出分数 score = float( pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0] ) if hasattr(pipeline, "decision_function") else float( pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0] ) prob = score status.append(f"Pred score: {score:.3f}") # ---- 4) SHAP 计算 ---- ns = int(nsamples) if nsamples is not None else 200 shap_out = explainer.shap_values(x_row, nsamples=ns) # 统一提取“一维贡献向量” if isinstance(shap_out, list): # 二分类:list 长度=2,取正类 classes_ = getattr(pipeline, "classes_", None) pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 sv = np.asarray(shap_out[pos_idx], dtype=np.float64) if sv.ndim == 2: # (1, n_features) sv = sv[0, :] else: sv = np.asarray(shap_out, dtype=np.float64) if sv.ndim == 3: # (1, n_features, n_classes) sv = sv[0, :, 1] elif sv.ndim == 2: # (1, n_features) sv = sv[0, :] else: sv = sv.reshape(-1) status.append(f"SHAP vector shape: {sv.shape}") # base value ev = explainer.expected_value if isinstance(ev, (list, np.ndarray)): ev = np.asarray(ev).reshape(-1) classes_ = getattr(pipeline, "classes_", None) pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 base_val = float(ev[pos_idx if len(ev) > pos_idx else 0]) else: base_val = float(ev) # ---- 5) 绘图(优先 force,失败退条形图)---- try: fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES) status.append("Rendered force plot (matplotlib).") return round(float(prob), 3), fig, "\n".join(status) except Exception as e_force: status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar") order = np.argsort(np.abs(sv))[::-1] topk = order[:min(len(FEATURES), sv.shape[0])] plt.close('all') fig = plt.figure(figsize=(8, 5), dpi=160) plt.barh(np.array(FEATURES)[topk], sv[topk]) plt.xlabel("SHAP value") plt.title("Top features (single-sample contribution)") plt.gca().invert_yaxis() plt.tight_layout() status.append("Rendered bar fallback.") return round(float(prob), 3), fig, "\n".join(status) except Exception as e: return None, None, f"Fatal error: {repr(e)}" # ====== 示例输入(8 项 + nsamples)====== # 给定示例: # ALB=40.7, TP=61.7, TBA=4.9, AST/ALT=1.0, CREA=54.3, PNI=48.6, AAPR=0.473, DBIL=8.5 # 由于前端输入仍为原始 6 项 + DBIL(PNI、AAPR 在后端计算),需“还原” LYM 与 ALP: example_values = [40.7, 61.7, 4.9, 1.0, 54.3, 1.58, 86, 8.5, 200] # ====== Gradio 界面 ====== with gr.Blocks() as demo: gr.Markdown( "### Meige Risk Prediction (SVM) with SHAP Explanation\n" "Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, ALP, and DBIL.\n\n" "**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), " "LYM(×10⁹/L), ALP(U/L), DBIL(μmol/L)." ) with gr.Row(): with gr.Column(scale=1): inputs = [ gr.Number(label="ALB (g/L)"), gr.Number(label="TP (g/L)"), gr.Number(label="TBA (μmol/L)"), gr.Number(label="AST/ALT"), gr.Number(label="CREA (μmol/L)"), gr.Number(label="LYM (×10⁹/L)"), gr.Number(label="ALP (U/L)"), gr.Number(label="DBIL (μmol/L)"), ] ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples") btn_fill = gr.Button("Fill Example") btn_predict = gr.Button("Predict") with gr.Column(scale=1): out_prob = gr.Number(label="Predicted Probability / Score") out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)") out_log = gr.Textbox(label="Status", lines=8) def _fill_example(): return tuple(example_values) btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider]) btn_predict.click( fn=predict_and_explain, inputs=[*inputs, ns_slider], outputs=[out_prob, out_plot, out_log] ) if __name__ == "__main__": demo.launch()