Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
import hydra | |
from lightning import LightningDataModule, LightningModule, Trainer, Callback | |
from lightning.pytorch.loggers import Logger | |
from omegaconf import DictConfig | |
from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks, instantiate_loggers, log_hyperparameters | |
log = get_logger(__name__) | |
# def fix_dict_config(cfg: DictConfig): | |
# """fix all vars in the cfg config | |
# this is an in-place operation""" | |
# keys = list(cfg.keys()) | |
# for k in keys: | |
# if type(cfg[k]) is DictConfig: | |
# fix_dict_config(cfg[k]) | |
# else: | |
# setattr(cfg, k, getattr(cfg, k)) | |
def test(cfg: DictConfig) -> Tuple[dict, dict]: | |
"""Evaluates given checkpoint on a datamodule testset. | |
This method is wrapped in optional @job_wrapper decorator, that controls the behavior during | |
failure. Useful for multiruns, saving info about the crash, etc. | |
Args: | |
cfg (DictConfig): Configuration composed by Hydra. | |
Returns: | |
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. | |
""" | |
# fix_dict_config(cfg) | |
assert cfg.ckpt_path | |
log.info(f"Instantiating datamodule <{cfg.data._target_}>") | |
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) | |
log.info(f"Instantiating model <{cfg.model._target_}>") | |
model: LightningModule = hydra.utils.instantiate(cfg.model) | |
log.info("Instantiating callbacks.") | |
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) | |
log.info("Instantiating loggers.") | |
logger: List[Logger] = instantiate_loggers(cfg.get("logger")) | |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") | |
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger, callbacks=callbacks) | |
object_dict = { | |
"cfg": cfg, | |
"datamodule": datamodule, | |
"model": model, | |
"callbacks": callbacks, | |
"logger": logger, | |
"trainer": trainer, | |
} | |
if logger: | |
log.info("Logging hyperparameters.") | |
log_hyperparameters(object_dict) | |
log.info("Start testing.") | |
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) | |
metric_dict = trainer.callback_metrics | |
return metric_dict, object_dict | |
def main(cfg: DictConfig) -> None: | |
# evaluate the model | |
metric_dict, _ = test(cfg) | |
# safely retrieve metric value for hydra-based hyperparameter optimization | |
objective_metric = cfg.get("objective_metric") | |
if not objective_metric: | |
return None | |
if objective_metric not in metric_dict: | |
raise Exception( | |
f"Unable to find `objective_metric` ({objective_metric}) in `metric_dict`.\n" | |
"Make sure `objective_metric` name in `sweep` config is correct." | |
) | |
metric_value = metric_dict[objective_metric].item() | |
log.info(f"Retrieved objective_metric. <{objective_metric}={metric_value}>") | |
# return optimized metric | |
return metric_value | |
if __name__ == "__main__": | |
main() | |