| | import optuna |
| | import os |
| | import tempfile |
| | import time |
| | import json |
| | import subprocess |
| | import logging |
| | from beam_search_utils import ( |
| | write_seglst_jsons, |
| | run_mp_beam_search_decoding, |
| | convert_nemo_json_to_seglst, |
| | ) |
| | from hydra.core.config_store import ConfigStore |
| |
|
| |
|
| | def evaluate(cfg, temp_out_dir, workspace_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict): |
| | write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp') |
| | write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref') |
| | write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src') |
| |
|
| | |
| | src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json") |
| | hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json") |
| | ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json") |
| | |
| | |
| | output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json") |
| | output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json") |
| |
|
| | |
| | cmd_hyp = [ |
| | "meeteval-wer", |
| | "cpwer", |
| | "-h", hyp_seglst_json, |
| | "-r", ref_seglst_json |
| | ] |
| | subprocess.run(cmd_hyp) |
| |
|
| | cmd_src = [ |
| | "meeteval-wer", |
| | "cpwer", |
| | "-h", src_seglst_json, |
| | "-r", ref_seglst_json |
| | ] |
| | subprocess.run(cmd_src) |
| |
|
| | |
| | try: |
| | with open(output_cpwer_hyp_json_file, "r") as file: |
| | data_h = json.load(file) |
| | print("Hypothesis cpWER:", data_h["error_rate"]) |
| | cpwer = data_h["error_rate"] |
| | logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}") |
| | except FileNotFoundError: |
| | raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.") |
| |
|
| | try: |
| | with open(output_cpwer_src_json_file, "r") as file: |
| | data_s = json.load(file) |
| | print("Source cpWER:", data_s["error_rate"]) |
| | source_cpwer = data_s["error_rate"] |
| | logging.info(f"-> SOURCE cpWER={source_cpwer:.4f}") |
| | except FileNotFoundError: |
| | raise FileNotFoundError(f"Output JSON: {output_cpwer_src_json_file}\nfile not found.") |
| | return cpwer |
| |
|
| |
|
| | def optuna_suggest_params(cfg, trial): |
| | cfg.alpha = trial.suggest_float("alpha", 0.01, 5.0) |
| | cfg.beta = trial.suggest_float("beta", 0.001, 2.0) |
| | cfg.beam_width = trial.suggest_int("beam_width", 4, 64) |
| | cfg.word_window = trial.suggest_int("word_window", 16, 64) |
| | cfg.use_ngram = True |
| | cfg.parallel_chunk_word_len = trial.suggest_int("parallel_chunk_word_len", 50, 300) |
| | cfg.peak_prob = trial.suggest_float("peak_prob", 0.9, 1.0) |
| | return cfg |
| |
|
| | def beamsearch_objective( |
| | trial, |
| | cfg, |
| | speaker_beam_search_decoder, |
| | loaded_kenlm_model, |
| | div_trans_info_dict, |
| | org_trans_info_dict, |
| | source_info_dict, |
| | reference_info_dict, |
| | ): |
| | with tempfile.TemporaryDirectory(dir=cfg.temp_out_dir, prefix="GenSEC_") as loca_temp_out_dir: |
| | start_time2 = time.time() |
| | cfg = optuna_suggest_params(cfg, trial) |
| | trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder, |
| | loaded_kenlm_model=loaded_kenlm_model, |
| | div_trans_info_dict=div_trans_info_dict, |
| | org_trans_info_dict=org_trans_info_dict, |
| | div_mp=True, |
| | win_len=cfg.parallel_chunk_word_len, |
| | word_window=cfg.word_window, |
| | port=cfg.port, |
| | use_ngram=cfg.use_ngram, |
| | ) |
| | hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict) |
| | cpwer = evaluate(cfg, loca_temp_out_dir, cfg.workspace_dir, cfg.asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict) |
| | logging.info(f"Beam Search time taken for trial {trial}: {(time.time() - start_time2)/60:.2f} mins") |
| | logging.info(f"Trial: {trial.number}") |
| | logging.info(f"[ cpWER={cpwer:.4f} ]") |
| | logging.info("-----------------------------------------------") |
| |
|
| | return cpwer |
| |
|
| |
|
| | def optuna_hyper_optim( |
| | cfg, |
| | speaker_beam_search_decoder, |
| | loaded_kenlm_model, |
| | div_trans_info_dict, |
| | org_trans_info_dict, |
| | source_info_dict, |
| | reference_info_dict, |
| | ): |
| | """ |
| | Optuna hyper-parameter optimization function. |
| | |
| | Parameters: |
| | cfg (dict): A dictionary containing the configuration parameters. |
| | |
| | """ |
| | worker_function = lambda trial: beamsearch_objective( |
| | trial=trial, |
| | cfg=cfg, |
| | speaker_beam_search_decoder=speaker_beam_search_decoder, |
| | loaded_kenlm_model=loaded_kenlm_model, |
| | div_trans_info_dict=div_trans_info_dict, |
| | org_trans_info_dict=org_trans_info_dict, |
| | source_info_dict=source_info_dict, |
| | reference_info_dict=reference_info_dict, |
| | ) |
| | study = optuna.create_study( |
| | direction="minimize", |
| | study_name=cfg.optuna_study_name, |
| | storage=cfg.storage, |
| | load_if_exists=True |
| | ) |
| | logger = logging.getLogger() |
| | logger.setLevel(logging.INFO) |
| | if cfg.output_log_file is not None: |
| | logger.addHandler(logging.FileHandler(cfg.output_log_file, mode="a")) |
| | logger.addHandler(logging.StreamHandler()) |
| | optuna.logging.enable_propagation() |
| | study.optimize(worker_function, n_trials=cfg.optuna_n_trials) |