Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| import hydra | |
| from omegaconf import DictConfig | |
| from lightning import LightningDataModule, LightningModule, Trainer, Callback | |
| from deepscreen.utils.hydra import checkpoint_rerun_config | |
| from deepscreen.utils import get_logger, job_wrapper, instantiate_callbacks | |
| log = get_logger(__name__) | |
| # def fix_dict_config(cfg: DictConfig): | |
| # """fix all vars in the cfg config | |
| # this is an in-place operation""" | |
| # keys = list(cfg.keys()) | |
| # for k in keys: | |
| # if type(cfg[k]) is DictConfig: | |
| # fix_dict_config(cfg[k]) | |
| # else: | |
| # setattr(cfg, k, getattr(cfg, k)) | |
| def predict(cfg: DictConfig) -> Tuple[list, dict]: | |
| """Predict given checkpoint on a data predict set. | |
| This method is wrapped in optional @job_wrapper decorator which applies extra utilities | |
| before and after the call. | |
| Args: | |
| cfg (DictConfig): Configuration composed by Hydra. | |
| Returns: | |
| Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. | |
| """ | |
| log.info(f"Instantiating data <{cfg.data._target_}>") | |
| datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) | |
| log.info(f"Instantiating model <{cfg.model._target_}>") | |
| model: LightningModule = hydra.utils.instantiate(cfg.model) | |
| log.info("Instantiating callbacks.") | |
| callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) | |
| log.info(f"Instantiating trainer <{cfg.trainer._target_}>") | |
| trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=False, callbacks=callbacks) | |
| object_dict = { | |
| "cfg": cfg, | |
| "datamodule": datamodule, | |
| "model": model, | |
| "callbacks": callbacks, | |
| "trainer": trainer, | |
| } | |
| log.info("Start predicting.") | |
| predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path, return_predictions=True) | |
| return predictions, object_dict | |
| def main(cfg: DictConfig): | |
| assert cfg.ckpt_path, "Checkpoint path (`ckpt_path`) must be specified for predicting." | |
| cfg = checkpoint_rerun_config(cfg) | |
| predictions, _ = predict(cfg) | |
| return predictions | |
| if __name__ == "__main__": | |
| main() | |