Zeyadd-Mostaffa commited on
Commit
ef2280b
Β·
verified Β·
1 Parent(s): 903e139

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -31
app.py CHANGED
@@ -7,8 +7,6 @@ import os
7
  import warnings
8
  import shap
9
  import matplotlib.pyplot as plt
10
- from sklearn.metrics import roc_curve, precision_recall_curve
11
- from imblearn.over_sampling import SMOTE
12
 
13
  # Suppress XGBoost warnings
14
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
@@ -27,7 +25,7 @@ def load_model():
27
 
28
  model = load_model()
29
 
30
- # Prediction function with dynamic threshold and balanced data
31
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
32
  average_monthly_hours, time_spent_company,
33
  work_accident, promotion_last_5years, salary, department, threshold=0.5):
@@ -41,22 +39,22 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
41
  if department in departments:
42
  department_features[f"department_{department}"] = 1
43
 
44
- # Automatically Generate Interaction Features
45
  satisfaction_evaluation = satisfaction_level * last_evaluation
46
  work_balance = average_monthly_hours / number_project
47
 
48
- # Prepare the input with all expected features as a DataFrame with column names
49
  input_data = {
50
  "satisfaction_level": [satisfaction_level],
51
  "last_evaluation": [last_evaluation],
52
  "number_project": [number_project],
53
  "average_monthly_hours": [average_monthly_hours],
54
- "time_spent_company": [time_spent_company],
55
  "Work_accident": [work_accident],
56
  "promotion_last_5years": [promotion_last_5years],
57
  "salary": [salary],
58
- "satisfaction_evaluation": [satisfaction_evaluation],
59
- "work_balance": [work_balance],
60
  **department_features
61
  }
62
 
@@ -73,30 +71,11 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
73
 
74
  # Apply the dynamic threshold
75
  result = "βœ… Employee is likely to quit." if prediction_prob >= threshold else "βœ… Employee is likely to stay."
76
- explanation = explain_prediction(input_df)
77
- return f"{result} (Probability: {prediction_prob:.2%})\n\nExplanation:\n{explanation}"
78
  except Exception as e:
79
  return f"❌ Error: {str(e)}"
80
 
81
- # SHAP Explainability (Directly Integrated)
82
- def explain_prediction(input_df):
83
- try:
84
- explainer = shap.TreeExplainer(model)
85
- shap_values = explainer.shap_values(input_df)
86
-
87
- # Generate and save SHAP explanation as an image
88
- shap.initjs()
89
- plt.figure()
90
- shap.waterfall_plot(shap.Explanation(values=shap_values[0],
91
- base_values=explainer.expected_value,
92
- data=input_df.iloc[0].values,
93
- feature_names=input_df.columns))
94
- plt.savefig("shap_explanation.png")
95
- return "SHAP explanation generated for this prediction."
96
- except Exception as e:
97
- return f"❌ Error in SHAP: {str(e)}"
98
-
99
- # Gradio interface with dynamic threshold and SHAP
100
  def gradio_interface():
101
  interface = gr.Interface(
102
  fn=predict_employee_status,
@@ -105,7 +84,7 @@ def gradio_interface():
105
  gr.Number(label="Last Evaluation (0.0 - 1.0)"),
106
  gr.Number(label="Number of Projects (1 - 10)"),
107
  gr.Number(label="Average Monthly Hours (80 - 320)"),
108
- gr.Number(label="Time Spent at Company (Years)"),
109
  gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
110
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
111
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
@@ -124,4 +103,3 @@ def gradio_interface():
124
  interface.launch()
125
 
126
  gradio_interface()
127
-
 
7
  import warnings
8
  import shap
9
  import matplotlib.pyplot as plt
 
 
10
 
11
  # Suppress XGBoost warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
 
25
 
26
  model = load_model()
27
 
28
+ # Prediction function with consistent feature names
29
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
30
  average_monthly_hours, time_spent_company,
31
  work_accident, promotion_last_5years, salary, department, threshold=0.5):
 
39
  if department in departments:
40
  department_features[f"department_{department}"] = 1
41
 
42
+ # Generate Interaction Features
43
  satisfaction_evaluation = satisfaction_level * last_evaluation
44
  work_balance = average_monthly_hours / number_project
45
 
46
+ # Prepare the input with all expected features
47
  input_data = {
48
  "satisfaction_level": [satisfaction_level],
49
  "last_evaluation": [last_evaluation],
50
  "number_project": [number_project],
51
  "average_monthly_hours": [average_monthly_hours],
52
+ "time_spent_company": [time_spent_company], # Corrected
53
  "Work_accident": [work_accident],
54
  "promotion_last_5years": [promotion_last_5years],
55
  "salary": [salary],
56
+ "satisfaction_evaluation": [satisfaction_evaluation], # Added
57
+ "work_balance": [work_balance], # Added
58
  **department_features
59
  }
60
 
 
71
 
72
  # Apply the dynamic threshold
73
  result = "βœ… Employee is likely to quit." if prediction_prob >= threshold else "βœ… Employee is likely to stay."
74
+ return f"{result} (Probability: {prediction_prob:.2%})"
 
75
  except Exception as e:
76
  return f"❌ Error: {str(e)}"
77
 
78
+ # Gradio interface with consistent feature names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def gradio_interface():
80
  interface = gr.Interface(
81
  fn=predict_employee_status,
 
84
  gr.Number(label="Last Evaluation (0.0 - 1.0)"),
85
  gr.Number(label="Number of Projects (1 - 10)"),
86
  gr.Number(label="Average Monthly Hours (80 - 320)"),
87
+ gr.Number(label="Time Spent at Company (Years)"), # Corrected
88
  gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
89
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
90
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
 
103
  interface.launch()
104
 
105
  gradio_interface()