Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,14 @@ import os
|
|
7 |
import warnings
|
8 |
import shap
|
9 |
import matplotlib.pyplot as plt
|
|
|
10 |
|
11 |
# Suppress XGBoost warnings
|
12 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
|
13 |
|
14 |
# Load your model (automatically detect XGBoost or joblib model)
|
15 |
def load_model():
|
16 |
-
model_path = "
|
17 |
if os.path.exists(model_path):
|
18 |
model = xgb.Booster()
|
19 |
model.load_model(model_path)
|
@@ -25,6 +26,16 @@ def load_model():
|
|
25 |
|
26 |
model = load_model()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# Prediction function with dynamic threshold
|
29 |
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
30 |
average_monthly_hours, time_spent_company,
|
@@ -70,22 +81,25 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
|
70 |
except Exception as e:
|
71 |
return f"❌ Error: {str(e)}"
|
72 |
|
73 |
-
# SHAP Explainability
|
74 |
def explain_prediction(input_df):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
87 |
|
88 |
-
# Gradio interface with dynamic threshold
|
89 |
def gradio_interface():
|
90 |
interface = gr.Interface(
|
91 |
fn=predict_employee_status,
|
@@ -106,7 +120,7 @@ def gradio_interface():
|
|
106 |
gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
|
107 |
],
|
108 |
outputs="text",
|
109 |
-
title="Employee Retention Prediction System (With SHAP
|
110 |
description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
|
111 |
theme="dark"
|
112 |
)
|
|
|
7 |
import warnings
|
8 |
import shap
|
9 |
import matplotlib.pyplot as plt
|
10 |
+
from sklearn.metrics import roc_curve, precision_recall_curve
|
11 |
|
12 |
# Suppress XGBoost warnings
|
13 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
|
14 |
|
15 |
# Load your model (automatically detect XGBoost or joblib model)
|
16 |
def load_model():
|
17 |
+
model_path = "xgboost_model.json" # Ensure this matches your file name
|
18 |
if os.path.exists(model_path):
|
19 |
model = xgb.Booster()
|
20 |
model.load_model(model_path)
|
|
|
26 |
|
27 |
model = load_model()
|
28 |
|
29 |
+
# Automatically find the best threshold using ROC
|
30 |
+
def optimize_threshold(X_test, y_test):
|
31 |
+
dmatrix = xgb.DMatrix(X_test)
|
32 |
+
y_prob = model.predict(dmatrix)
|
33 |
+
|
34 |
+
fpr, tpr, thresholds = roc_curve(y_test, y_prob)
|
35 |
+
optimal_idx = np.argmax(tpr - fpr)
|
36 |
+
optimal_threshold = thresholds[optimal_idx]
|
37 |
+
return optimal_threshold
|
38 |
+
|
39 |
# Prediction function with dynamic threshold
|
40 |
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
41 |
average_monthly_hours, time_spent_company,
|
|
|
81 |
except Exception as e:
|
82 |
return f"❌ Error: {str(e)}"
|
83 |
|
84 |
+
# SHAP Explainability (Directly Integrated)
|
85 |
def explain_prediction(input_df):
|
86 |
+
try:
|
87 |
+
explainer = shap.TreeExplainer(model)
|
88 |
+
shap_values = explainer.shap_values(input_df)
|
89 |
+
|
90 |
+
# Generate and save SHAP explanation as an image
|
91 |
+
shap.initjs()
|
92 |
+
plt.figure()
|
93 |
+
shap.waterfall_plot(shap.Explanation(values=shap_values[0],
|
94 |
+
base_values=explainer.expected_value,
|
95 |
+
data=input_df.iloc[0].values,
|
96 |
+
feature_names=input_df.columns))
|
97 |
+
plt.savefig("shap_explanation.png")
|
98 |
+
return "SHAP explanation generated for this prediction."
|
99 |
+
except Exception as e:
|
100 |
+
return f"❌ Error in SHAP: {str(e)}"
|
101 |
|
102 |
+
# Gradio interface with dynamic threshold and SHAP
|
103 |
def gradio_interface():
|
104 |
interface = gr.Interface(
|
105 |
fn=predict_employee_status,
|
|
|
120 |
gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
|
121 |
],
|
122 |
outputs="text",
|
123 |
+
title="Employee Retention Prediction System (With SHAP & ROC Threshold)",
|
124 |
description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
|
125 |
theme="dark"
|
126 |
)
|