Multiple123 commited on
Commit
bdab798
·
verified ·
1 Parent(s): 5225542

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py (7-feature aligned, server-safe)
2
  import gradio as gr
3
  import pandas as pd
4
  import joblib
@@ -14,8 +13,8 @@ warnings.filterwarnings("ignore")
14
  MODEL_PATH = "models/svm_pipeline.joblib"
15
  BG_PATH = "data/bg.csv"
16
 
17
- # 模型最终需要的 7 个特征(顺序必须与训练一致)
18
- FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR"]
19
 
20
  # 加载模型与背景
21
  pipeline = joblib.load(MODEL_PATH)
@@ -52,15 +51,13 @@ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray
52
 
53
  # ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位)
54
  feat = np.asarray(feat_1d, dtype=float).copy()
55
- # 这里给出一套通用规则:ALB/TP/CREA/PNI 一位小数;TBA 两位;AST_ALT 与 AAPR 两位
56
  round_map = {
57
- "ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3
58
  }
59
  feat_rounded = [
60
  np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames)
61
  ]
62
 
63
- # 如需进一步减字,可把特征名缩写或去掉空格,这里保持原名
64
  shap.force_plot(
65
  base_val,
66
  np.asarray(shap_1d).reshape(-1),
@@ -80,7 +77,8 @@ def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray
80
  def _coerce_float(x):
81
  return float(x) if x is not None and x != "" else np.nan
82
 
83
- def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
 
84
  status = []
85
  try:
86
  # ---- 1) 取数并校验 ----
@@ -91,21 +89,24 @@ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
91
  CREA = _coerce_float(CREA)
92
  LYM = _coerce_float(LYM)
93
  ALP = _coerce_float(ALP)
 
94
 
95
- vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP]
96
  if any(np.isnan(v) for v in vals):
97
  return None, None, "Error: 所有输入必须为数值且不可缺失。"
98
 
99
  if ALP <= 0:
100
  return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
 
 
101
 
102
  # ---- 2) 衍生指标 ----
103
  PNI = ALB + 5.0 * LYM
104
  AAPR = ALB / ALP
105
- status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.2f}")
106
 
107
- # ---- 3) 组装最终 7 特征并预测 ----
108
- x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR]], dtype=np.float64)
109
 
110
  if hasattr(pipeline, "predict_proba"):
111
  classes_ = getattr(pipeline, "classes_", None)
@@ -162,7 +163,7 @@ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
162
  except Exception as e_force:
163
  status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
164
  order = np.argsort(np.abs(sv))[::-1]
165
- topk = order[:min(7, sv.shape[0])]
166
  plt.close('all')
167
  fig = plt.figure(figsize=(8, 5), dpi=160)
168
  plt.barh(np.array(FEATURES)[topk], sv[topk])
@@ -176,19 +177,21 @@ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200):
176
  except Exception as e:
177
  return None, None, f"Fatal error: {repr(e)}"
178
 
179
- # ====== 示例输入(仅 7 项 + nsamples)======
180
- example_values = [41.7, 64.9, 0.870, 0.890, 55.9, 1.95, 51.96, 200]
181
-
182
-
 
 
183
 
184
 
185
  # ====== Gradio 界面 ======
186
  with gr.Blocks() as demo:
187
  gr.Markdown(
188
  "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
189
- "Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, and ALP. The values for PNI and AAPR will be calculated automatically.\n\n"
190
  "**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
191
- "LYM(×10⁹/L), ALP(U/L)."
192
  )
193
 
194
  with gr.Row():
@@ -201,6 +204,7 @@ with gr.Blocks() as demo:
201
  gr.Number(label="CREA (μmol/L)"),
202
  gr.Number(label="LYM (×10⁹/L)"),
203
  gr.Number(label="ALP (U/L)"),
 
204
  ]
205
  ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
206
 
@@ -223,4 +227,4 @@ with gr.Blocks() as demo:
223
  )
224
 
225
  if __name__ == "__main__":
226
- demo.launch()
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import joblib
 
