Zeyadd-Mostaffa commited on
Commit
3e47c80
Β·
verified Β·
1 Parent(s): 2633968

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -29
app.py CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
5
  import joblib
6
  import os
7
  import warnings
 
 
8
 
9
  # Suppress XGBoost warnings
10
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
@@ -23,10 +25,10 @@ def load_model():
23
 
24
  model = load_model()
25
 
26
- # Prediction function
27
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
28
  average_monthly_hours, time_spent_company,
29
- work_accident, promotion_last_5years, salary, department):
30
 
31
  # One-hot encode the department
32
  departments = [
@@ -43,7 +45,7 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
43
  "last_evaluation": [last_evaluation],
44
  "number_project": [number_project],
45
  "average_monthly_hours": [average_monthly_hours],
46
- "time_spend_company": [time_spent_company],
47
  "Work_accident": [work_accident],
48
  "promotion_last_5years": [promotion_last_5years],
49
  "salary": [salary],
@@ -60,33 +62,54 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
60
  dmatrix = xgb.DMatrix(input_df)
61
  prediction = model.predict(dmatrix)
62
  prediction_prob = prediction[0]
63
- result = "βœ… Employee is likely to quit." if prediction_prob > 0.3 else "βœ… Employee is likely to stay."
64
- return f"{result}"
 
 
 
65
  except Exception as e:
66
  return f"❌ Error: {str(e)}"
67
 
68
- # Gradio interface
69
- interface = gr.Interface(
70
- fn=predict_employee_status,
71
- inputs=[
72
- gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
73
- gr.Number(label="Last Evaluation (0.0 - 1.0)"),
74
- gr.Number(label="Number of Projects (1 - 10)"),
75
- gr.Number(label="Average Monthly Hours (80 - 320)"),
76
- gr.Number(label="Time Spent at Company (Years)"),
77
- gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
78
- gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
79
- gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
80
- gr.Dropdown(
81
- ['RandD', 'accounting', 'hr', 'management', 'marketing',
82
- 'product_mng', 'sales', 'support', 'technical'],
83
- label="Department"
84
- )
85
- ],
86
- outputs="text",
87
- title="Employee Retention Prediction System",
88
- description="Predict whether an employee is likely to stay or quit based on their profile.",
89
- theme="dark"
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- interface.launch()
 
5
  import joblib
6
  import os
7
  import warnings
8
+ import shap
9
+ import matplotlib.pyplot as plt
10
 
11
  # Suppress XGBoost warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
 
25
 
26
  model = load_model()
27
 
28
+ # Prediction function with dynamic threshold
29
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
30
  average_monthly_hours, time_spent_company,
31
+ work_accident, promotion_last_5years, salary, department, threshold=0.5):
32
 
33
  # One-hot encode the department
34
  departments = [
 
45
  "last_evaluation": [last_evaluation],
46
  "number_project": [number_project],
47
  "average_monthly_hours": [average_monthly_hours],
48
+ "time_spent_company": [time_spent_company],
49
  "Work_accident": [work_accident],
50
  "promotion_last_5years": [promotion_last_5years],
51
  "salary": [salary],
 
62
  dmatrix = xgb.DMatrix(input_df)
63
  prediction = model.predict(dmatrix)
64
  prediction_prob = prediction[0]
65
+
66
+ # Apply the dynamic threshold
67
+ result = "βœ… Employee is likely to quit." if prediction_prob >= threshold else "βœ… Employee is likely to stay."
68
+ explanation = explain_prediction(input_df)
69
+ return f"{result} (Probability: {prediction_prob:.2%})\n\nExplanation:\n{explanation}"
70
  except Exception as e:
71
  return f"❌ Error: {str(e)}"
72
 
73
+ # SHAP Explainability
74
+ def explain_prediction(input_df):
75
+ explainer = shap.TreeExplainer(model)
76
+ shap_values = explainer.shap_values(input_df)
77
+
78
+ # Generating SHAP explanation for this prediction
79
+ shap.initjs()
80
+ plt.figure()
81
+ shap.waterfall_plot(shap.Explanation(values=shap_values[0],
82
+ base_values=explainer.expected_value,
83
+ data=input_df.iloc[0].values,
84
+ feature_names=input_df.columns))
85
+ plt.savefig("shap_explanation.png")
86
+ return "SHAP explanation generated for this prediction."
87
+
88
+ # Gradio interface with dynamic threshold
89
+ def gradio_interface():
90
+ interface = gr.Interface(
91
+ fn=predict_employee_status,
92
+ inputs=[
93
+ gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
94
+ gr.Number(label="Last Evaluation (0.0 - 1.0)"),
95
+ gr.Number(label="Number of Projects (1 - 10)"),
96
+ gr.Number(label="Average Monthly Hours (80 - 320)"),
97
+ gr.Number(label="Time Spent at Company (Years)"),
98
+ gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
99
+ gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
100
+ gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
101
+ gr.Dropdown(
102
+ ['RandD', 'accounting', 'hr', 'management', 'marketing',
103
+ 'product_mng', 'sales', 'support', 'technical'],
104
+ label="Department"
105
+ ),
106
+ gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
107
+ ],
108
+ outputs="text",
109
+ title="Employee Retention Prediction System (With SHAP Explainability)",
110
+ description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
111
+ theme="dark"
112
+ )
113
+ interface.launch()
114
 
115
+ gradio_interface()