Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import joblib | |
import shap | |
import numpy as np | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
import warnings | |
warnings.filterwarnings("ignore") | |
# ====== 模型与背景数据 ====== | |
MODEL_PATH = "models/SVM_pipeline.pkl" | |
BG_PATH = "data/bg.csv" | |
feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"] | |
pipeline = joblib.load(MODEL_PATH) | |
bg_df = pd.read_csv(BG_PATH) | |
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64) | |
def _predict_proba_nd(x_nd): | |
df = pd.DataFrame(x_nd, columns=feature_names) | |
return pipeline.predict_proba(df) | |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array) | |
def predict_and_explain(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples=200): | |
try: | |
# 自动派生变量 | |
PNI = ALB + 5 * LYM | |
HALP = HGB * ALB * LYM / PLT | |
AAPR = ALB / ALP | |
conuts = ( | |
(0 if ALB >= 35 else 2 if ALB >= 30 else 4 if ALB >= 25 else 6) + | |
(0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) + | |
(0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3) | |
) | |
x_row = [[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]] | |
input_df = pd.DataFrame(x_row, columns=feature_names) | |
prob = float(pipeline.predict_proba(input_df)[0, 1]) | |
shap_out = explainer.shap_values(np.array(x_row), nsamples=nsamples) | |
sv = shap_out[1][0] if isinstance(shap_out, list) else shap_out[0] | |
base_val = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value | |
plt.close('all') | |
shap.force_plot(base_val, sv, x_row[0], feature_names=feature_names, matplotlib=True, show=False) | |
fig = plt.gcf() | |
fig.set_size_inches(8, 4) | |
plt.tight_layout() | |
return round(prob, 3), fig, "Success" | |
except Exception as e: | |
return None, None, f"Error: {e}" | |
example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 33, 2.2, 164, 67.9, 2.8, 200] | |
with gr.Blocks() as demo: | |
gr.Markdown("### Logistic Regression Risk Prediction with SHAP Explanation") | |
with gr.Row(): | |
with gr.Column(): | |
inputs = [ | |
gr.Number(label="HGB (g/L)"), | |
gr.Number(label="HDL-C (mmol/L)"), | |
gr.Number(label="DBIL (μmol/L)"), | |
gr.Number(label="AST/ALT"), | |
gr.Number(label="UA (μmol/L)"), | |
gr.Number(label="GFR (mL/min/1.73 m²)"), | |
gr.Number(label="ALB (g/L)"), | |
gr.Number(label="LYM (×10⁹/L)"), | |
gr.Number(label="PLT (×10⁹/L)"), | |
gr.Number(label="ALP (U/L)"), | |
gr.Number(label="CHOL (mmol/L)") | |
] | |
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples") | |
gr.Button("Fill Example").click(lambda: tuple(example_vals), outputs=[*inputs, ns_slider]) | |
gr.Button("Predict").click(fn=predict_and_explain, | |
inputs=[*inputs, ns_slider], | |
outputs=[gr.Number(label="Risk"), gr.Plot(), gr.Textbox(label="Status", lines=4)]) | |
with gr.Column(): | |
pass | |
demo.launch() | |