File size: 4,730 Bytes
05ca42f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import List, Tuple

import hydra
from lightning import LightningDataModule, LightningModule, Trainer, Callback
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

from deepscreen.utils.hydra import checkpoint_rerun_config
from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks, instantiate_loggers, log_hyperparameters


log = get_logger(__name__)


# def fix_dict_config(cfg: DictConfig):
#     """fix all vars in the cfg config
#     this is an in-place operation"""
#     keys = list(cfg.keys())
#     for k in keys:
#         if type(cfg[k]) is DictConfig:
#             fix_dict_config(cfg[k])
#         else:
#             setattr(cfg, k, getattr(cfg, k))


@job_wrapper(extra_utils=True)
def test(cfg: DictConfig) -> Tuple[dict, dict]:
    """Evaluates given checkpoint on a datamodule testset.

    This method is wrapped in optional @job_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    Args:
        cfg (DictConfig): Configuration composed by Hydra.

    Returns:
        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
    """
    # fix_dict_config(cfg)

    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)

    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(cfg.model)

    log.info("Instantiating callbacks.")
    callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

    log.info("Instantiating loggers.")
    logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, callbacks=callbacks)

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters.")
        log_hyperparameters(object_dict)

    log.info("Start testing.")
    trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)

    metric_dict = trainer.callback_metrics

    return metric_dict, object_dict


@hydra.main(version_base="1.3", config_path="../configs", config_name="test.yaml")
def main(cfg: DictConfig):
    assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for testing."
    cfg = checkpoint_rerun_config(cfg)
    # from hydra.core.hydra_config import HydraConfig
    #
    # hydra_cfg = HydraConfig.get()
    # if hydra_cfg.output_subdir is not None:
    #     ckpt_cfg_path = Path(cfg.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'config.yaml'
    #     ckpt_hydra_cfg_path = Path(cfg.ckpt_path).parents[1] / hydra_cfg.output_subdir / 'hydra.yaml'
    #     hydra_output = Path(hydra_cfg.runtime.output_dir) / hydra_cfg.output_subdir
    #
    #     if ckpt_cfg_path.is_file():
    #         log.info(f"Found config file for the checkpoint at {str(ckpt_cfg_path)}; "
    #                  f"merging config overrides with checkpoint config...")
    #         ckpt_cfg = OmegaConf.load(ckpt_cfg_path)
    #
    #         if ckpt_hydra_cfg_path.is_file():
    #             ckpt_hydra_cfg = OmegaConf.load(ckpt_hydra_cfg_path)
    #             override_dirname = sanitize_path(ckpt_hydra_cfg.job.override_dirname)
    #
    #         # Merge checkpoint config with test config by overriding specified nodes.
    #         ckpt_cfg = OmegaConf.masked_copy(ckpt_cfg, ['model', 'data', 'task'])
    #         ckpt_cfg.data = OmegaConf.masked_copy(ckpt_cfg.data, [
    #             key for key in ckpt_cfg.data.keys() if key not in ['data_file', 'split', 'train_val_test_split']
    #         ])
    #         cfg = OmegaConf.merge(ckpt_cfg, cfg)
    #
    #         _save_config(cfg, "config.yaml", hydra_output)

    # evaluate the model
    metric_dict, _ = test(cfg)

    # safely retrieve metric value for hydra-based hyperparameter optimization
    objective_metric = cfg.get("objective_metric")

    if not objective_metric:
        return None

    if objective_metric not in metric_dict:
        raise Exception(
            f"Unable to find `objective_metric` ({objective_metric}) in `metric_dict`.\n"
            "Make sure `objective_metric` name in `sweep` config is correct."
        )

    metric_value = metric_dict[objective_metric].item()
    log.info(f"Retrieved objective_metric. <{objective_metric}={metric_value}>")

    # return optimized metric
    return metric_value


if __name__ == "__main__":
    main()