File size: 9,053 Bytes
b3ed7ae
 
 
 
 
 
7876a01
b3ed7ae
 
 
 
 
f2b3a5b
7876a01
b3ed7ae
bdab798
7dcfbfd
b3ed7ae
7876a01
b3ed7ae
f2b3a5b
b3ed7ae
f2b3a5b
 
 
 
b3ed7ae
f2b3a5b
7876a01
f2b3a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ed7ae
7876a01
b3ed7ae
 
7876a01
 
 
e6a0696
 
 
 
7dcfbfd
e6a0696
 
 
 
 
7876a01
f2b3a5b
 
e6a0696
f2b3a5b
e6a0696
 
7876a01
e6a0696
7876a01
e6a0696
 
 
7876a01
 
e6a0696
f2b3a5b
 
 
bdab798
7dcfbfd
7876a01
b3ed7ae
f2b3a5b
 
 
 
7dcfbfd
f2b3a5b
 
 
bdab798
f2b3a5b
7dcfbfd
f2b3a5b
 
 
 
 
bdab798
 
f2b3a5b
 
7876a01
b3ed7ae
bdab798
f2b3a5b
bdab798
7dcfbfd
f2b3a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7876a01
 
 
f2b3a5b
7876a01
f2b3a5b
 
 
 
 
7876a01
 
 
 
 
 
 
 
 
f2b3a5b
7876a01
f2b3a5b
7876a01
 
 
f2b3a5b
 
 
7876a01
 
 
f2b3a5b
7876a01
f2b3a5b
7876a01
f2b3a5b
7876a01
 
f2b3a5b
bdab798
f2b3a5b
 
 
 
 
 
 
 
 
b3ed7ae
 
7876a01
b3ed7ae
bdab798
 
 
 
9603b83
 
d816f90
b3ed7ae
7876a01
b3ed7ae
7876a01
 
bdab798
f2b3a5b
bdab798
7876a01
b3ed7ae
 
7876a01
b3ed7ae
 
f2b3a5b
 
 
 
b3ed7ae
 
bdab798
b3ed7ae
f2b3a5b
7876a01
 
 
 
 
f2b3a5b
7876a01
f2b3a5b
7876a01
 
 
 
 
 
 
 
 
 
 
 
7dcfbfd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg")  # 非交互后端,服务器端更稳
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# ====== 模型与背景数据 ======
MODEL_PATH = "models/svm_pipeline.joblib"
BG_PATH    = "data/bg.csv"

# 模型最终需要的 8 个特征(顺序必须与训练一致)
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR", "DBIL"]

# 加载模型与背景
pipeline = joblib.load(MODEL_PATH)

bg_df = pd.read_csv(BG_PATH)
missing_bg = [c for c in FEATURES if c not in bg_df.columns]
if missing_bg:
    raise ValueError(f"背景集缺少列: {missing_bg}")
bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64)

# 预测函数(供 KernelExplainer 调用)——返回正类概率/分数
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
    df = pd.DataFrame(x_nd, columns=FEATURES)
    # 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict
    if hasattr(pipeline, "predict_proba"):
        proba = pipeline.predict_proba(df)
        # 确定正类索引(假定正类标签为 1;若不是,请在此处修改)
        classes_ = getattr(pipeline, "classes_", None)
        pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
        return proba[:, pos_idx]
    elif hasattr(pipeline, "decision_function"):
        score = pipeline.decision_function(df)
        return score if isinstance(score, np.ndarray) else np.asarray(score)
    else:
        pred = pipeline.predict(df)
        return pred if isinstance(pred, np.ndarray) else np.asarray(pred)

# 只初始化一次 explainer(性能更稳)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)

def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
    """返回 matplotlib Figure(旧接口,服务器端稳定)"""
    plt.close('all')

    # ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位)
    feat = np.asarray(feat_1d, dtype=float).copy()
    round_map = {
        "ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3, "DBIL": 1
    }
    feat_rounded = [
        np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames)
    ]

    shap.force_plot(
        base_val,
        np.asarray(shap_1d).reshape(-1),
        np.asarray(feat_rounded).reshape(-1),
        feature_names=list(fnames),
        matplotlib=True,
        show=False
    )

    fig = plt.gcf()
    fig.set_size_inches(14, 3.6)   # 加宽、适当降低高度
    fig.set_dpi(180)               # 提高分辨率
    plt.tight_layout(pad=0.6)      # 更紧凑
    return fig


def _coerce_float(x):
    return float(x) if x is not None and x != "" else np.nan


