Zeyadd-Mostaffa commited on
Commit
cdda9e4
·
verified ·
1 Parent(s): 3da7c41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -7,13 +7,14 @@ import os
7
  import warnings
8
  import shap
9
  import matplotlib.pyplot as plt
 
10
 
11
  # Suppress XGBoost warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
13
 
14
  # Load your model (automatically detect XGBoost or joblib model)
15
  def load_model():
16
- model_path = "best_model.json" # Ensure this matches your file name
17
  if os.path.exists(model_path):
18
  model = xgb.Booster()
19
  model.load_model(model_path)
@@ -25,6 +26,16 @@ def load_model():
25
 
26
  model = load_model()
27
 
 
 
 
 
 
 
 
 
 
 
28
  # Prediction function with dynamic threshold
29
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
30
  average_monthly_hours, time_spent_company,
@@ -70,22 +81,25 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
70
  except Exception as e:
71
  return f"❌ Error: {str(e)}"
72
 
73
- # SHAP Explainability
74
  def explain_prediction(input_df):
75
- explainer = shap.TreeExplainer(model)
76
- shap_values = explainer.shap_values(input_df)
77
-
78
- # Generating SHAP explanation for this prediction
79
- shap.initjs()
80
- plt.figure()
81
- shap.waterfall_plot(shap.Explanation(values=shap_values[0],
82
- base_values=explainer.expected_value,
83
- data=input_df.iloc[0].values,
84
- feature_names=input_df.columns))
85
- plt.savefig("shap_explanation.png")
86
- return "SHAP explanation generated for this prediction."
 
 
 
87
 
88
- # Gradio interface with dynamic threshold
89
  def gradio_interface():
90
  interface = gr.Interface(
91
  fn=predict_employee_status,
@@ -106,7 +120,7 @@ def gradio_interface():
106
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
107
  ],
108
  outputs="text",
109
- title="Employee Retention Prediction System (With SHAP Explainability)",
110
  description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
111
  theme="dark"
112
  )
 
7
  import warnings
8
  import shap
9
  import matplotlib.pyplot as plt
10
+ from sklearn.metrics import roc_curve, precision_recall_curve
11
 
12
  # Suppress XGBoost warnings
13
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
14
 
15
  # Load your model (automatically detect XGBoost or joblib model)
16
  def load_model():
17
+ model_path = "xgboost_model.json" # Ensure this matches your file name
18
  if os.path.exists(model_path):
19
  model = xgb.Booster()
20
  model.load_model(model_path)
 
26
 
27
  model = load_model()
28
 
29
+ # Automatically find the best threshold using ROC
30
+ def optimize_threshold(X_test, y_test):
31
+ dmatrix = xgb.DMatrix(X_test)
32
+ y_prob = model.predict(dmatrix)
33
+
34
+ fpr, tpr, thresholds = roc_curve(y_test, y_prob)
35
+ optimal_idx = np.argmax(tpr - fpr)
36
+ optimal_threshold = thresholds[optimal_idx]
37
+ return optimal_threshold
38
+
39
  # Prediction function with dynamic threshold
40
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
41
  average_monthly_hours, time_spent_company,
 
81
  except Exception as e:
82
  return f"❌ Error: {str(e)}"
83
 
84
+ # SHAP Explainability (Directly Integrated)
85
  def explain_prediction(input_df):
86
+ try:
87
+ explainer = shap.TreeExplainer(model)
88
+ shap_values = explainer.shap_values(input_df)
89
+
90
+ # Generate and save SHAP explanation as an image
91
+ shap.initjs()
92
+ plt.figure()
93
+ shap.waterfall_plot(shap.Explanation(values=shap_values[0],
94
+ base_values=explainer.expected_value,
95
+ data=input_df.iloc[0].values,
96
+ feature_names=input_df.columns))
97
+ plt.savefig("shap_explanation.png")
98
+ return "SHAP explanation generated for this prediction."
99
+ except Exception as e:
100
+ return f"❌ Error in SHAP: {str(e)}"
101
 
102
+ # Gradio interface with dynamic threshold and SHAP
103
  def gradio_interface():
104
  interface = gr.Interface(
105
  fn=predict_employee_status,
 
120
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
121
  ],
122
  outputs="text",
123
+ title="Employee Retention Prediction System (With SHAP & ROC Threshold)",
124
  description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
125
  theme="dark"
126
  )