Spaces:
Sleeping
Sleeping
Delete deepscreen/utils/hydra.py.bak
Browse files- deepscreen/utils/hydra.py.bak +0 -182
deepscreen/utils/hydra.py.bak
DELETED
|
@@ -1,182 +0,0 @@
|
|
| 1 |
-
from datetime import datetime
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
import re
|
| 4 |
-
from typing import Any, Tuple
|
| 5 |
-
|
| 6 |
-
import pandas as pd
|
| 7 |
-
from hydra import TaskFunction
|
| 8 |
-
from hydra.core.hydra_config import HydraConfig
|
| 9 |
-
from hydra.core.override_parser.overrides_parser import OverridesParser
|
| 10 |
-
from hydra.core.utils import _save_config
|
| 11 |
-
from hydra.experimental.callbacks import Callback
|
| 12 |
-
from hydra.types import RunMode
|
| 13 |
-
from hydra._internal.config_loader_impl import ConfigLoaderImpl
|
| 14 |
-
from omegaconf import DictConfig, OmegaConf
|
| 15 |
-
from omegaconf.errors import MissingMandatoryValue
|
| 16 |
-
|
| 17 |
-
from deepscreen.utils import get_logger
|
| 18 |
-
|
| 19 |
-
log = get_logger(__name__)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class CSVExperimentSummary(Callback):
|
| 23 |
-
"""On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv."""
|
| 24 |
-
|
| 25 |
-
def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'):
|
| 26 |
-
self.filename = filename
|
| 27 |
-
self.prefix = prefix if isinstance(prefix, str) else tuple(prefix)
|
| 28 |
-
self.input_experiment_summary = None
|
| 29 |
-
self.time = {}
|
| 30 |
-
|
| 31 |
-
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
|
| 32 |
-
if config.hydra.get('overrides') and config.hydra.overrides.get('task'):
|
| 33 |
-
for i, override in enumerate(config.hydra.overrides.task):
|
| 34 |
-
if override.startswith("ckpt_path"):
|
| 35 |
-
ckpt_path = override.split('=', 1)[1]
|
| 36 |
-
if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
|
| 37 |
-
config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
|
| 38 |
-
log.info(ckpt_path)
|
| 39 |
-
break
|
| 40 |
-
if config.hydra.sweeper.get('params'):
|
| 41 |
-
if config.hydra.sweeper.params.get('ckpt_path'):
|
| 42 |
-
ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"")
|
| 43 |
-
if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
|
| 44 |
-
config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
|
| 45 |
-
log.info(ckpt_path)
|
| 46 |
-
def on_job_start(self, config: DictConfig, *, task_function: TaskFunction, **kwargs: Any) -> None:
|
| 47 |
-
self.time['start'] = datetime.now()
|
| 48 |
-
|
| 49 |
-
def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None:
|
| 50 |
-
# Skip callback if job is DDP subprocess
|
| 51 |
-
if "ddp" in job_return.hydra_cfg.hydra.job.name:
|
| 52 |
-
return
|
| 53 |
-
|
| 54 |
-
try:
|
| 55 |
-
self.time['end'] = datetime.now()
|
| 56 |
-
if config.hydra.mode == RunMode.RUN:
|
| 57 |
-
summary_file_path = Path(config.hydra.run.dir) / self.filename
|
| 58 |
-
elif config.hydra.mode == RunMode.MULTIRUN:
|
| 59 |
-
summary_file_path = Path(config.hydra.sweep.dir) / self.filename
|
| 60 |
-
else:
|
| 61 |
-
raise RuntimeError('Invalid Hydra `RunMode`.')
|
| 62 |
-
|
| 63 |
-
if summary_file_path.is_file():
|
| 64 |
-
summary_df = pd.read_csv(summary_file_path)
|
| 65 |
-
else:
|
| 66 |
-
summary_df = pd.DataFrame()
|
| 67 |
-
|
| 68 |
-
# Add job and override info
|
| 69 |
-
info_dict = {}
|
| 70 |
-
if job_return.overrides:
|
| 71 |
-
info_dict = dict(override.split('=', 1) for override in job_return.overrides)
|
| 72 |
-
info_dict['job_status'] = job_return.status.name
|
| 73 |
-
info_dict['job_id'] = job_return.hydra_cfg.hydra.job.id
|
| 74 |
-
info_dict['wall_time'] = str(self.time['end'] - self.time['start'])
|
| 75 |
-
|
| 76 |
-
# Add checkpoint info
|
| 77 |
-
if info_dict.get('ckpt_path'):
|
| 78 |
-
info_dict['ckpt_path'] = str(info_dict['ckpt_path']).strip("'\"")
|
| 79 |
-
|
| 80 |
-
ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"")
|
| 81 |
-
if Path(ckpt_path).is_file():
|
| 82 |
-
if info_dict.get('ckpt_path') and ckpt_path != info_dict['ckpt_path']:
|
| 83 |
-
info_dict['previous_ckpt_path'] = info_dict['ckpt_path']
|
| 84 |
-
info_dict['ckpt_path'] = ckpt_path
|
| 85 |
-
if info_dict.get('ckpt_path'):
|
| 86 |
-
info_dict['best_epoch'] = int(re.search(r'epoch_(\d+)', info_dict['ckpt_path']).group(1))
|
| 87 |
-
|
| 88 |
-
# Add metrics info
|
| 89 |
-
metrics_df = pd.DataFrame()
|
| 90 |
-
if config.get('logger'):
|
| 91 |
-
output_dir = Path(config.hydra.runtime.output_dir).resolve()
|
| 92 |
-
csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv"
|
| 93 |
-
if csv_metrics_path.is_file():
|
| 94 |
-
log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}")
|
| 95 |
-
metrics_df = pd.read_csv(csv_metrics_path)
|
| 96 |
-
# Find rows where 'test/' columns are not null and reset its epoch to the best model epoch
|
| 97 |
-
test_columns = [col for col in metrics_df.columns if col.startswith('test/')]
|
| 98 |
-
if test_columns:
|
| 99 |
-
mask = metrics_df[test_columns].notna().any(axis=1)
|
| 100 |
-
metrics_df.loc[mask, 'epoch'] = info_dict['best_epoch']
|
| 101 |
-
# Group and filter by best epoch
|
| 102 |
-
metrics_df = metrics_df.groupby('epoch').first()
|
| 103 |
-
metrics_df = metrics_df[metrics_df.index == info_dict['best_epoch']]
|
| 104 |
-
else:
|
| 105 |
-
log.info(f"No metrics.csv found in {output_dir}")
|
| 106 |
-
|
| 107 |
-
if metrics_df.empty:
|
| 108 |
-
metrics_df = pd.DataFrame(data=info_dict, index=[0])
|
| 109 |
-
else:
|
| 110 |
-
metrics_df = metrics_df.assign(**info_dict)
|
| 111 |
-
metrics_df.index = [0]
|
| 112 |
-
|
| 113 |
-
# Add extra info from the input batch experiment summary
|
| 114 |
-
if self.input_experiment_summary is not None and 'ckpt_path' in metrics_df.columns:
|
| 115 |
-
log.info(self.input_experiment_summary['ckpt_path'])
|
| 116 |
-
log.info(metrics_df['ckpt_path'])
|
| 117 |
-
orig_meta = self.input_experiment_summary[
|
| 118 |
-
self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
|
| 119 |
-
].head(1)
|
| 120 |
-
if not orig_meta.empty:
|
| 121 |
-
orig_meta.index = [0]
|
| 122 |
-
metrics_df = metrics_df.astype('O').combine_first(orig_meta.astype('O'))
|
| 123 |
-
|
| 124 |
-
summary_df = pd.concat([summary_df, metrics_df])
|
| 125 |
-
|
| 126 |
-
# Drop empty columns
|
| 127 |
-
summary_df.dropna(inplace=True, axis=1, how='all')
|
| 128 |
-
summary_df.to_csv(summary_file_path, index=False, mode='w')
|
| 129 |
-
log.info(f"Experiment summary saved to {summary_file_path}")
|
| 130 |
-
except Exception as e:
|
| 131 |
-
log.exception("Unable to save the experiment summary due to an error.", exc_info=e)
|
| 132 |
-
|
| 133 |
-
def parse_ckpt_path_from_experiment_summary(self, ckpt_path):
|
| 134 |
-
log.info(ckpt_path)
|
| 135 |
-
try:
|
| 136 |
-
self.input_experiment_summary = pd.read_csv(
|
| 137 |
-
ckpt_path, usecols=lambda col: not col.startswith(self.prefix)
|
| 138 |
-
)
|
| 139 |
-
self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply(
|
| 140 |
-
lambda x: x.strip("'\"")
|
| 141 |
-
)
|
| 142 |
-
ckpt_list = list(set(self.input_experiment_summary['ckpt_path']))
|
| 143 |
-
parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list])
|
| 144 |
-
return parsed_ckpt_path
|
| 145 |
-
|
| 146 |
-
except Exception as e:
|
| 147 |
-
log.exception(
|
| 148 |
-
f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).',
|
| 149 |
-
exc_info=e
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def checkpoint_rerun_config(config: DictConfig):
|
| 154 |
-
hydra_cfg = HydraConfig.get()
|
| 155 |
-
|
| 156 |
-
if hydra_cfg.get('output_subdir'):
|
| 157 |
-
ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml'
|
| 158 |
-
hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir
|
| 159 |
-
|
| 160 |
-
if ckpt_cfg_path.is_file():
|
| 161 |
-
log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; "
|
| 162 |
-
f"merging config overrides with checkpoint config...")
|
| 163 |
-
ckpt_cfg = OmegaConf.load(ckpt_cfg_path)
|
| 164 |
-
|
| 165 |
-
# Recompose checkpoint config with overrides
|
| 166 |
-
|
| 167 |
-
if hydra_cfg.overrides.get('task'):
|
| 168 |
-
parser = OverridesParser.create()
|
| 169 |
-
parsed_overrides = parser.parse_overrides(overrides=hydra_cfg.overrides.task)
|
| 170 |
-
filtered_overrides = []
|
| 171 |
-
for override in parsed_overrides:
|
| 172 |
-
if not override.is_force_add():
|
| 173 |
-
OmegaConf.update(ckpt_cfg, override.key_or_group, override.value())
|
| 174 |
-
filtered_overrides.append(override)
|
| 175 |
-
log.info(filtered_overrides)
|
| 176 |
-
ConfigLoaderImpl._apply_overrides_to_config(filtered_overrides, config)
|
| 177 |
-
|
| 178 |
-
_save_config(config, "config.yaml", hydra_output)
|
| 179 |
-
|
| 180 |
-
return config
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|