Multiple123 commited on
Commit
7876a01
·
verified ·
1 Parent(s): 886a0e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -34
app.py CHANGED
@@ -1,34 +1,68 @@
 
1
  import gradio as gr
2
  import pandas as pd
3
  import joblib
4
  import shap
5
  import numpy as np
6
  import matplotlib
7
- matplotlib.use("Agg")
8
  import matplotlib.pyplot as plt
9
  import warnings
10
  warnings.filterwarnings("ignore")
11
 
12
  # ====== 模型与背景数据 ======
13
  MODEL_PATH = "models/SVM_pipeline.pkl"
14
- BG_PATH = "data/bg.csv"
15
 
 
16
  feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
17
 
 
18
  pipeline = joblib.load(MODEL_PATH)
19
  bg_df = pd.read_csv(BG_PATH)
20
  bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
21
 
22
- def _predict_proba_nd(x_nd):
 
23
  df = pd.DataFrame(x_nd, columns=feature_names)
24
  return pipeline.predict_proba(df)
25
 
 
26
  explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
27
 
28
- def predict_and_explain(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples=200):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
- # 自动派生变量
31
- PNI = ALB + 5 * LYM
 
 
 
 
 
 
 
 
 
 
 
 
32
  HALP = HGB * ALB * LYM / PLT
33
  AAPR = ALB / ALP