13
  MODEL_PATH = "models/svm_pipeline.joblib"
14
  BG_PATH = "data/bg.csv"
15
 
16
+ # 模型最终需要的 8 个特征(顺序必须与训练一致)
17
+ FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR", "DBIL"]
18
 
19
  # 加载模型与背景
20
  pipeline = joblib.load(MODEL_PATH)
 
51
 
52
  # ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位)
53
  feat = np.asarray(feat_1d, dtype=float).copy()
 
54
  round_map = {
55
+ "ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3, "DBIL": 1
56
  }
57
  feat_rounded = [
58
  np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames)
59
  ]
60
 
 
61
  shap.force_plot(
62
  base_val,
63
  np.asarray(shap_1d).reshape(-1),
 
77
  def _coerce_float(x):
78
  return float(x) if x is not None and x != "" else np.nan
79
 
80
+
81
+ def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL, nsamples=200):
82
  status = []
83
  try:
84
  # ---- 1) 取数并校验 ----
 
89
  CREA = _coerce_float(CREA)
90
  LYM = _coerce_float(LYM)
91
  ALP = _coerce_float(ALP)
92
+ DBIL = _coerce_float(DBIL)
93
 
94
+ vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL]
95
  if any(np.isnan(v) for v in vals):
96
  return None, None, "Error: 所有输入必须为数值且不可缺失。"
97
 
98
  if ALP <= 0:
99
  return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
100
+ if DBIL < 0:
101
+ return None, None, "Error: DBIL 不能为负值。"
102
 
103
  # ---- 2) 衍生指标 ----
104
  PNI = ALB + 5.0 * LYM
105
  AAPR = ALB / ALP
106
+ status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.3f}")
107
 
108
+ # ---- 3) 组装最终 8 特征并预测 ----
109
+ x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR, DBIL]], dtype=np.float64)
110
 
111
  if hasattr(pipeline, "predict_proba"):
112
  classes_ = getattr(pipeline, "classes_", None)
 
163
  except Exception as e_force:
164
  status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
165
  order = np.argsort(np.abs(sv))[::-1]
166
+ topk = order[:min(len(FEATURES), sv.shape[0])]
167
  plt.close('all')
168
  fig = plt.figure(figsize=(8, 5), dpi=160)
169
  plt.barh(np.array(FEATURES)[topk], sv[topk])
 
177
  except Exception as e:
178
  return None, None, f"Fatal error: {repr(e)}"
179
 
180
+ # ====== 示例输入(8 项 + nsamples)======
181
+ # 给定示例:
182
+ # ALB=40.7, TP=61.7, TBA=4.9, AST/ALT=1.0, CREA=54.3, PNI=48.6, AAPR=0.473, DBIL=8.5
183
+ # 由于前端输入仍为原始 6 项 + DBIL(PNI、AAPR 在后端计算),需“还原” LYM 与 ALP:
184
+ # LYM=(PNI-ALB)/5=(48.6-40.7)/5=1.58; ALP=ALB/AAPR=40.7/0.473≈86.0465
185
+ example_values = [40.7, 61.7, 4.9, 1.0, 54.3, 1.58, 86.0465, 8.5, 200]
186
 
187
 
188
  # ====== Gradio 界面 ======
189
  with gr.Blocks() as demo:
190
  gr.Markdown(
191
  "### Meige Risk Prediction (SVM) with SHAP Explanation\n"
192
+ "Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, ALP, and DBIL.\n\n"
193
  "**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
194
+ "LYM(×10⁹/L), ALP(U/L), DBIL(μmol/L)."
195
  )
196
 
197
  with gr.Row():
 
204
  gr.Number(label="CREA (μmol/L)"),
205
  gr.Number(label="LYM (×10⁹/L)"),
206
  gr.Number(label="ALP (U/L)"),
207
+ gr.Number(label="DBIL (μmol/L)"),
208
  ]
209
  ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
210
 
 
227
  )
228
 
229
  if __name__ == "__main__":
230
+ demo.launch()