Zeyadd-Mostaffa commited on
Commit
3e635fa
·
verified ·
1 Parent(s): 96a2bd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -27,12 +27,13 @@ def predict_employee_status(
27
  average_monthly_hours, time_spend_company,
28
  work_accident, promotion_last_5years, salary, department, threshold=0.5
29
  ):
 
30
  departments = [
31
- 'sales', 'accounting', 'hr', 'technical', 'support',
32
- 'management', 'IT', 'product_mng', 'marketing', 'RandD'
33
  ]
34
-
35
- # One-hot encode department (include department_IT explicitly)
36
  department_features = {f"department_{dept}": 0 for dept in departments}
37
  if department in departments:
38
  department_features[f"department_{department}"] = 1
@@ -79,8 +80,8 @@ def gradio_interface():
79
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
80
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
81
  gr.Dropdown(
82
- ['sales', 'accounting', 'hr', 'technical', 'support',
83
- 'management', 'IT', 'product_mng', 'marketing', 'RandD'],
84
  label="Department"
85
  ),
86
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
@@ -93,3 +94,4 @@ def gradio_interface():
93
  interface.launch()
94
 
95
  gradio_interface()
 
 
27
  average_monthly_hours, time_spend_company,
28
  work_accident, promotion_last_5years, salary, department, threshold=0.5
29
  ):
30
+ # List of all departments as encoded during training
31
  departments = [
32
+ 'IT', 'RandD', 'accounting', 'hr', 'management',
33
+ 'marketing', 'product_mng', 'sales', 'support', 'technical'
34
  ]
35
+
36
+ # One-hot encode department
37
  department_features = {f"department_{dept}": 0 for dept in departments}
38
  if department in departments:
39
  department_features[f"department_{department}"] = 1
 
80
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
81
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
82
  gr.Dropdown(
83
+ ['IT', 'RandD', 'accounting', 'hr', 'management',
84
+ 'marketing', 'product_mng', 'sales', 'support', 'technical'],
85
  label="Department"
86
  ),
87
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
 
94
  interface.launch()
95
 
96
  gradio_interface()
97
+