Spaces:
Sleeping
Sleeping
File size: 9,053 Bytes
b3ed7ae 7876a01 b3ed7ae f2b3a5b 7876a01 b3ed7ae bdab798 7dcfbfd b3ed7ae 7876a01 b3ed7ae f2b3a5b b3ed7ae f2b3a5b b3ed7ae f2b3a5b 7876a01 f2b3a5b b3ed7ae 7876a01 b3ed7ae 7876a01 e6a0696 7dcfbfd e6a0696 7876a01 f2b3a5b e6a0696 f2b3a5b e6a0696 7876a01 e6a0696 7876a01 e6a0696 7876a01 e6a0696 f2b3a5b bdab798 7dcfbfd 7876a01 b3ed7ae f2b3a5b 7dcfbfd f2b3a5b bdab798 f2b3a5b 7dcfbfd f2b3a5b bdab798 f2b3a5b 7876a01 b3ed7ae bdab798 f2b3a5b bdab798 7dcfbfd f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b bdab798 f2b3a5b b3ed7ae 7876a01 b3ed7ae bdab798 9603b83 d816f90 b3ed7ae 7876a01 b3ed7ae 7876a01 bdab798 f2b3a5b bdab798 7876a01 b3ed7ae 7876a01 b3ed7ae f2b3a5b b3ed7ae bdab798 b3ed7ae f2b3a5b 7876a01 f2b3a5b 7876a01 f2b3a5b 7876a01 7dcfbfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import gradio as gr
import pandas as pd
import joblib
import shap
import numpy as np
import matplotlib
matplotlib.use("Agg") # 非交互后端,服务器端更稳
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
# ====== 模型与背景数据 ======
MODEL_PATH = "models/svm_pipeline.joblib"
BG_PATH = "data/bg.csv"
# 模型最终需要的 8 个特征(顺序必须与训练一致)
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR", "DBIL"]
# 加载模型与背景
pipeline = joblib.load(MODEL_PATH)
bg_df = pd.read_csv(BG_PATH)
missing_bg = [c for c in FEATURES if c not in bg_df.columns]
if missing_bg:
raise ValueError(f"背景集缺少列: {missing_bg}")
bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64)
# 预测函数(供 KernelExplainer 调用)——返回正类概率/分数
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray:
df = pd.DataFrame(x_nd, columns=FEATURES)
# 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict
if hasattr(pipeline, "predict_proba"):
proba = pipeline.predict_proba(df)
# 确定正类索引(假定正类标签为 1;若不是,请在此处修改)
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
return proba[:, pos_idx]
elif hasattr(pipeline, "decision_function"):
score = pipeline.decision_function(df)
return score if isinstance(score, np.ndarray) else np.asarray(score)
else:
pred = pipeline.predict(df)
return pred if isinstance(pred, np.ndarray) else np.asarray(pred)
# 只初始化一次 explainer(性能更稳)
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array)
def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames):
"""返回 matplotlib Figure(旧接口,服务器端稳定)"""
plt.close('all')
# ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位)
feat = np.asarray(feat_1d, dtype=float).copy()
round_map = {
"ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3, "DBIL": 1
}
feat_rounded = [
np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames)
]
shap.force_plot(
base_val,
np.asarray(shap_1d).reshape(-1),
np.asarray(feat_rounded).reshape(-1),
feature_names=list(fnames),
matplotlib=True,
show=False
)
fig = plt.gcf()
fig.set_size_inches(14, 3.6) # 加宽、适当降低高度
fig.set_dpi(180) # 提高分辨率
plt.tight_layout(pad=0.6) # 更紧凑
return fig
def _coerce_float(x):
return float(x) if x is not None and x != "" else np.nan
def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL, nsamples=200):
status = []
try:
# ---- 1) 取数并校验 ----
ALB = _coerce_float(ALB)
TP = _coerce_float(TP)
TBA = _coerce_float(TBA)
AST_ALT = _coerce_float(AST_ALT)
CREA = _coerce_float(CREA)
LYM = _coerce_float(LYM)
ALP = _coerce_float(ALP)
DBIL = _coerce_float(DBIL)
vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, DBIL]
if any(np.isnan(v) for v in vals):
return None, None, "Error: 所有输入必须为数值且不可缺失。"
if ALP <= 0:
return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。"
if DBIL < 0:
return None, None, "Error: DBIL 不能为负值。"
# ---- 2) 衍生指标 ----
PNI = ALB + 5.0 * LYM
AAPR = ALB / ALP
status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.3f}")
# ---- 3) 组装最终 8 特征并预测 ----
x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR, DBIL]], dtype=np.float64)
if hasattr(pipeline, "predict_proba"):
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx])
status.append(f"Pred prob: {prob:.3f}")
else:
# 若无概率,给出分数
score = float(
pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0]
) if hasattr(pipeline, "decision_function") else float(
pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0]
)
prob = score
status.append(f"Pred score: {score:.3f}")
# ---- 4) SHAP 计算 ----
ns = int(nsamples) if nsamples is not None else 200
shap_out = explainer.shap_values(x_row, nsamples=ns)
# 统一提取“一维贡献向量”
if isinstance(shap_out, list):
# 二分类:list 长度=2,取正类
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
sv = np.asarray(shap_out[pos_idx], dtype=np.float64)
if sv.ndim == 2: # (1, n_features)
sv = sv[0, :]
else:
sv = np.asarray(shap_out, dtype=np.float64)
if sv.ndim == 3: # (1, n_features, n_classes)
sv = sv[0, :, 1]
elif sv.ndim == 2: # (1, n_features)
sv = sv[0, :]
else:
sv = sv.reshape(-1)
status.append(f"SHAP vector shape: {sv.shape}")
# base value
ev = explainer.expected_value
if isinstance(ev, (list, np.ndarray)):
ev = np.asarray(ev).reshape(-1)
classes_ = getattr(pipeline, "classes_", None)
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1
base_val = float(ev[pos_idx if len(ev) > pos_idx else 0])
else:
base_val = float(ev)
# ---- 5) 绘图(优先 force,失败退条形图)----
try:
fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES)
status.append("Rendered force plot (matplotlib).")
return round(float(prob), 3), fig, "\n".join(status)
except Exception as e_force:
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar")
order = np.argsort(np.abs(sv))[::-1]
topk = order[:min(len(FEATURES), sv.shape[0])]
plt.close('all')
fig = plt.figure(figsize=(8, 5), dpi=160)
plt.barh(np.array(FEATURES)[topk], sv[topk])
plt.xlabel("SHAP value")
plt.title("Top features (single-sample contribution)")
plt.gca().invert_yaxis()
plt.tight_layout()
status.append("Rendered bar fallback.")
return round(float(prob), 3), fig, "\n".join(status)
except Exception as e:
return None, None, f"Fatal error: {repr(e)}"
# ====== 示例输入(8 项 + nsamples)======
# 给定示例:
# ALB=40.7, TP=61.7, TBA=4.9, AST/ALT=1.0, CREA=54.3, PNI=48.6, AAPR=0.473, DBIL=8.5
# 由于前端输入仍为原始 6 项 + DBIL(PNI、AAPR 在后端计算),需“还原” LYM 与 ALP:
example_values = [40.7, 61.7, 4.9, 1.0, 54.3, 1.58, 86, 8.5, 200]
# ====== Gradio 界面 ======
with gr.Blocks() as demo:
gr.Markdown(
"### Meige Risk Prediction (SVM) with SHAP Explanation\n"
"Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, ALP, and DBIL.\n\n"
"**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), "
"LYM(×10⁹/L), ALP(U/L), DBIL(μmol/L)."
)
with gr.Row():
with gr.Column(scale=1):
inputs = [
gr.Number(label="ALB (g/L)"),
gr.Number(label="TP (g/L)"),
gr.Number(label="TBA (μmol/L)"),
gr.Number(label="AST/ALT"),
gr.Number(label="CREA (μmol/L)"),
gr.Number(label="LYM (×10⁹/L)"),
gr.Number(label="ALP (U/L)"),
gr.Number(label="DBIL (μmol/L)"),
]
ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples")
btn_fill = gr.Button("Fill Example")
btn_predict = gr.Button("Predict")
with gr.Column(scale=1):
out_prob = gr.Number(label="Predicted Probability / Score")
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)")
out_log = gr.Textbox(label="Status", lines=8)
def _fill_example():
return tuple(example_values)
btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider])
btn_predict.click(
fn=predict_and_explain,
inputs=[*inputs, ns_slider],
outputs=[out_prob, out_plot, out_log]
)
if __name__ == "__main__":
demo.launch()
|