| import gradio as gr | |
| import random | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import shap | |
| import xgboost as xgb | |
| from datasets import load_dataset | |
| matplotlib.use("Agg") | |
| dataset = load_dataset("scikit-learn/adult-census-income") | |
| X_train = dataset["train"].to_pandas() | |
| _ = X_train.pop("fnlwgt") | |
| _ = X_train.pop("race") | |
| y_train = X_train.pop("income") | |
| y_train = (y_train == ">50K").astype(int) | |
| categorical_columns = [ | |
| "workclass", | |
| "education", | |
| "marital.status", | |
| "occupation", | |
| "relationship", | |
| "sex", | |
| "native.country", | |
| ] | |
| X_train = X_train.astype({col: "category" for col in categorical_columns}) | |
| data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True) | |
| model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data) | |
| explainer = shap.TreeExplainer(model) | |
| def predict(*args): | |
| df = pd.DataFrame([args], columns=X_train.columns) | |
| df = df.astype({col: "category" for col in categorical_columns}) | |
| pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True)) | |
| return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])} | |
| def interpret(*args): | |
| df = pd.DataFrame([args], columns=X_train.columns) | |
| df = df.astype({col: "category" for col in categorical_columns}) | |
| shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True)) | |
| scores_desc = list(zip(shap_values[0], X_train.columns)) | |
| scores_desc = sorted(scores_desc) | |
| fig_m = plt.figure(tight_layout=True) | |
| plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc]) | |
| plt.title("Feature Shap Values") | |
| plt.ylabel("Shap Value") | |
| plt.xlabel("Feature") | |
| plt.tight_layout() | |
| return fig_m | |
| unique_class = sorted(X_train["workclass"].unique()) | |
| unique_education = sorted(X_train["education"].unique()) | |
| unique_marital_status = sorted(X_train["marital.status"].unique()) | |
| unique_relationship = sorted(X_train["relationship"].unique()) | |
| unique_occupation = sorted(X_train["occupation"].unique()) | |
| unique_sex = sorted(X_train["sex"].unique()) | |
| unique_country = sorted(X_train["native.country"].unique()) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True) | |
| work_class = gr.Dropdown( | |
| label="Workclass", | |
| choices=unique_class, | |
| value=lambda: random.choice(unique_class), | |
| ) | |
| education = gr.Dropdown( | |
| label="Education Level", | |
| choices=unique_education, | |
| value=lambda: random.choice(unique_education), | |
| ) | |
| years = gr.Slider( | |
| label="Years of schooling", | |
| minimum=1, | |
| maximum=16, | |
| step=1, | |
| randomize=True, | |
| ) | |
| marital_status = gr.Dropdown( | |
| label="Marital Status", | |
| choices=unique_marital_status, | |
| value=lambda: random.choice(unique_marital_status), | |
| ) | |
| occupation = gr.Dropdown( | |
| label="Occupation", | |
| choices=unique_occupation, | |
| value=lambda: random.choice(unique_occupation), | |
| ) | |
| relationship = gr.Dropdown( | |
| label="Relationship Status", | |
| choices=unique_relationship, | |
| value=lambda: random.choice(unique_relationship), | |
| ) | |
| sex = gr.Dropdown( | |
| label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex) | |
| ) | |
| capital_gain = gr.Slider( | |
| label="Capital Gain", | |
| minimum=0, | |
| maximum=100000, | |
| step=500, | |
| randomize=True, | |
| ) | |
| capital_loss = gr.Slider( | |
| label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True | |
| ) | |
| hours_per_week = gr.Slider( | |
| label="Hours Per Week Worked", minimum=1, maximum=99, step=1 | |
| ) | |
| country = gr.Dropdown( | |
| label="Native Country", | |
| choices=unique_country, | |
| value=lambda: random.choice(unique_country), | |
| ) | |
| with gr.Column(): | |
| label = gr.Label() | |
| plot = gr.Plot() | |
| with gr.Row(): | |
| predict_btn = gr.Button(value="Predict") | |
| interpret_btn = gr.Button(value="Explain") | |
| predict_btn.click( | |
| predict, | |
| inputs=[ | |
| age, | |
| work_class, | |
| education, | |
| years, | |
| marital_status, | |
| occupation, | |
| relationship, | |
| sex, | |
| capital_gain, | |
| capital_loss, | |
| hours_per_week, | |
| country, | |
| ], | |
| outputs=[label], | |
| ) | |
| interpret_btn.click( | |
| interpret, | |
| inputs=[ | |
| age, | |
| work_class, | |
| education, | |
| years, | |
| marital_status, | |
| occupation, | |
| relationship, | |
| sex, | |
| capital_gain, | |
| capital_loss, | |
| hours_per_week, | |
| country, | |
| ], | |
| outputs=[plot], | |
| ) | |
| demo.launch() | |