Multiple123 commited on
Commit
f2b3a5b
·
verified ·
1 Parent(s): 7aa125b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -81
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (robust, server-safe)
2
  import gradio as gr
3
  import pandas as pd
4
  import joblib
@@ -11,21 +11,37 @@ import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  # ====== 模型与背景数据 ======
14
- MODEL_PATH = "models/SVM_pipeline.pkl"
15
  BG_PATH = "data/bg.csv"
16
 
17
- # 模型最终需要的10个特征(顺序必须与训练一致)
18
- feature_names = ["HGB", "HDL_C", "DBIL", "AST_ALT", "UA", "GFR", "PNI", "HALP", "AAPR", "conuts"]
19
 
20
  # 加载模型与背景
21
  pipeline = joblib.load(MODEL_PATH)
 
22
  bg_df = pd.read_csv(BG_PATH)
23
- bg_array = bg_df[feature_names].to_numpy(dtype=np.float64)
 
 
 
24
 
25
- # 预测函数(供 KernelExplainer 调用)
26
  def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
27
- df = pd.DataFrame(x_nd, columns=feature_names)
28
- return pipeline.predict_proba(df)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # 只初始化一次 explainer(性能更稳)
31
  explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
@@ -34,57 +50,73 @@ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray
34
  """返回 matplotlib Figure(旧接口,服务器端稳定)"""
35
  plt.close('all')
36
  shap.force_plot(
37
- base_val, np.asarray(shap_1d).reshape(-1), np.asarray(feat_1d).reshape(-1),
38
- feature_names=list(fnames), matplotlib=True, show=False
 
 
 
39
  )
40
  fig = plt.gcf()
41
  fig.set_size_inches(8, 4)
42
  plt.tight_layout()
43
  return fig
44
 
45
- def predict_and_explain(
46
- HGB, HDL_C, DBIL, AST_ALT, UA, GFR,
47
- ALB, LYM, PLT, ALP, CHOL,
48
- nsamples=200
49
- ):
50
  status = []
51
  try:
52
- # ---- 1) 衍生指标(由原始输入计算)----
53
- try:
54
- HGB = float(HGB); HDL_C = float(HDL_C); DBIL = float(DBIL); AST_ALT = float(AST_ALT)
55
- UA = float(UA); GFR = float(GFR)
56
- ALB = float(ALB); LYM = float(LYM); PLT = float(PLT)
57
- ALP = float(ALP); CHOL = float(CHOL)
58
- except Exception:
59
- return None, None, "Error: some inputs are not numeric."
60
-
61
- # 防极端值(避免除0)
62
- if PLT <= 0 or ALP <= 0:
63
- return None, None, "Error: PLT and ALP must be > 0."
64
-
 
 
 
 
65
  PNI = ALB + 5.0 * LYM
66
- HALP = HGB * ALB * LYM / PLT
67
  AAPR = ALB / ALP
68
- conuts = (
69
- (0 if ALB >= 35 else 2 if ALB >= 30 else 4 if ALB >= 25 else 6) +
70
- (0 if LYM >= 1.6 else 1 if LYM >= 1.2 else 2 if LYM >= 0.8 else 3) +
71
- (0 if CHOL >= 4.65 else 1 if CHOL >= 3.10 else 2 if CHOL >= 2.59 else 3)
72
- )
73
- x_row = np.array([[HGB, HDL_C, DBIL, AST_ALT, UA, GFR, PNI, HALP, AAPR, conuts]], dtype=np.float64)
74
- status.append(f"Derived: PNI={PNI:.3f}, HALP={HALP:.3f}, AAPR={AAPR:.3f}, CONUTS={conuts}")
75
-
76
- # ---- 2) 概率 ----
77
- prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=feature_names))[0, 1])
78
- status.append(f"Pred prob computed: {prob:.3f}")
79
-
80
- # ---- 3) SHAP 计算 ----
 
 
 
 
 
 
 
 
81
  ns = int(nsamples) if nsamples is not None else 200
82
  shap_out = explainer.shap_values(x_row, nsamples=ns)
83
 
84
- # 统一提取“正类”一维向量
85
  if isinstance(shap_out, list):