def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL, nsamples=200):
    status = []
    try:
        # ---- 1) 取数并校验 ----
        ALB     = _coerce_float(ALB)
        TP      = _coerce_float(TP)
        TBA     = _coerce_float(TBA)
        AST_ALT = _coerce_float(AST_ALT)
        CREA    = _coerce_float(CREA)
        LYM     = _coerce_float(LYM)
        ALP     = _coerce_float(ALP)
        DBIL    = _coerce_float(DBIL)

        vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL]
        if any(np.isnan(v) for v in vals):
            return None, None, "Error: 所有输入必须为数值且不可缺失。"

        if ALP <= 0:
            return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
        if DBIL < 0:
            return None, None, "Error: DBIL 不能为负值。"

        # ---- 2) 衍生指标 ----
        PNI  = ALB + 5.0 * LYM
        AAPR = ALB / ALP
        status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.3f}")

        # ---- 3) 组装最终 8 特征并预测 ----
        x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR, DBIL]], dtype=np.float64)

        if hasattr(pipeline, "predict_proba"):
            classes_ = getattr(pipeline, "classes_", None)
            pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
            prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx])
            status.append(f"Pred prob: {prob:.3f}")
        else:
            # 若无概率,给出分数
            score = float(
                pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0]
            ) if hasattr(pipeline, "decision_function") else float(
                pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0]
            )
            prob = score
            status.append(f"Pred score: {score:.3f}")

        # ---- 4) SHAP 计算 ----
        ns = int(nsamples) if nsamples is not None else 200
        shap_out = explainer.shap_values(x_row, nsamples=ns)

        # 统一提取“一维贡献向量”
        if isinstance(shap_out, list):
            # 二分类:list 长度=2,取正类
            classes_ = getattr(pipeline, "classes_", None)
            pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
            sv = np.asarray(shap_out[pos_idx], dtype=np.float64)
            if sv.ndim == 2:  # (1, n_features)
                sv = sv[0, :]
        else:
            sv = np.asarray(shap_out, dtype=np.float64)
            if sv.ndim == 3:      # (1, n_features, n_classes)
                sv = sv[0, :, 1]
            elif sv.ndim == 2:    # (1, n_features)
                sv = sv[0, :]
            else:
                sv = sv.reshape(-1)
        status.append(f"SHAP vector shape: {sv.shape}")

        # base value
        ev = explainer.expected_value
        if isinstance(ev, (list, np.ndarray)):
            ev = np.asarray(ev).reshape(-1)
            classes_ = getattr(pipeline, "classes_", None)
            pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
            base_val = float(ev[pos_idx if len(ev) > pos_idx else 0])
        else:
            base_val = float(ev)

        # ---- 5) 绘图(优先 force,失败退条形图)----
        try:
            fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES)
            status.append("Rendered force plot (matplotlib).")
            return round(float(prob), 3), fig, "\n".join(status)
        except Exception as e_force:
            status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
            order = np.argsort(np.abs(sv))[::-1]
            topk = order[:min(len(FEATURES), sv.shape[0])]
            plt.close('all')
            fig = plt.figure(figsize=(8, 5), dpi=160)
            plt.barh(np.array(FEATURES)[topk], sv[topk])
            plt.xlabel("SHAP value")
            plt.title("Top features (single-sample contribution)")
            plt.gca().invert_yaxis()
            plt.tight_layout()
            status.append("Rendered bar fallback.")
            return round(float(prob), 3), fig, "\n".join(status)

    except Exception as e:
        return None, None, f"Fatal error: {repr(e)}"

# ====== 示例输入(8 项 + nsamples)======
# 给定示例:
# ALB=40.7, TP=61.7, TBA=4.9, AST/ALT=1.0, CREA=54.3, PNI=48.6, AAPR=0.473, DBIL=8.5
# 由于前端输入仍为原始 6 项 + DBIL(PNI、AAPR 在后端计算),需“还原” LYM 与 ALP:

example_values = [40.7, 61.7, 4.9, 1.0, 54.3, 1.58, 86, 8.5, 200]


# ====== Gradio 界面 ======
with gr.Blocks() as demo:
    gr.Markdown(
        "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
        "Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, ALP, and DBIL.\n\n"
        "**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
        "LYM(×10⁹/L), ALP(U/L), DBIL(μmol/L)."
    )

    with gr.Row():
        with gr.Column(scale=1):
            inputs = [
                gr.Number(label="ALB (g/L)"),
                gr.Number(label="TP (g/L)"),
                gr.Number(label="TBA (μmol/L)"),
                gr.Number(label="AST/ALT"),
                gr.Number(label="CREA (μmol/L)"),
                gr.Number(label="LYM (×10⁹/L)"),
                gr.Number(label="ALP (U/L)"),
                gr.Number(label="DBIL (μmol/L)"),
            ]
            ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")

            btn_fill = gr.Button("Fill Example")
            btn_predict = gr.Button("Predict")

        with gr.Column(scale=1):
            out_prob = gr.Number(label="Predicted Probability / Score")
            out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
            out_log  = gr.Textbox(label="Status", lines=8)

    def _fill_example():
        return tuple(example_values)

    btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
    btn_predict.click(
        fn=predict_and_explain,
        inputs=[*inputs, ns_slider],
        outputs=[out_prob, out_plot, out_log]
    )

if __name__ == "__main__":
    demo.launch()