libokj's picture
Upload 358 files
05ca42f
raw
history blame
8.94 kB
from pathlib import Path
import re
from typing import Any, Tuple
import pandas as pd
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _save_config
from hydra.experimental.callbacks import Callback
from hydra.types import RunMode
from omegaconf import DictConfig, OmegaConf
from deepscreen.utils import get_logger
log = get_logger(__name__)
class CSVExperimentSummary(Callback):
"""On multirun end, aggregate the results from each job's metrics.csv and save them in metrics_summary.csv."""
def __init__(self, filename: str = 'experiment_summary.csv', prefix: str | Tuple[str] = 'test/'):
self.filename = filename
self.prefix = prefix if isinstance(prefix, str) else tuple(prefix)
self.input_experiment_summary = None
def on_multirun_start(self, config: DictConfig, **kwargs: Any) -> None:
if config.hydra.get('overrides'):
if config.hydra.overrides.task:
for i, override in enumerate(config.hydra.overrides.task):
if override.startswith("ckpt_path"):
ckpt_path = override.split('=', 1)[1]
if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
config.hydra.overrides.task[i] = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
break
elif config.hydra.sweeper.get('params'):
if config.hydra.sweeper.params.get('ckpt_path'):
ckpt_path = str(config.hydra.sweeper.params.ckpt_path).strip("'\"")
if ckpt_path.endswith(('.csv', '.txt', '.tsv', '.ssv', '.psv')):
config.hydra.sweeper.params.ckpt_path = self.parse_ckpt_path_from_experiment_summary(ckpt_path)
def on_job_end(self, config: DictConfig, job_return, **kwargs: Any) -> None:
# Skip callback if job is DDP subprocess
if "ddp" in job_return.hydra_cfg.hydra.job.name:
return
try:
if config.hydra.mode == RunMode.RUN:
summary_file_path = Path(config.hydra.run.dir) / self.filename
elif config.hydra.mode == RunMode.MULTIRUN:
summary_file_path = Path(config.hydra.sweep.dir) / self.filename
else:
raise RuntimeError('Invalid Hydra `RunMode`.')
if summary_file_path.is_file():
summary_df = pd.read_csv(summary_file_path)
else:
summary_df = pd.DataFrame()
# Add job and override info
override_dict = dict(override.split('=', 1) for override in job_return.overrides)
override_dict['job_status'] = job_return.status.name
# Add checkpoint info
if override_dict.get('ckpt_path'):
override_dict['ckpt_path'] = str(override_dict['ckpt_path']).strip("'\"")
if job_return.cfg.get('ckpt_path'):
ckpt_path = str(job_return.cfg.ckpt_path).strip("'\"")
if Path(ckpt_path).is_file():
if override_dict.get('ckpt_path') and ckpt_path != override_dict['ckpt_path']:
override_dict['previous_ckpt_path'] = override_dict['ckpt_path']
override_dict['ckpt_path'] = ckpt_path
override_dict['epoch'] = int(re.search(r'epoch_(\d+)', override_dict['ckpt_path']).group(1))
# Add metrics info
output_dir = Path(config.hydra.runtime.output_dir).resolve()
csv_metrics_path = output_dir / config.logger.csv.name / "metrics.csv"
if csv_metrics_path.is_file():
log.info(f"Summarizing metrics with prefix `{self.prefix}` from {csv_metrics_path}")
# Use only columns that start with the specified prefix
metrics_df = pd.read_csv(csv_metrics_path)
# Find rows where any 'test/' column is not null and reset its epoch to the best model epoch
test_columns = [col for col in metrics_df.columns if col.startswith('test/')]
mask = metrics_df[test_columns].notna().any(axis=1)
metrics_df.loc[mask, 'epoch'] = override_dict['epoch']
# Group and filter by best epoch
metrics_df = metrics_df.groupby('epoch').first()
metrics_df = metrics_df[metrics_df.index == override_dict['epoch']]
else:
log.info(f"No metrics.csv found in {output_dir}")
metrics_df = pd.DataFrame()
metrics_df = metrics_df.assign(**override_dict)
metrics_df.index = [0]
# Add extra info from the input batch experiment summary
if self.input_experiment_summary is not None:
orig_meta = self.input_experiment_summary[
self.input_experiment_summary['ckpt_path'] == metrics_df['ckpt_path'][0]
].head(1)
orig_meta.index = [0]
metrics_df = metrics_df.combine_first(orig_meta)
summary_df = pd.concat([summary_df, metrics_df])
# Drop empty columns
summary_df.dropna(inplace=True, axis=1, how='all')
summary_df.to_csv(summary_file_path, index=False, mode='w')
log.info(f"Experiment summary saved to {summary_file_path}")
except Exception as e:
log.exception("Unable to save the experiment summary due to an error.", exc_info=e)
def parse_ckpt_path_from_experiment_summary(self, ckpt_path):
try:
self.input_experiment_summary = pd.read_csv(
ckpt_path, usecols=lambda col: not col.startswith(self.prefix)
)
self.input_experiment_summary['ckpt_path'] = self.input_experiment_summary['ckpt_path'].apply(
lambda x: x.strip("'\"")
)
ckpt_list = list(set(self.input_experiment_summary['ckpt_path']))
parsed_ckpt_path = ','.join([f"'{ckpt}'" for ckpt in ckpt_list])
return parsed_ckpt_path
except Exception as e:
log.exception(
f'Error in parsing checkpoint paths from experiment_summary file ({ckpt_path}).',
exc_info=e
)
def checkpoint_rerun_config(config: DictConfig):
hydra_cfg = HydraConfig.get()
if hydra_cfg.output_subdir is not None:
ckpt_cfg_path = Path(config.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml'
hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir
if ckpt_cfg_path.is_file():
log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; "
f"merging config overrides with checkpoint config...")
ckpt_cfg = OmegaConf.load(ckpt_cfg_path)
# Merge checkpoint config with test config by overriding specified nodes.
# ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'trainer', 'task'])
# ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [
# key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split']
# ])
#
# config = OmegaConf.merge(ckpt_cfg, config)
# config = OmegaConf.masked_copy(config,
# [key for key in config if key not in
# ['task']])
# config.data = OmegaConf.masked_copy(config.data,
# [key for key in config.data if key not in
# ['drug_featurizer', 'protein_featurizer', 'collator']])
# config.model = OmegaConf.masked_copy(config.model,
# [key for key in config.model if key not in
# ['predictor']])
#
# config = OmegaConf.merge(ckpt_cfg, config)
ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'task', 'seed'])
ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [
key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split']
])
ckpt_override_keys = ['task',
'data.drug_featurizer', 'data.protein_featurizer', 'data.collator',
'model.predictor']
for key in ckpt_override_keys:
OmegaConf.update(config, key, OmegaConf.select(ckpt_cfg, key), force_add=True)
config = OmegaConf.merge(ckpt_cfg, config)
# OmegaConf.set_readonly(hydra_cfg, False)
# hydra_cfg.job.override_dirname += f"ckpt={str(Path(*Path(config.ckpt_path).parts[-4:]))}"
_save_config(config, "config.yaml", hydra_output)
return config