86
- sv = np.asarray(shap_out[1], dtype=np.float64)
87
- if sv.ndim == 2:
 
 
 
88
  sv = sv[0, :]
89
  else:
90
  sv = np.asarray(shap_out, dtype=np.float64)
@@ -94,77 +126,75 @@ def predict_and_explain(
94
  sv = sv[0, :]
95
  else:
96
  sv = sv.reshape(-1)
97
- status.append(f"SHAP 1D shape: {sv.shape}; features: {x_row.shape[1:]}")
98
 
99
- # base value 取正类
100
  ev = explainer.expected_value
101
  if isinstance(ev, (list, np.ndarray)):
102
  ev = np.asarray(ev).reshape(-1)
103
- base_val = float(ev[1] if len(ev) > 1 else ev[0])
 
 
104
  else:
105
  base_val = float(ev)
106
 
107
- # ---- 4) 绘图:优先力图;失败则条形图兜底 ----
108
  try:
109
- fig = _render_force_plot(base_val, sv, x_row[0, :], feature_names)
110
  status.append("Rendered force plot (matplotlib).")
111
- return round(prob, 3), fig, "\n".join(status)
112
  except Exception as e_force:
113
  status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
114
-
115
- order = np.argsort(np.abs(sv))[::-1]
116
- topk = order[:min(10, sv.shape[0])]
117
- plt.close('all')
118
- fig = plt.figure(figsize=(8, 5), dpi=160)
119
- plt.barh(np.array(feature_names)[topk], sv[topk])
120
- plt.xlabel("SHAP value")
121
- plt.title("Top features (single-sample contribution)")
122
- plt.gca().invert_yaxis()
123
- plt.tight_layout()
124
- status.append("Rendered bar fallback.")
125
- return round(prob, 3), fig, "\n".join(status)
126
 
127
  except Exception as e:
128
  return None, None, f"Fatal error: {repr(e)}"
129
 
130
- # ====== 示例:一组“原始指标”可复现你之前的 PNI/HALP/AAPR/CONUTS ======
131
- # 对应:PNI=44, HALP≈60.8, AAPR≈0.486, CONUTS=4
132
- example_values = [167, 1.76, 8.6, 0.97, 310, 75, 33, 2.2, 164, 68, 2.8, 200]
133
- # 顺序:HGB, HDL_C, DBIL, AST_ALT, UA, GFR, ALB, LYM, PLT, ALP, CHOL, nsamples
134
 
135
  # ====== Gradio 界面 ======
136
  with gr.Blocks() as demo:
137
  gr.Markdown(
138
  "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
139
- "Enter **original clinical indicators**; the app will derive PNI/HALP/AAPR/CONUTS internally.\n\n"
140
- "**Units**: HGB (g/L), HDL‑C (mmol/L), DBIL (μmol/L), AST/ALT (ratio), UA (μmol/L), "
141
- "GFR (mL/min/1.73 m²), ALB (g/L), LYM (×10⁹/L), PLT (×10⁹/L), ALP (U/L), CHOL (mmol/L)."
 
142
  )
143
 
144
  with gr.Row():
145
  with gr.Column(scale=1):
146
  inputs = [
147
- gr.Number(label="HGB (g/L)"),
148
- gr.Number(label="HDL-C (mmol/L)"),
149
- gr.Number(label="DBIL (μmol/L)"),
150
- gr.Number(label="AST/ALT"),
151
- gr.Number(label="UA (μmol/L)"),
152
- gr.Number(label="GFR (mL/min/1.73 m²)"),
153
  gr.Number(label="ALB (g/L)"),
 
 
 
 
154
  gr.Number(label="LYM (×10⁹/L)"),
155
- gr.Number(label="PLT (×10⁹/L)"),
156
  gr.Number(label="ALP (U/L)"),
157
- gr.Number(label="CHOL (mmol/L)")
158
  ]
159
- ns_slider = gr.Slider(100, 400, value=200, step=50, label="SHAP nsamples")
160
 
161
  btn_fill = gr.Button("Fill Example")
162
  btn_predict = gr.Button("Predict")
163
 
164
  with gr.Column(scale=1):
165
- out_prob = gr.Number(label="Predicted Probability")
166
  out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
167
- out_log = gr.Textbox(label="Status", lines=6)
168
 
169
  def _fill_example():
170
  return tuple(example_values)
 
1
+ # app.py (7-feature aligned, server-safe)
2
  import gradio as gr
3
  import pandas as pd
4
  import joblib
 
11
  warnings.filterwarnings("ignore")
12
 
13
  # ====== 模型与背景数据 ======
14
+ MODEL_PATH = "models/svm_pipeline.joblib"
15
  BG_PATH = "data/bg.csv"
16
 
17
+ # 模型最终需要的 7 个特征(顺序必须与训练一致)
18
+ FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR"]
19
 
20
  # 加载模型与背景
21
  pipeline = joblib.load(MODEL_PATH)
22
+
23
  bg_df = pd.read_csv(BG_PATH)
24
+ missing_bg = [c for c in FEATURES if c not in bg_df.columns]
25
+ if missing_bg:
26
+ raise ValueError(f"背景集缺少列: {missing_bg}")
27
+ bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64)
28
 
