File size: 3,358 Bytes
7c5d1d0
7a63bb7
a719e13
2750f6c
 
 
11d6e4a
2750f6c
de496ae
7c5d1d0
2750f6c
11d6e4a
 
 
 
 
 
 
 
 
 
 
2750f6c
11d6e4a
 
2750f6c
7a63bb7
2750f6c
 
de496ae
2750f6c
de496ae
3e47c80
de496ae
7a63bb7
 
 
 
a719e13
 
 
7a63bb7
de496ae
2ab8b05
 
 
de496ae
a719e13
 
 
 
2d5fce6
de496ae
a719e13
 
 
de496ae
 
2d5fce6
a719e13
 
 
7a63bb7
de496ae
2750f6c
de496ae
2750f6c
7a63bb7
de496ae
 
 
7a63bb7
de496ae
7c5d1d0
de496ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c5d1d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import numpy as np
import pandas as pd
import joblib
import os
import warnings
import zipfile

warnings.filterwarnings("ignore")

def load_model():
    zip_path = "final_ensemble_model.zip"
    pkl_path = "final_ensemble_model.pkl"

    if not os.path.exists(pkl_path):
        print("πŸ“¦ Extracting model from zip...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(".")
    
    try:
        model = joblib.load(pkl_path)
        print("βœ… Ensemble model loaded.")
        return model
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return None

model = load_model()

# Prediction function
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
                            average_monthly_hours, time_spend_company,
                            work_accident, promotion_last_5years, salary, department, threshold=0.5):

    departments = [
        'RandD', 'accounting', 'hr', 'management', 'marketing',
        'product_mng', 'sales', 'support', 'technical'
    ]
    department_features = {f"department_{dept}": 0 for dept in departments}
    if department in departments:
        department_features[f"department_{department}"] = 1

    # Feature engineering
    satisfaction_evaluation = satisfaction_level * last_evaluation
    work_balance = average_monthly_hours / number_project

    # Construct DataFrame
    input_data = {
        "satisfaction_level": [satisfaction_level],
        "last_evaluation": [last_evaluation],
        "number_project": [number_project],
        "average_monthly_hours": [average_monthly_hours],
        "time_spend_company": [time_spend_company],
        "Work_accident": [work_accident],
        "promotion_last_5years": [promotion_last_5years],
        "salary": [salary],
        "satisfaction_evaluation": [satisfaction_evaluation],
        "work_balance": [work_balance],
        **department_features
    }

    input_df = pd.DataFrame(input_data)

    # Prediction
    if model is None:
        return "❌ No model loaded."

    try:
        prob = model.predict_proba(input_df)[0][1]
        label = "βœ… Employee is likely to quit." if prob >= threshold else "βœ… Employee is likely to stay."
        return f"{label} (Probability: {prob:.2%})"
    except Exception as e:
        return f"❌ Error during prediction: {str(e)}"

# Launch Gradio Interface
gr.Interface(
    fn=predict_employee_status,
    inputs=[
        gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
        gr.Number(label="Last Evaluation (0.0 - 1.0)"),
        gr.Number(label="Number of Projects (1 - 10)"),
        gr.Number(label="Average Monthly Hours (80 - 320)"),
        gr.Number(label="Time Spend at Company (Years)"),
        gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
        gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
        gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
        gr.Dropdown(departments, label="Department"),
        gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
    ],
    outputs="text",
    title="Employee Retention Prediction System (Voting Ensemble)",
    description="Predict whether an employee will stay or quit. Adjust threshold for sensitivity.",
    theme="dark"
).launch()