| | import argparse |
| | import os |
| | from data_loader import load_and_process_data, CATEGORICAL_COLUMNS |
| | from model_trainer import train_models |
| | from model_manager import save_models, load_models |
| | from model_predictor import predict |
| | from config import MODEL_DIR, CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS |
| | import wandb |
| | from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report |
| | import pandas as pd |
| |
|
| | |
| | |
| | |
| |
|
| | def main(train=True, retrain=False): |
| | """ Main entry point to train, retrain or predict """ |
| | |
| | if not os.path.exists(MODEL_DIR): |
| | os.makedirs(MODEL_DIR) |
| | print("\nπ Loading data...") |
| | X_train, X_val, y_train, y_val, test_df = load_and_process_data() |
| |
|
| | if train or retrain: |
| | print("\nπ Training models...") |
| | models = train_models(X_train, y_train, CATEGORICAL_COLUMNS) |
| | save_models(models) |
| |
|
| | else: |
| | print("\nπ Loading existing models...") |
| | models = load_models() |
| |
|
| |
|
| | |
| | param_grid = {"CATBOOST_PARAMS": CATBOOST_PARAMS, |
| | "XGB_PARAMS": XGB_PARAMS, |
| | "RF_PARAMS": RF_PARAMS} |
| | os.getenv("WANDB_API_KEY") |
| | run = wandb.init(project="is_click_predictor", config=param_grid) |
| |
|
| | print("\nπ Makings predictions for validation set...") |
| | predictions_val = predict(models, X_val) |
| | accuracy_val = accuracy_score(y_val, predictions_val["is_click_predicted"]) |
| | balanced_accuracy_val = balanced_accuracy_score(y_val, predictions_val["is_click_predicted"]) |
| | classification_report_val = classification_report(y_val, predictions_val["is_click_predicted"], output_dict=True) |
| | classification_report_val = pd.DataFrame(classification_report_val).transpose() |
| | predictions_val_table = wandb.Table(dataframe=predictions_val) |
| | classification_report_val_table = wandb.Table(dataframe=classification_report_val) |
| |
|
| | print("\nπ Making predictions for test set...") |
| | predictions = predict(models, test_df) |
| |
|
| | |
| | run.log({"param_grid": param_grid, |
| | "accuracy_val": accuracy_val, |
| | "balanced_accuracy_val": balanced_accuracy_val, |
| | "classification_report_val_table": classification_report_val_table, |
| | "predictions_val_table": predictions_val_table, |
| | "y_val": y_val.tolist()}) |
| | run.finish() |
| |
|
| | |
| | predictions.to_csv("final_predictions.csv", index=False) |
| | print("\nβ
Predictions saved successfully as 'final_predictions.csv'!") |
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | |
| | |
| | main(train=True, retrain=False) |