File size: 7,046 Bytes
7876a01
b3ed7ae
 
 
 
 
 
7876a01
b3ed7ae
 
 
 
 
 
7876a01
b3ed7ae
7876a01
b3ed7ae
 
7876a01
b3ed7ae
 
 
 
7876a01
 
b3ed7ae
 
 
7876a01
b3ed7ae
 
7876a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ed7ae
7876a01
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ed7ae
 
 
 
 
 
 
7876a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ed7ae
7876a01
 
 
 
 
b3ed7ae
7876a01
 
b3ed7ae
 
7876a01
b3ed7ae
7876a01
 
7aa125b
7876a01
b3ed7ae
7876a01
b3ed7ae
7876a01
 
 
 
 
 
b3ed7ae
 
7876a01
b3ed7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
7876a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# app.py (robust, server-safe)
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg")  # 非交互后端,服务器端更稳
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# ====== 模型与背景数据 ======
MODEL_PATH = "models/SVM_pipeline.pkl"
BG_PATH    = "data/bg.csv"

# 模型最终需要的10个特征(顺序必须与训练一致)
feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]

# 加载模型与背景
pipeline = joblib.load(MODEL_PATH)
bg_df = pd.read_csv(BG_PATH)
bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)

# 预测函数(供 KernelExplainer 调用)
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
    df = pd.DataFrame(x_nd, columns=feature_names)
    return pipeline.predict_proba(df)

# 只初始化一次 explainer(性能更稳)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)

def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
    """返回 matplotlib Figure(旧接口,服务器端稳定)"""
    plt.close('all')
    shap.force_plot(
        base_val, np.asarray(shap_1d).reshape(-1), np.asarray(feat_1d).reshape(-1),
        feature_names=list(fnames), matplotlib=True, show=False
    )
    fig = plt.gcf()
    fig.set_size_inches(8, 4)
    plt.tight_layout()
    return fig

def predict_and_explain(
    HGB, HDL_C, DBIL, AST_ALT, UA, GFR,
    ALB, LYM, PLT, ALP, CHOL,
    nsamples=200
):
    status = []
    try:
        # ---- 1) 衍生指标(由原始输入计算)----
        try:
            HGB  = float(HGB);   HDL_C = float(HDL_C); DBIL = float(DBIL); AST_ALT = float(AST_ALT)
            UA   = float(UA);    GFR   = float(GFR)
            ALB  = float(ALB);   LYM   = float(LYM);   PLT   = float(PLT)
            ALP  = float(ALP);   CHOL  = float(CHOL)
        except Exception:
            return None, None, "Error: some inputs are not numeric."

        # 防极端值(避免除0)
        if PLT <= 0 or ALP <= 0:
            return None, None, "Error: PLT and ALP must be > 0."

        PNI  = ALB + 5.0 * LYM
        HALP = HGB * ALB * LYM / PLT
        AAPR = ALB / ALP
        conuts = (
            (0 if ALB >= 35 else 2 if ALB >= 30 else 4 if ALB >= 25 else 6) +
            (0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
            (0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
        )
        x_row = np.array([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], dtype=np.float64)
        status.append(f"Derived: PNI={PNI:.3f}, HALP={HALP:.3f}, AAPR={AAPR:.3f}, CONUTS={conuts}")

        # ---- 2) 概率 ----
        prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=feature_names))[0, 1])
        status.append(f"Pred prob computed: {prob:.3f}")

        # ---- 3) SHAP 计算 ----
        ns = int(nsamples) if nsamples is not None else 200
        shap_out = explainer.shap_values(x_row, nsamples=ns)

        # 统一提取“正类”一维向量
        if isinstance(shap_out, list):
            sv = np.asarray(shap_out[1], dtype=np.float64)
            if sv.ndim == 2:
                sv = sv[0, :]
        else:
            sv = np.asarray(shap_out, dtype=np.float64)
            if sv.ndim == 3:      # (1, n_features, n_classes)
                sv = sv[0, :, 1]
            elif sv.ndim == 2:    # (1, n_features)
                sv = sv[0, :]
            else:
                sv = sv.reshape(-1)
        status.append(f"SHAP 1D shape: {sv.shape}; features: {x_row.shape[1:]}")

        # base value 取正类
        ev = explainer.expected_value
        if isinstance(ev, (list, np.ndarray)):
            ev = np.asarray(ev).reshape(-1)
            base_val = float(ev[1] if len(ev) > 1 else ev[0])
        else:
            base_val = float(ev)

        # ---- 4) 绘图:优先力图;失败则条形图兜底 ----
        try:
            fig = _render_force_plot(base_val, sv, x_row[0, :], feature_names)
            status.append("Rendered force plot (matplotlib).")
            return round(prob, 3), fig, "\n".join(status)
        except Exception as e_force:
            status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")

        order = np.argsort(np.abs(sv))[::-1]
        topk = order[:min(10, sv.shape[0])]
        plt.close('all')
        fig = plt.figure(figsize=(8, 5), dpi=160)
        plt.barh(np.array(feature_names)[topk], sv[topk])
        plt.xlabel("SHAP value")
        plt.title("Top features (single-sample contribution)")
        plt.gca().invert_yaxis()
        plt.tight_layout()
        status.append("Rendered bar fallback.")
        return round(prob, 3), fig, "\n".join(status)

    except Exception as e:
        return None, None, f"Fatal error: {repr(e)}"

# ====== 示例:一组“原始指标”可复现你之前的 PNI/HALP/AAPR/CONUTS ======
# 对应:PNI=44, HALP≈60.8, AAPR≈0.486, CONUTS=4
example_values = [167, 1.76, 8.6, 0.97, 310, 75, 33, 2.2, 164, 68, 2.8, 200]
# 顺序:HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples

# ====== Gradio 界面 ======
with gr.Blocks() as demo:
    gr.Markdown(
        "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
        "Enter **original clinical indicators**; the app will derive PNI/HALP/AAPR/CONUTS internally.\n\n"
        "**Units**: HGB (g/L), HDL‑C (mmol/L), DBIL (μmol/L), AST/ALT (ratio), UA (μmol/L), "
        "GFR (mL/min/1.73 m²), ALB (g/L), LYM (×10⁹/L), PLT (×10⁹/L), ALP (U/L), CHOL (mmol/L)."
    )

    with gr.Row():
        with gr.Column(scale=1):
            inputs = [
                gr.Number(label="HGB (g/L)"),
                gr.Number(label="HDL-C (mmol/L)"),
                gr.Number(label="DBIL (μmol/L)"),
                gr.Number(label="AST/ALT"),
                gr.Number(label="UA (μmol/L)"),
                gr.Number(label="GFR (mL/min/1.73 m²)"),
                gr.Number(label="ALB (g/L)"),
                gr.Number(label="LYM (×10⁹/L)"),
                gr.Number(label="PLT (×10⁹/L)"),
                gr.Number(label="ALP (U/L)"),
                gr.Number(label="CHOL (mmol/L)")
            ]
            ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")

            btn_fill = gr.Button("Fill Example")
            btn_predict = gr.Button("Predict")

        with gr.Column(scale=1):
            out_prob = gr.Number(label="Predicted Probability")
            out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
            out_log  = gr.Textbox(label="Status", lines=6)

    def _fill_example():
        return tuple(example_values)

    btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
    btn_predict.click(
        fn=predict_and_explain,
        inputs=[*inputs, ns_slider],
        outputs=[out_prob, out_plot, out_log]
    )

if __name__ == "__main__":
    demo.launch()