Meige / app.py
Multiple123's picture
Update app.py
b3ed7ae verified
raw
history blame
3.35 kB
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# ====== 模型与背景数据 ======
MODEL_PATH = "models/SVM_pipeline.pkl"
BG_PATH = "data/bg.csv"
feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
pipeline = joblib.load(MODEL_PATH)
bg_df = pd.read_csv(BG_PATH)
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
def _predict_proba_nd(x_nd):
df = pd.DataFrame(x_nd, columns=feature_names)
return pipeline.predict_proba(df)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
def predict_and_explain(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples=200):
try:
# 自动派生变量
PNI = ALB + 5 * LYM
HALP = HGB * ALB * LYM / PLT
AAPR = ALB / ALP
conuts = (
(0 if ALB >= 35 else 2 if ALB >= 30 else 4 if ALB >= 25 else 6) +
(0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
(0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
)
x_row = [[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]]
input_df = pd.DataFrame(x_row, columns=feature_names)
prob = float(pipeline.predict_proba(input_df)[0, 1])
shap_out = explainer.shap_values(np.array(x_row), nsamples=nsamples)
sv = shap_out[1][0] if isinstance(shap_out, list) else shap_out[0]
base_val = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
plt.close('all')
shap.force_plot(base_val, sv, x_row[0], feature_names=feature_names, matplotlib=True, show=False)
fig = plt.gcf()
fig.set_size_inches(8, 4)
plt.tight_layout()
return round(prob, 3), fig, "Success"
except Exception as e:
return None, None, f"Error: {e}"
example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 33, 2.2, 164, 67.9, 2.8, 200]
with gr.Blocks() as demo:
gr.Markdown("### Logistic Regression Risk Prediction with SHAP Explanation")
with gr.Row():
with gr.Column():
inputs = [
gr.Number(label="HGB (g/L)"),
gr.Number(label="HDL-C (mmol/L)"),
gr.Number(label="DBIL (μmol/L)"),
gr.Number(label="AST/ALT"),
gr.Number(label="UA (μmol/L)"),
gr.Number(label="GFR (mL/min/1.73 m²)"),
gr.Number(label="ALB (g/L)"),
gr.Number(label="LYM (×10⁹/L)"),
gr.Number(label="PLT (×10⁹/L)"),
gr.Number(label="ALP (U/L)"),
gr.Number(label="CHOL (mmol/L)")
]
ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
gr.Button("Fill Example").click(lambda: tuple(example_vals), outputs=[*inputs, ns_slider])
gr.Button("Predict").click(fn=predict_and_explain,
inputs=[*inputs, ns_slider],
outputs=[gr.Number(label="Risk"), gr.Plot(), gr.Textbox(label="Status", lines=4)])
with gr.Column():
pass
demo.launch()