Spaces:
Running
Running
# Gradio.py | |
import gradio as gr | |
import pandas as pd | |
import joblib | |
import shap | |
import numpy as np | |
import matplotlib | |
matplotlib.use("Agg") # 无交互后端更稳 | |
import matplotlib.pyplot as plt | |
import warnings | |
warnings.filterwarnings("ignore") | |
# ====== 路径与特征 ====== | |
MODEL_PATH = "models/SVM_pipeline.pkl" | |
BG_PATH = "data/bg.csv" | |
feature_labels = [ | |
"HGB (g/L)", "HDL-C (mmol/L)", "DBIL (μmol/L)", "AST/ALT", "UA (μmol/L)", | |
"GFR (mL/min/1.73 m²)", "PNI", "HALP", "AAPR", "conuts" | |
] | |
feature_names = ["HGB","HDL_C","DBIL","AST_ALT","UA","GFR","PNI","HALP","AAPR","conuts"] | |
# ====== 加载模型和背景 ====== | |
pipeline = joblib.load(MODEL_PATH) | |
bg_df = pd.read_csv(BG_PATH) | |
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64) | |
# ====== 全局 KernelExplainer(只建一次) ====== | |
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray: | |
df = pd.DataFrame(x_nd, columns=feature_names) | |
return pipeline.predict_proba(df) | |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array) | |
def predict_and_shap(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts, nsamples=200): | |
status_msgs = [] | |
try: | |
# 1) 输入与补全 | |
input_df = pd.DataFrame([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], | |
columns=feature_names).apply(pd.to_numeric, errors="coerce") | |
if input_df.isnull().values.any(): | |
med = pd.Series(np.median(bg_array, axis=0), index=feature_names) | |
input_df = input_df.fillna(med) | |
status_msgs.append("Missing values filled with background medians.") | |
# 2) 概率 | |
prob = float(pipeline.predict_proba(input_df)[0, 1]) | |
status_msgs.append(f"Pred prob computed: {prob:.3f}") | |
# 3) SHAP | |
x_row = input_df.to_numpy(dtype=np.float64) # (1, n_features) | |
shap_out = explainer.shap_values(x_row, nsamples=int(nsamples)) | |
# —— 统一提取“正类”一维向量 (n_features,) —— | |
if isinstance(shap_out, list): | |
sv = np.asarray(shap_out[1], dtype=np.float64) | |
if sv.ndim == 2: | |
sv = sv[0, :] | |
else: | |
sv = np.asarray(shap_out, dtype=np.float64) | |
if sv.ndim == 3: # (1, n_features, n_classes) | |
sv = sv[0, :, 1] # 正类通道 | |
elif sv.ndim == 2: # (1, n_features) | |
sv = sv[0, :] | |
else: | |
sv = sv.reshape(-1) | |
x_1d = x_row[0, :].astype(np.float64) | |
status_msgs.append(f"SHAP 1D shape: {sv.shape}; features: {x_1d.shape}") | |
# base value 取正类 | |
ev = explainer.expected_value | |
if isinstance(ev, (list, np.ndarray)): | |
ev = np.asarray(ev).reshape(-1) | |
base_val = float(ev[1] if len(ev) > 1 else ev[0]) | |
else: | |
base_val = float(ev) | |
fnames = [str(f) for f in feature_names] | |
# 4) 力图(关键:不要先建 fig;让 SHAP 画完后用 plt.gcf() 接回真正的 Figure) | |
try: | |
plt.close('all') # 清理历史句柄,防串扰 | |
shap.force_plot(base_val, sv, x_1d, | |
feature_names=fnames, | |
matplotlib=True, show=False) | |
fig = plt.gcf() # 取 SHAP 实际绘制的 Figure | |
fig.set_size_inches(8, 4) # 调整尺寸 | |
plt.tight_layout() | |
status_msgs.append("Rendered force plot (matplotlib) on current figure.") | |
return round(prob, 3), fig, "\n".join(status_msgs) | |
except Exception as e_force: | |
status_msgs.append(f"Force-plot failed: {repr(e_force)}; fallback=bar") | |
# 5) 条形图兜底(返回实际 fig) | |
order = np.argsort(np.abs(sv))[::-1] | |
topk = order[:min(10, sv.shape[0])] | |
plt.close('all') | |
fig = plt.figure(figsize=(8, 5), dpi=160) | |
plt.barh(np.array(fnames)[topk], sv[topk]) | |
plt.xlabel("SHAP value") | |
plt.title("Top features (single-sample contribution)") | |
plt.gca().invert_yaxis() | |
plt.tight_layout() | |
status_msgs.append("Rendered bar fallback.") | |
return round(prob, 3), fig, "\n".join(status_msgs) | |
except Exception as e: | |
return None, None, f"Fatal error: {repr(e)}" | |
# ====== Blocks 界面 ====== | |
example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 44, 60.8, 0.486, 4, 200] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"### SVM Meige Risk Prediction & SHAP Explanation\n" | |
"Enter 10 indicators with **units** to predict risk and view an individualized explanation.\n\n" | |
"**Example**: HGB=137 g/L, HDL‑C=1.76 mmol/L, DBIL=8.6 μmol/L, AST/ALT=0.97, UA=310 μmol/L, " | |
"GFR=75.4 mL/min/1.73 m², PNI=44, HALP=60.8, AAPR=0.486, conuts=4." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
num_inputs = [gr.Number(label=feature_labels[i], precision=3) for i in range(10)] | |
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples") | |
btn_fill = gr.Button("Fill with Example") | |
btn_predict = gr.Button("Predict") | |
with gr.Column(scale=1): | |
out_prob = gr.Number(label="Predicted Probability") | |
out_plot = gr.Plot(label="SHAP Force (fallback: bar)") # 改成 Plot | |
out_log = gr.Textbox(label="Status", lines=6) | |
def fill_example(): | |
return tuple(example_values) | |
fill_evt = btn_fill.click(fn=fill_example, outputs=[*num_inputs, ns_slider]) | |
fill_evt.then(fn=predict_and_shap, inputs=[*num_inputs, ns_slider], outputs=[out_prob, out_plot, out_log]) | |
btn_predict.click(fn=predict_and_shap, inputs=[*num_inputs, ns_slider], outputs=[out_prob, out_plot, out_log]) | |
if __name__ == "__main__": | |
demo.launch() # 不要写 server_port / share | |