Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from pathlib import Path | |
| import re | |
| from typing import Any, Tuple | |
| import pandas as pd | |
| from hydra import TaskFunction | |
| from hydra.core.hydra_config import HydraConfig | |
| from hydra.core.override_parser.overrides_parser import OverridesParser | |
| from hydra.core.utils import _save_config | |
| from hydra.experimental.callbacks import Callback | |
| from hydra.types import RunMode | |
| from hydra._internal.config_loader_impl import ConfigLoaderImpl | |
| from omegaconf import DictConfig, OmegaConf | |
| from omegaconf.errors import MissingMandatoryValue | |
| from deepscreen.utils import get_logger | |
| log = get_logger(__name__) | |
| class CSVExperimentSummary(Callback): | |
| """On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv.""" | |
| def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'): | |
| self.filename = filename | |
| self.prefix = prefix if isinstance(prefix, str) else tuple(prefix) | |
| self.input_experiment_summary = None | |
| self.time = {} | |
| def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None: | |
| if config.hydra.get('overrides') and config.hydra.overrides.get('task'): | |
| for i, override in enumerate(config.hydra.overrides.task): | |
| if override.startswith("ckpt_path"): | |
| ckpt_path = override.split('=', 1)[1] | |
| if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): | |
| config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path) | |
| log.info(ckpt_path) | |
| break | |
| if config.hydra.sweeper.get('params'): | |
| if config.hydra.sweeper.params.get('ckpt_path'): | |
| ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"") | |
| if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')): | |
| config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path) | |
| log.info(ckpt_path) | |
| def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None: | |
| self.time['start'] = datetime.now() | |
| def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None: | |
| # Skip callback if job is DDP subprocess | |
| if "ddp" in job_return.hydra_cfg.hydra.job.name: | |
| return | |
| try: | |
| self.time['end'] = datetime.now() | |
| if config.hydra.mode == RunMode.RUN: | |
| summary_file_path = Path(config.hydra.run.dir) / self.filename | |
| elif config.hydra.mode == RunMode.MULTIRUN: | |
| summary_file_path = Path(config.hydra.sweep.dir) / self.filename | |
| else: | |
| raise RuntimeError('Invalid Hydra `RunMode`.') | |
| if summary_file_path.is_file(): | |
| summary_df = pd.read_csv(summary_file_path) | |
| else: | |
| summary_df = pd.DataFrame() | |
| # Add job and override info | |
| info_dict = {} | |
| if job_return.overrides: | |
| info_dict = dict(override.split('=', 1) for override in job_return.overrides) | |
| info_dict['job_status'] = job_return.status.name | |
| info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id | |
| info_dict['wall_time'] = str(self.time['end'] - self.time['start']) | |
| # Add checkpoint info | |
| if info_dict.get('ckpt_path'): | |
| info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"") | |
| ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"") | |
| if Path(ckpt_path).is_file(): | |
| if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']: | |
| info_dict['previous_ckpt_path'] = info_dict['ckpt_path'] | |
| info_dict['ckpt_path'] = ckpt_path | |
| if info_dict.get('ckpt_path'): | |
| if (epoch := re.search(r'epoch_(\d+)', info_dict['ckpt_path'])) is not None: | |
| info_dict['best_epoch'] = int(epoch.group(1)) | |
| # Add metrics info | |
| metrics_df = pd.DataFrame() | |
| if config.get('logger'): | |
| output_dir = Path(config.hydra.runtime.output_dir).resolve() | |
| csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv" | |
| if csv_metrics_path.is_file(): | |
| log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}") | |
| metrics_df = pd.read_csv(csv_metrics_path) | |
| # Find rows where 'test/' columns are not null and reset its epoch to the best model epoch | |
| if info_dict.get('best_epoch'): | |
| test_columns = [col for col in metrics_df.columns if col.startswith('test/')] | |
| if test_columns: | |
| mask = metrics_df[test_columns].notna().any(axis=1) | |
| metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch'] | |
| # Group and filter by best epoch | |
| metrics_df = metrics_df.groupby('epoch').first() | |
| metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']] | |
| else: | |
| log.info(f"No metrics.csv found in {output_dir}") | |
| if metrics_df.empty: | |
| metrics_df = pd.DataFrame(data=info_dict, index=[0]) | |
| else: | |
| metrics_df = metrics_df.assign(**info_dict) | |
| metrics_df.index = [0] | |
| # Add extra info from the input batch experiment summary | |
| if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns: | |
| log.info(self.input_experiment_summary['ckpt_path']) | |
| log.info(metrics_df['ckpt_path']) | |
| orig_meta = self.input_experiment_summary[ | |
| self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0] | |
| ].head(1) | |
| if not orig_meta.empty: | |
| orig_meta.index = [0] | |
| metrics_df = metrics_df.astype('O').combine_first(orig_meta.astype('O')) | |
| summary_df = pd.concat([summary_df, metrics_df]) | |
| # Drop empty columns | |
| summary_df.dropna(inplace=True, axis=1, how='all') | |
| summary_df.to_csv(summary_file_path, index=False, mode='w') | |
| log.info(f"Experiment summary saved to {summary_file_path}") | |
| except Exception as e: | |
| log.exception("Unable to save the experiment summary due to an error.", exc_info=e) | |
| def parse_ckpt_path_from_experiment_summary(self, ckpt_path): | |
| log.info(ckpt_path) | |
| try: | |
| self.input_experiment_summary = pd.read_csv( | |
| ckpt_path, usecols=lambda col: not col.startswith(self.prefix) | |
| ) | |
| self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply( | |
| lambda x: x.strip("'\"") | |
| ) | |
| ckpt_list = list(set(self.input_experiment_summary['ckpt_path'])) | |
| parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list]) | |
| return parsed_ckpt_path | |
| except Exception as e: | |
| log.exception( | |
| f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).', | |
| exc_info=e | |
| ) | |
| def checkpoint_rerun_config(config: DictConfig): | |
| hydra_cfg = HydraConfig.get() | |
| if not Path(config.ckpt_path).is_file(): | |
| raise FileNotFoundError(f'Not a valid checkpoint file: {config.ckpt_path}') | |
| if hydra_cfg.get('output_subdir'): | |
| ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml' | |
| hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir | |
| if ckpt_cfg_path.is_file(): | |
| log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; " | |
| f"merging config overrides with checkpoint config...") | |
| ckpt_cfg = OmegaConf.load(ckpt_cfg_path) | |
| for key, value in ckpt_cfg.items(): | |
| OmegaConf.update(config, key, value, merge=False, force_add=True) | |
| # Recompose merged config with overrides | |
| if hydra_cfg.overrides.get('task'): | |
| parser = OverridesParser.create() | |
| parsed_overrides = parser.parse_overrides(overrides=hydra_cfg.overrides.task) | |
| filtered_overrides = [] | |
| for override in parsed_overrides: | |
| if override.is_force_add() or override.key_or_group.split('.')[0] in config: | |
| filtered_overrides.append(override) | |
| ConfigLoaderImpl._apply_overrides_to_config(filtered_overrides, config) | |
| _save_config(config, "config.yaml", hydra_output) | |
| return config | |