Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
import hydra | |
from lightning import LightningDataModule, LightningModule, Trainer, Callback | |
from lightning.pytorch.loggers import Logger | |
from omegaconf import DictConfig | |
from deepscreen.utils.hydra import checkpoint_rerun_config | |
from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks, instantiate_loggers, log_hyperparameters | |
log = get_logger(__name__) | |
# def fix_dict_config(cfg: DictConfig): | |
# """fix all vars in the cfg config | |
# this is an in-place operation""" | |
# keys = list(cfg.keys()) | |
# for k in keys: | |
# if type(cfg[k]) is DictConfig: | |
# fix_dict_config(cfg[k]) | |
# else: | |
# setattr(cfg, k, getattr(cfg, k)) | |
def test(cfg: DictConfig) -> Tuple[dict, dict]: | |
"""Evaluates given checkpoint on a datamodule testset. | |
This method is wrapped in optional @job_wrapper decorator, that controls the behavior during | |
failure. Useful for multiruns, saving info about the crash, etc. | |
Args: | |
cfg (DictConfig): Configuration composed by Hydra. | |
Returns: | |
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. | |
""" | |
# fix_dict_config(cfg) | |
log.info(f"Instantiating datamodule <{cfg.data._target_}>") | |
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) | |
log.info(f"Instantiating model <{cfg.model._target_}>") | |
model: LightningModule = hydra.utils.instantiate(cfg.model) | |
log.info("Instantiating callbacks.") | |
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) | |
log.info("Instantiating loggers.") | |
logger: List[Logger] = instantiate_loggers(cfg.get("logger")) | |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") | |
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, callbacks=callbacks) | |
object_dict = { | |
"cfg": cfg, | |
"datamodule": datamodule, | |
"model": model, | |
"callbacks": callbacks, | |
"logger": logger, | |
"trainer": trainer, | |
} | |
if logger: | |
log.info("Logging hyperparameters.") | |
log_hyperparameters(object_dict) | |
log.info("Start testing.") | |
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) | |
metric_dict = trainer.callback_metrics | |
return metric_dict, object_dict | |
def main(cfg: DictConfig): | |
assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for testing." | |
cfg = checkpoint_rerun_config(cfg) | |
# from hydra.core.hydra_config import HydraConfig | |
# | |
# hydra_cfg = HydraConfig.get() | |
# if hydra_cfg.output_subdir is not None: | |
# ckpt_cfg_path = Path(cfg.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml' | |
# ckpt_hydra_cfg_path = Path(cfg.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'hydra.yaml' | |
# hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir | |
# | |
# if ckpt_cfg_path.is_file(): | |
# log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; " | |
# f"merging config overrides with checkpoint config...") | |
# ckpt_cfg = OmegaConf.load(ckpt_cfg_path) | |
# | |
# if ckpt_hydra_cfg_path.is_file(): | |
# ckpt_hydra_cfg = OmegaConf.load(ckpt_hydra_cfg_path) | |
# override_dirname = sanitize_path(ckpt_hydra_cfg.job.override_dirname) | |
# | |
# # Merge checkpoint config with test config by overriding specified nodes. | |
# ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'task']) | |
# ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [ | |
# key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split'] | |
# ]) | |
# cfg = OmegaConf.merge(ckpt_cfg, cfg) | |
# | |
# _save_config(cfg, "config.yaml", hydra_output) | |
# evaluate the model | |
metric_dict, _ = test(cfg) | |
# safely retrieve metric value for hydra-based hyperparameter optimization | |
objective_metric = cfg.get("objective_metric") | |
if not objective_metric: | |
return None | |
if objective_metric not in metric_dict: | |
raise Exception( | |
f"Unable to find `objective_metric` ({objective_metric}) in `metric_dict`.\n" | |
"Make sure `objective_metric` name in `sweep` config is correct." | |
) | |
metric_value = metric_dict[objective_metric].item() | |
log.info(f"Retrieved objective_metric. <{objective_metric}={metric_value}>") | |
# return optimized metric | |
return metric_value | |
if __name__ == "__main__": | |
main() | |