# Gradio.py import gradio as gr import pandas as pd import joblib import shap import numpy as np import matplotlib matplotlib.use("Agg") # 无交互后端更稳 import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") # ====== 路径与特征 ====== MODEL_PATH = "models/SVM_pipeline.pkl" BG_PATH = "data/bg.csv" feature_labels = [ "HGB (g/L)", "HDL-C (mmol/L)", "DBIL (μmol/L)", "AST/ALT", "UA (μmol/L)", "GFR (mL/min/1.73 m²)", "PNI", "HALP", "AAPR", "conuts" ] feature_names = ["HGB","HDL_C","DBIL","AST_ALT","UA","GFR","PNI","HALP","AAPR","conuts"] # ====== 加载模型和背景 ====== pipeline = joblib.load(MODEL_PATH) bg_df = pd.read_csv(BG_PATH) bg_array = bg_df[feature_names].to_numpy(dtype=np.float64) # ====== 全局 KernelExplainer(只建一次) ====== def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray: df = pd.DataFrame(x_nd, columns=feature_names) return pipeline.predict_proba(df) explainer = shap.KernelExplainer(_predict_proba_nd, bg_array) def predict_and_shap(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts, nsamples=200): status_msgs = [] try: # 1) 输入与补全 input_df = pd.DataFrame([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], columns=feature_names).apply(pd.to_numeric, errors="coerce") if input_df.isnull().values.any(): med = pd.Series(np.median(bg_array, axis=0), index=feature_names) input_df = input_df.fillna(med) status_msgs.append("Missing values filled with background medians.") # 2) 概率 prob = float(pipeline.predict_proba(input_df)[0, 1]) status_msgs.append(f"Pred prob computed: {prob:.3f}") # 3) SHAP x_row = input_df.to_numpy(dtype=np.float64) # (1, n_features) shap_out = explainer.shap_values(x_row, nsamples=int(nsamples)) # —— 统一提取“正类”一维向量 (n_features,) —— if isinstance(shap_out, list): sv = np.asarray(shap_out[1], dtype=np.float64) if sv.ndim == 2: sv = sv[0, :] else: sv = np.asarray(shap_out, dtype=np.float64) if sv.ndim == 3: # (1, n_features, n_classes) sv = sv[0, :, 1] # 正类通道 elif sv.ndim == 2: # (1, n_features) sv = sv[0, :] else: sv = sv.reshape(-1) x_1d = x_row[0, :].astype(np.float64) status_msgs.append(f"SHAP 1D shape: {sv.shape}; features: {x_1d.shape}") # base value 取正类 ev = explainer.expected_value if isinstance(ev, (list, np.ndarray)): ev = np.asarray(ev).reshape(-1) base_val = float(ev[1] if len(ev) > 1 else ev[0]) else: base_val = float(ev) fnames = [str(f) for f in feature_names] # 4) 力图(关键:不要先建 fig;让 SHAP 画完后用 plt.gcf() 接回真正的 Figure) try: plt.close('all') # 清理历史句柄,防串扰 shap.force_plot(base_val, sv, x_1d, feature_names=fnames, matplotlib=True, show=False) fig = plt.gcf() # 取 SHAP 实际绘制的 Figure fig.set_size_inches(8, 4) # 调整尺寸 plt.tight_layout() status_msgs.append("Rendered force plot (matplotlib) on current figure.") return round(prob, 3), fig, "\n".join(status_msgs) except Exception as e_force: status_msgs.append(f"Force-plot failed: {repr(e_force)}; fallback=bar") # 5) 条形图兜底(返回实际 fig) order = np.argsort(np.abs(sv))[::-1] topk = order[:min(10, sv.shape[0])] plt.close('all') fig = plt.figure(figsize=(8, 5), dpi=160) plt.barh(np.array(fnames)[topk], sv[topk]) plt.xlabel("SHAP value") plt.title("Top features (single-sample contribution)") plt.gca().invert_yaxis() plt.tight_layout() status_msgs.append("Rendered bar fallback.") return round(prob, 3), fig, "\n".join(status_msgs) except Exception as e: return None, None, f"Fatal error: {repr(e)}" # ====== Blocks 界面 ====== example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 44, 60.8, 0.486, 4, 200] with gr.Blocks() as demo: gr.Markdown( "### SVM Meige Risk Prediction & SHAP Explanation\n" "Enter 10 indicators with **units** to predict risk and view an individualized explanation.\n\n" "**Example**: HGB=137 g/L, HDL‑C=1.76 mmol/L, DBIL=8.6 μmol/L, AST/ALT=0.97, UA=310 μmol/L, " "GFR=75.4 mL/min/1.73 m², PNI=44, HALP=60.8, AAPR=0.486, conuts=4." ) with gr.Row(): with gr.Column(scale=1): num_inputs = [gr.Number(label=feature_labels[i], precision=3) for i in range(10)] ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples") btn_fill = gr.Button("Fill with Example") btn_predict = gr.Button("Predict") with gr.Column(scale=1): out_prob = gr.Number(label="Predicted Probability") out_plot = gr.Plot(label="SHAP Force (fallback: bar)") # 改成 Plot out_log = gr.Textbox(label="Status", lines=6) def fill_example(): return tuple(example_values) fill_evt = btn_fill.click(fn=fill_example, outputs=[*num_inputs, ns_slider]) fill_evt.then(fn=predict_and_shap, inputs=[*num_inputs, ns_slider], outputs=[out_prob, out_plot, out_log]) btn_predict.click(fn=predict_and_shap, inputs=[*num_inputs, ns_slider], outputs=[out_prob, out_plot, out_log]) if __name__ == "__main__": demo.launch() # 不要写 server_port / share