Spaces:
Sleeping
Sleeping
File size: 5,067 Bytes
7c5d1d0 2750f6c 7a63bb7 a719e13 2750f6c 3e47c80 cdda9e4 2ab8b05 2750f6c 7c5d1d0 2750f6c cdda9e4 2750f6c 7a63bb7 2750f6c 2ab8b05 2750f6c 3e47c80 2750f6c 7a63bb7 a719e13 7a63bb7 2ab8b05 a719e13 2d5fce6 3e47c80 a719e13 2ab8b05 2d5fce6 a719e13 7a63bb7 2750f6c 7a63bb7 a719e13 2750f6c b444f01 3e47c80 7a63bb7 2750f6c 7c5d1d0 cdda9e4 3e47c80 cdda9e4 3e47c80 cdda9e4 3e47c80 cdda9e4 3e47c80 7c5d1d0 3e47c80 2ab8b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import xgboost as xgb
import numpy as np
import pandas as pd
import joblib
import os
import warnings
import shap
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve
from imblearn.over_sampling import SMOTE
# Suppress XGBoost warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
# Load your model (automatically detect XGBoost or joblib model)
def load_model():
model_path = "xgboost_model.json" # Ensure this matches your file name
if os.path.exists(model_path):
model = xgb.Booster()
model.load_model(model_path)
print("β
Model loaded successfully.")
return model
else:
print("β Model file not found.")
return None
model = load_model()
# Prediction function with dynamic threshold and balanced data
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
average_monthly_hours, time_spent_company,
work_accident, promotion_last_5years, salary, department, threshold=0.5):
# One-hot encode the department
departments = [
'RandD', 'accounting', 'hr', 'management', 'marketing',
'product_mng', 'sales', 'support', 'technical'
]
department_features = {f"department_{dept}": 0 for dept in departments}
if department in departments:
department_features[f"department_{department}"] = 1
# Automatically Generate Interaction Features
satisfaction_evaluation = satisfaction_level * last_evaluation
work_balance = average_monthly_hours / number_project
# Prepare the input with all expected features as a DataFrame with column names
input_data = {
"satisfaction_level": [satisfaction_level],
"last_evaluation": [last_evaluation],
"number_project": [number_project],
"average_monthly_hours": [average_monthly_hours],
"time_spent_company": [time_spent_company],
"Work_accident": [work_accident],
"promotion_last_5years": [promotion_last_5years],
"salary": [salary],
"satisfaction_evaluation": [satisfaction_evaluation],
"work_balance": [work_balance],
**department_features
}
input_df = pd.DataFrame(input_data)
# Predict using the model
if model is None:
return "β No model found. Please upload the model file."
try:
dmatrix = xgb.DMatrix(input_df)
prediction = model.predict(dmatrix)
prediction_prob = prediction[0]
# Apply the dynamic threshold
result = "β
Employee is likely to quit." if prediction_prob >= threshold else "β
Employee is likely to stay."
explanation = explain_prediction(input_df)
return f"{result} (Probability: {prediction_prob:.2%})\n\nExplanation:\n{explanation}"
except Exception as e:
return f"β Error: {str(e)}"
# SHAP Explainability (Directly Integrated)
def explain_prediction(input_df):
try:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(input_df)
# Generate and save SHAP explanation as an image
shap.initjs()
plt.figure()
shap.waterfall_plot(shap.Explanation(values=shap_values[0],
base_values=explainer.expected_value,
data=input_df.iloc[0].values,
feature_names=input_df.columns))
plt.savefig("shap_explanation.png")
return "SHAP explanation generated for this prediction."
except Exception as e:
return f"β Error in SHAP: {str(e)}"
# Gradio interface with dynamic threshold and SHAP
def gradio_interface():
interface = gr.Interface(
fn=predict_employee_status,
inputs=[
gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
gr.Number(label="Last Evaluation (0.0 - 1.0)"),
gr.Number(label="Number of Projects (1 - 10)"),
gr.Number(label="Average Monthly Hours (80 - 320)"),
gr.Number(label="Time Spent at Company (Years)"),
gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
gr.Dropdown(
['RandD', 'accounting', 'hr', 'management', 'marketing',
'product_mng', 'sales', 'support', 'technical'],
label="Department"
),
gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
],
outputs="text",
title="Employee Retention Prediction System (With SHAP & ROC Threshold)",
description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
theme="dark"
)
interface.launch()
gradio_interface()
|