demo_active_learning / train_ensemble_models_main.py
bndl's picture
Upload 3 files
19b61e8
raw
history blame
4.73 kB
import argparse
import pandas as pd
import os
import numpy as np
import pickle
from train_model_main import prepare_data, set_all_seeds, train_model, save_shap_explainer
from utils import EnsembleModel, unpickle_file
def run_ensemble_models_training(
data,
columns_numerical,
columns_target,
main_folder,
model_name,
data_type="path",
lr=0.01,
n_models=5,
save_explainer_single=False,
save_explainer_ensemble=True,
seed_train_test_split=None,
):
"""
Runs multiple models with different seed for the same prediction task
"""
seeds = range(n_models)
model_list = []
model_path_list = []
history_list = []
for s in seeds:
print("-----------------------")
print(f"Training model {s + 1}/{n_models}")
main_seed_folder = os.path.join(main_folder, f"seed{s}")
seed_split = s
if seed_train_test_split is not None:
seed_split = seed_train_test_split
X_train, X_test, y_train, y_test = prepare_data(
data, columns_numerical, columns_target, main_seed_folder, data_type=data_type, seed=seed_split
)
model, history = train_model(
X_train,
X_test,
y_train,
y_test,
columns_target,
main_seed_folder,
model_name,
lr=lr,
seed=s,
get_history=True,
)
model_list.append(model)
history_list.append(history)
model_path_list.append(os.path.join(main_seed_folder, model_name))
if save_explainer_single:
save_shap_explainer(model.predict, X_train, X_test, main_seed_folder)
scaler_targets = unpickle_file(os.path.join(main_seed_folder, "minmax_scaler_targets.pickle"))
ensemble_model = EnsembleModel(model_list, history_list, scaler_targets=scaler_targets)
with open(os.path.join(main_folder, f"ensemble_{model_name.split('.')[0]}.pkl"), "wb+") as file:
pickle.dump(ensemble_model, file)
# For now just gets the last X_train, X_test, but should be changed to a better solution
X_train_all = X_train.copy()
X_test_all = X_test.copy()
if save_explainer_ensemble:
save_shap_explainer(ensemble_model.predict, X_train_all, X_test_all, main_folder)
return model_list
def train_ensemble_models_from_split(
X_train, X_test, y_train, y_test, columns_target, main_folder, model_path, lr=0.01, n_models=5, save_explainer=True
):
"""
Assumes the train set is the same for all the models
"""
seeds = range(n_models)
model_ls = []
for s in seeds:
model_main_name = model_path.split(".")[0]
model_ext = model_path.split(".")[1]
model_name = f"{model_main_name}_s{s}.{model_ext}"
model = train_model(X_train, X_test, y_train, y_test, columns_target, main_folder, model_name, lr=lr, seed=s)
model_ls.append(model)
if save_explainer:
save_shap_explainer(model, X_train, X_test, main_folder)
return model_ls
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process parameters")
parser.add_argument(
"--data_path",
type=str,
help="The path to your input data file",
default="preprocessed_data.csv",
required=False,
)
parser.add_argument(
"--main_folder", type=str, help="Folder to save model files", default="../models/hardness", required=False
)
parser.add_argument(
"--model_name", type=str, help="Path to save model", default="model_hardness.h5", required=False
)
parser.add_argument("--columns_target", type=str, help="List of target columns", default="H", required=False)
parser.add_argument(
"--columns_numerical",
type=str,
help="List of data columns with numeric values",
default="%A,%B,%C,%D,%E,%F,%Phase_A,%Phase_B,%Phase_C,%Phase_D,%Phase_E,%Phase_F,%A_Matrice,%B_Matrice,%C_Matrice,%D_Matrice,%E_Matrice,%F_Matrice,H,Temperature_C",
required=False,
)
parser.add_argument("--learning_rate", "-lr", type=float, help="Learning rate", default=0.01, required=False)
parser.add_argument("--n_models", "-n", type=int, help="Number of models to run", default=2, required=False)
args = parser.parse_args()
columns_numerical = args.columns_numerical.split(",") if args.columns_numerical else []
columns_target = args.columns_target.split(",") if args.columns_target else []
run_ensemble_models_training(
args.data_path,
columns_numerical,
columns_target,
args.main_folder,
args.model_name,
lr=args.learning_rate,
n_models=args.n_models,
)