29
+ # 预测函数(供 KernelExplainer 调用)——返回正类概率/分数
30
  def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
31
+ df = pd.DataFrame(x_nd, columns=FEATURES)
32
+ # 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict
33
+ if hasattr(pipeline, "predict_proba"):
34
+ proba = pipeline.predict_proba(df)
35
+ # 确定正类索引(假定正类标签为 1;若不是,请在此处修改)
36
+ classes_ = getattr(pipeline, "classes_", None)
37
+ pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
38
+ return proba[:, pos_idx]
39
+ elif hasattr(pipeline, "decision_function"):
40
+ score = pipeline.decision_function(df)
41
+ return score if isinstance(score, np.ndarray) else np.asarray(score)
42
+ else:
43
+ pred = pipeline.predict(df)
44
+ return pred if isinstance(pred, np.ndarray) else np.asarray(pred)
45
 
46
  # 只初始化一次 explainer(性能更稳)
47
  explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
 
50
  """返回 matplotlib Figure(旧接口,服务器端稳定)"""
51
  plt.close('all')
52
  shap.force_plot(
53
+ base_val,
54
+ np.asarray(shap_1d).reshape(-1),
55
+ np.asarray(feat_1d).reshape(-1),
56
+ feature_names=list(fnames),
57
+ matplotlib=True, show=False
58
  )
59
  fig = plt.gcf()
60
  fig.set_size_inches(8, 4)
61
  plt.tight_layout()
62
  return fig
63
 
64
+ def _coerce_float(x):
65
+ return float(x) if x is not None and x != "" else np.nan
66
+
67
+ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
 
68
  status = []
69
  try:
70
+ # ---- 1) 取数并校验 ----
71
+ ALB = _coerce_float(ALB)
72
+ TP = _coerce_float(TP)
73
+ TBA = _coerce_float(TBA)
74
+ AST_ALT = _coerce_float(AST_ALT)
75
+ CREA = _coerce_float(CREA)
76
+ LYM = _coerce_float(LYM)
77
+ ALP = _coerce_float(ALP)
78
+
79
+ vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP]
80
+ if any(np.isnan(v) for v in vals):
81
+ return None, None, "Error: 所有输入必须为数值且不可缺失。"
82
+
83
+ if ALP <= 0:
84
+ return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
85
+
86
+ # ---- 2) 衍生指标 ----
87
  PNI = ALB + 5.0 * LYM
 
88
  AAPR = ALB / ALP
89
+ status.append(f"Derived: PNI={PNI:.3f}, AAPR={AAPR:.3f}")
90
+
91
+ # ---- 3) 组装最终 7 特征并预测 ----
92
+ x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR]], dtype=np.float64)
93
+
94
+ if hasattr(pipeline, "predict_proba"):
95
+ classes_ = getattr(pipeline, "classes_", None)
96
+ pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
97
+ prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx])
98
+ status.append(f"Pred prob: {prob:.3f}")
99
+ else:
100
+ # 若无概率,给出分数
101
+ score = float(
102
+ pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0]
103
+ ) if hasattr(pipeline, "decision_function") else float(
104
+ pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0]
105
+ )
106
+ prob = score
107
+ status.append(f"Pred score: {score:.3f}")
108
+
109
+ # ---- 4) SHAP 计算 ----
110
  ns = int(nsamples) if nsamples is not None else 200