34
  conuts = (
@@ -36,36 +70,79 @@ def predict_and_explain(HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP,
36
  (0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
37
  (0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
38
  )
39
-
40
- x_row = [[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]]
41
- input_df = pd.DataFrame(x_row, columns=feature_names)
42
-
43
- prob = float(pipeline.predict_proba(input_df)[0, 1])
44
-
45
- shap_out = explainer.shap_values(np.array(x_row), nsamples=nsamples)
46
- sv = shap_out[1][0] if isinstance(shap_out, list) else shap_out[0]
47
-
48
- base_val = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  plt.close('all')
51
- shap.force_plot(base_val, sv, x_row[0], feature_names=feature_names, matplotlib=True, show=False)
52
- fig = plt.gcf()
53
- fig.set_size_inches(8, 4)
 
 
54
  plt.tight_layout()
55
-
56
- return round(prob, 3), fig, "Success"
57
 
58
  except Exception as e:
59
- return None, None, f"Error: {e}"
60
 
 
 
61
  example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 33, 2.2, 164, 67.9, 2.8, 200]
 
62
 
63
-
64
  with gr.Blocks() as demo:
65
- gr.Markdown("### Logistic Regression Risk Prediction with SHAP Explanation")
 
 
 
 
 
66
 
67
  with gr.Row():
68
- with gr.Column():
69
  inputs = [
70
  gr.Number(label="HGB (g/L)"),
71
  gr.Number(label="HDL-C (mmol/L)"),
@@ -80,11 +157,24 @@ with gr.Blocks() as demo:
80
  gr.Number(label="CHOL (mmol/L)")
81
  ]
82
  ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
83
- gr.Button("Fill Example").click(lambda: tuple(example_values), outputs=[*inputs, ns_slider])
84
- gr.Button("Predict").click(fn=predict_and_explain,
85
- inputs=[*inputs, ns_slider],
86
- outputs=[gr.Number(label="Risk"), gr.Plot(), gr.Textbox(label="Status", lines=4)])
87
- with gr.Column():
88
- pass
89
-
90
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (robust, server-safe)
2
  import gradio as gr
3
  import pandas as pd
4
  import joblib
5
  import shap
6
  import numpy as np
7
  import matplotlib
8
+ matplotlib.use("Agg") # 非交互后端,服务器端更稳
9
  import matplotlib.pyplot as plt
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  # ====== 模型与背景数据 ======
14
  MODEL_PATH = "models/SVM_pipeline.pkl"
15
+ BG_PATH = "data/bg.csv"
16
 
17
+ # 模型最终需要的10个特征(顺序必须与训练一致)
18
  feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
19
 
20
+ # 加载模型与背景
21
  pipeline = joblib.load(MODEL_PATH)
22
  bg_df = pd.read_csv(BG_PATH)
23
  bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
24
 
25
+ # 预测函数(供 KernelExplainer 调用)
26
+ def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
27
  df = pd.DataFrame(x_nd, columns=feature_names)
28
  return pipeline.predict_proba(df)
29
 
30
+ # 只初始化一次 explainer(性能更稳)
31
  explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
32
 
33
+ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
34
+ """返回 matplotlib Figure(旧接口,服务器端稳定)"""
35
+ plt.close('all')
36
+ shap.force_plot(
37
+ base_val, np.asarray(shap_1d).reshape(-1), np.asarray(feat_1d).reshape(-1),
38
+ feature_names=list(fnames), matplotlib=True, show=False
39
+ )
40
+ fig = plt.gcf()
41
+ fig.set_size_inches(8, 4)
42
+ plt.tight_layout()
43
+ return fig
44
+
45
+ def predict_and_explain(
46
+ HGB, HDL_C, DBIL, AST_ALT, UA, GFR,
47
+ ALB, LYM, PLT, ALP, CHOL,
48
+ nsamples=200
49
+ ):
50
+ status = []
51
  try:
52
+ # ---- 1) 衍生指标(由原始输入计算)----
53
+ try:
54
+ HGB = float(HGB); HDL_C = float(HDL_C); DBIL = float(DBIL); AST_ALT = float(AST_ALT)
55
+ UA = float(UA); GFR = float(GFR)
56
+ ALB = float(ALB); LYM = float(LYM); PLT = float(PLT)
57
+ ALP = float(ALP); CHOL = float(CHOL)
58
+ except Exception:
59
+ return None, None, "Error: some inputs are not numeric."
60
+
61
+ # 防极端值(避免除0)
62
+ if PLT <= 0 or ALP <= 0:
63
+ return None, None, "Error: PLT and ALP must be > 0."
64
+
65
+ PNI = ALB + 5.0 * LYM
66
  HALP = HGB * ALB * LYM / PLT
67
  AAPR = ALB / ALP
68
  conuts = (
 
70
  (0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
71
  (0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
72
  )
73
+ x_row = np.array([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], dtype=np.float64)
74
+ status.append(f"Derived: PNI={PNI:.3f}, HALP={HALP:.3f}, AAPR={AAPR:.3f}, CONUTS={conuts}")
75
+
76
+ # ---- 2) 概率 ----
77
+ prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=feature_names))[0, 1])
78
+ status.append(f"Pred prob computed: {prob:.3f}")
79
+
80
+ # ---- 3) SHAP 计算 ----
81
+ ns = int(nsamples) if nsamples is not None else 200
82
+ shap_out = explainer.shap_values(x_row, nsamples=ns)
83
+
84
+ # 统一提取“正类”一维向量
85
+ if isinstance(shap_out, list):
86
+ sv = np.asarray(shap_out[1], dtype=np.float64)
87
+ if sv.ndim == 2:
88
+ sv = sv[0, :]
89
+ else:
90
+ sv = np.asarray(shap_out, dtype=np.float64)
91
+ if sv.ndim == 3: # (1, n_features, n_classes)
92
+ sv = sv[0, :, 1]
93
+ elif sv.ndim == 2: # (1, n_features)
94
+ sv = sv[0, :]
95
+ else:
96
+ sv = sv.reshape(-1)
97
+ status.append(f"SHAP 1D shape: {sv.shape}; features: {x_row.shape[1:]}")
98
+
99
+ # base value 取正类
100
+ ev = explainer.expected_value
101
+ if isinstance(ev, (list, np.ndarray)):
102
+ ev = np.asarray(ev).reshape(-1)
103
+ base_val = float(ev[1] if len(ev) > 1 else ev[0])
104
+ else:
105
+ base_val = float(ev)
106
+
107
+ # ---- 4) 绘图:优先力图;失败则条形图兜底 ----
108
+ try:
109
+ fig = _render_force_plot(base_val, sv, x_row[0, :], feature_names)
110
+ status.append("Rendered force plot (matplotlib).")
111
+ return round(prob, 3), fig, "\n".join(status)
112
+ except Exception as e_force:
113
+ status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
114
+
115
+ order = np.argsort(np.abs(sv))[::-1]
116
+ topk = order[:min(10, sv.shape[0])]
117
  plt.close('all')
118
+ fig = plt.figure(figsize=(8, 5), dpi=160)
119
+ plt.barh(np.array(feature_names)[topk], sv[topk])
120
+ plt.xlabel("SHAP value")
121
+ plt.title("Top features (single-sample contribution)")
122
+ plt.gca().invert_yaxis()
123
  plt.tight_layout()
124
+ status.append("Rendered bar fallback.")
125
+ return round(prob, 3), fig, "\n".join(status)
126
 
127
  except Exception as e:
128
+ return None, None, f"Fatal error: {repr(e)}"
129
 
130
+ # ====== 示例:一组“原始指标”可复现你之前的 PNI/HALP/AAPR/CONUTS ======
131
+ # 对应:PNI=44, HALP≈60.8, AAPR≈0.486, CONUTS=4
132
  example_values = [137, 1.76, 8.6, 0.97, 310, 75.4, 33, 2.2, 164, 67.9, 2.8, 200]
133
+ # 顺序:HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples
134
 
135
+ # ====== Gradio 界面 ======
136
  with gr.Blocks() as demo:
137
+ gr.Markdown(
138
+ "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
139
+ "Enter **original clinical indicators**; the app will derive PNI/HALP/AAPR/CONUTS internally.\n\n"
140
+ "**Units**: HGB (g/L), HDL‑C (mmol/L), DBIL (μmol/L), AST/ALT (ratio), UA (μmol/L), "
141
+ "GFR (mL/min/1.73 m²), ALB (g/L), LYM (×10⁹/L), PLT (×10⁹/L), ALP (U/L), CHOL (mmol/L)."
142
+ )
143
 
144
  with gr.Row():
145
+ with gr.Column(scale=1):
146
  inputs = [
147
  gr.Number(label="HGB (g/L)"),
148
  gr.Number(label="HDL-C (mmol/L)"),
 
157
  gr.Number(label="CHOL (mmol/L)")
158
  ]
159
  ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
160
+
161
+ btn_fill = gr.Button("Fill Example")
162
+ btn_predict = gr.Button("Predict")
163
+
164
+ with gr.Column(scale=1):
165
+ out_prob = gr.Number(label="Predicted Probability")
166
+ out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
167
+ out_log = gr.Textbox(label="Status", lines=6)
168
+
169
+ def _fill_example():
170
+ return tuple(example_values)
171
+
172
+ btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
173
+ btn_predict.click(
174
+ fn=predict_and_explain,
175
+ inputs=[*inputs, ns_slider],
176
+ outputs=[out_prob, out_plot, out_log]
177
+ )
178
+
179
+ if __name__ == "__main__":
180
+ demo.launch()