Meige / Gradio.py
Multiple123's picture
Upload 4 files
0683b82 verified
raw
history blame
6 kB
# Gradio.py
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg") # 无交互后端更稳
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# ====== 路径与特征 ======
MODEL_PATH = "models/SVM_pipeline.pkl"
BG_PATH = "data/bg.csv"
feature_labels = [
"HGB (g/L)", "HDL-C (mmol/L)", "DBIL (μmol/L)", "AST/ALT", "UA (μmol/L)",
"GFR (mL/min/1.73 m²)", "PNI", "HALP", "AAPR", "conuts"
]
feature_names = ["HGB","HDL_C","DBIL","AST_ALT","UA","GFR","PNI","HALP","AAPR","conuts"]
# ====== 加载模型和背景 ======
pipeline = joblib.load(MODEL_PATH)
bg_df = pd.read_csv(BG_PATH)
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
# ====== 全局 KernelExplainer(只建一次) ======
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
df = pd.DataFrame(x_nd, columns=feature_names)
return pipeline.predict_proba(df)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
def predict_and_shap(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts, nsamples=200):
status_msgs = []
try:
# 1) 输入与补全
input_df = pd.DataFrame([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]],
columns=feature_names).apply(pd.to_numeric, errors="coerce")
if input_df.isnull().values.any():
med = pd.Series(np.median(bg_array, axis=0), index=feature_names)
input_df = input_df.fillna(med)
status_msgs.append("Missing values filled with background medians.")
# 2) 概率
prob = float(pipeline.predict_proba(input_df)[0, 1])
status_msgs.append(f"Pred prob computed: {prob:.3f}")
# 3) SHAP
x_row = input_df.to_numpy(dtype=np.float64) # (1, n_features)
shap_out = explainer.shap_values(x_row, nsamples=int(nsamples))
# —— 统一提取“正类”一维向量 (n_features,) ——
if isinstance(shap_out, list):
sv = np.asarray(shap_out[1], dtype=np.float64)
if sv.ndim == 2:
sv = sv[0, :]
else:
sv = np.asarray(shap_out, dtype=np.float64)
if sv.ndim == 3: # (1, n_features, n_classes)
sv = sv[0, :, 1] # 正类通道
elif sv.ndim == 2: # (1, n_features)
sv = sv[0, :]
else:
sv = sv.reshape(-1)
x_1d = x_row[0, :].astype(np.float64)
status_msgs.append(f"SHAP 1D shape: {sv.shape}; features: {x_1d.shape}")
# base value 取正类
ev = explainer.expected_value
if isinstance(ev, (list, np.ndarray)):
ev = np.asarray(ev).reshape(-1)
base_val = float(ev[1] if len(ev) > 1 else ev[0])
else:
base_val = float(ev)
fnames = [str(f) for f in feature_names]
# 4) 力图(关键:不要先建 fig;让 SHAP 画完后用 plt.gcf() 接回真正的 Figure)
try:
plt.close('all') # 清理历史句柄,防串扰
shap.force_plot(base_val, sv, x_1d,
feature_names=fnames,
matplotlib=True, show=False)
fig = plt.gcf() # 取 SHAP 实际绘制的 Figure
fig.set_size_inches(8, 4) # 调整尺寸
plt.tight_layout()
status_msgs.append("Rendered force plot (matplotlib) on current figure.")
return round(prob, 3), fig, "\n".join(status_msgs)
except Exception as e_force:
status_msgs.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
# 5) 条形图兜底(返回实际 fig)
order = np.argsort(np.abs(sv))[::-1]
topk = order[:min(10, sv.shape[0])]
plt.close('all')
fig = plt.figure(figsize=(8, 5), dpi=160)
plt.barh(np.array(fnames)[topk], sv[topk])
plt.xlabel("SHAP value")
plt.title("Top features (single-sample contribution)")
plt.gca().invert_yaxis()
plt.tight_layout()
status_msgs.append("Rendered bar fallback.")
return round(prob, 3), fig, "\n".join(status_msgs)
except Exception as e:
return None, None, f"Fatal error: {repr(e)}"
# ====== Blocks 界面 ======
example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 44, 60.8, 0.486, 4, 200]
with gr.Blocks() as demo:
gr.Markdown(
"### SVM Meige Risk Prediction & SHAP Explanation\n"
"Enter 10 indicators with **units** to predict risk and view an individualized explanation.\n\n"
"**Example**: HGB=137 g/L, HDL‑C=1.76 mmol/L, DBIL=8.6 μmol/L, AST/ALT=0.97, UA=310 μmol/L, "
"GFR=75.4 mL/min/1.73 m², PNI=44, HALP=60.8, AAPR=0.486, conuts=4."
)
with gr.Row():
with gr.Column(scale=1):
num_inputs = [gr.Number(label=feature_labels[i], precision=3) for i in range(10)]
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
btn_fill = gr.Button("Fill with Example")
btn_predict = gr.Button("Predict")
with gr.Column(scale=1):
out_prob = gr.Number(label="Predicted Probability")
out_plot = gr.Plot(label="SHAP Force (fallback: bar)") # 改成 Plot
out_log = gr.Textbox(label="Status", lines=6)
def fill_example():
return tuple(example_values)
fill_evt = btn_fill.click(fn=fill_example, outputs=[*num_inputs, ns_slider])
fill_evt.then(fn=predict_and_shap, inputs=[*num_inputs, ns_slider], outputs=[out_prob, out_plot, out_log])
btn_predict.click(fn=predict_and_shap, inputs=[*num_inputs, ns_slider], outputs=[out_prob, out_plot, out_log])
if __name__ == "__main__":
demo.launch() # 不要写 server_port / share