111
  shap_out = explainer.shap_values(x_row, nsamples=ns)
112
 
113
+ # 统一提取“一维贡献向量”
114
  if isinstance(shap_out, list):
115
+ # 二分类:list 长度=2,取正类
116
+ classes_ = getattr(pipeline, "classes_", None)
117
+ pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
118
+ sv = np.asarray(shap_out[pos_idx], dtype=np.float64)
119
+ if sv.ndim == 2: # (1, n_features)
120
  sv = sv[0, :]
121
  else:
122
  sv = np.asarray(shap_out, dtype=np.float64)
 
126
  sv = sv[0, :]
127
  else:
128
  sv = sv.reshape(-1)
129
+ status.append(f"SHAP vector shape: {sv.shape}")
130
 
131
+ # base value
132
  ev = explainer.expected_value
133
  if isinstance(ev, (list, np.ndarray)):
134
  ev = np.asarray(ev).reshape(-1)
135
+ classes_ = getattr(pipeline, "classes_", None)
136
+ pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
137
+ base_val = float(ev[pos_idx if len(ev) > pos_idx else 0])
138
  else:
139
  base_val = float(ev)
140
 
141
+ # ---- 5) 绘图(优先 force,失败退条形图)----
142
  try:
143
+ fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES)
144
  status.append("Rendered force plot (matplotlib).")
145
+ return round(float(prob), 3), fig, "\n".join(status)
146
  except Exception as e_force:
147
  status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
148
+ order = np.argsort(np.abs(sv))[::-1]
149
+ topk = order[:min(7, sv.shape[0])]
150
+ plt.close('all')
151
+ fig = plt.figure(figsize=(8, 5), dpi=160)
152
+ plt.barh(np.array(FEATURES)[topk], sv[topk])
153
+ plt.xlabel("SHAP value")
154
+ plt.title("Top features (single-sample contribution)")
155
+ plt.gca().invert_yaxis()
156
+ plt.tight_layout()
157
+ status.append("Rendered bar fallback.")
158
+ return round(float(prob), 3), fig, "\n".join(status)
 
159
 
160
  except Exception as e:
161
  return None, None, f"Fatal error: {repr(e)}"
162
 
163
+ # ====== 示例输入(仅 7 项 + nsamples)======
164
+ example_values = [38.0, 68.0, 6.5, 1.0, 75.0, 1.2, 80.0, 200]
165
+ # 顺序:ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples
166
+ # 注:上例将派生 PNI=ALB+5*LYM=44、AAPR=ALB/ALP=0.475,与训练对齐
167
 
168
  # ====== Gradio 界面 ======
169
  with gr.Blocks() as demo:
170
  gr.Markdown(
171
  "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
172
+ "输入 **ALB, TP, TBA, AST/ALT, CREA, LYM, ALP**;应用会内部计算 **PNI=ALB+5×LYM** 与 **AAPR=ALB/ALP**,"
173
+ "并以这 7 个最终特征喂给模型和 SHAP。\n\n"
174
+ "**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
175
+ "LYM(×10⁹/L), ALP(U/L)."
176
  )
177
 
178
  with gr.Row():
179
  with gr.Column(scale=1):
180
  inputs = [
 
 
 
 
 
 
181
  gr.Number(label="ALB (g/L)"),
182
+ gr.Number(label="TP (g/L)"),
183
+ gr.Number(label="TBA (μmol/L)"),
184
+ gr.Number(label="AST/ALT"),
185
+ gr.Number(label="CREA (μmol/L)"),
186
  gr.Number(label="LYM (×10⁹/L)"),
 
187
  gr.Number(label="ALP (U/L)"),
 
188
  ]
189
+ ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
190
 
191
  btn_fill = gr.Button("Fill Example")
192
  btn_predict = gr.Button("Predict")
193
 
194
  with gr.Column(scale=1):
195
+ out_prob = gr.Number(label="Predicted Probability / Score")
196
  out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
197
+ out_log = gr.Textbox(label="Status", lines=8)
198
 
199
  def _fill_example():
200
  return tuple(example_values)