Spaces:
Sleeping
Sleeping
| from typing import List, Optional | |
| from pathlib import Path | |
| import torch | |
| import hydra | |
| from omegaconf import OmegaConf, DictConfig | |
| from pytorch_lightning import ( | |
| Callback, | |
| LightningDataModule, | |
| LightningModule, | |
| Trainer, | |
| seed_everything, | |
| ) | |
| from pytorch_lightning.loggers import LightningLoggerBase | |
| from src.utils import utils | |
| log = utils.get_logger(__name__) | |
| def remove_prefix(text: str, prefix: str): | |
| if text.startswith(prefix): | |
| return text[len(prefix) :] | |
| return text # or whatever | |
| def load_checkpoint(path, device='cpu'): | |
| path = Path(path).expanduser() | |
| if path.is_dir(): | |
| path /= 'last.ckpt' | |
| # dst = f'cuda:{torch.cuda.current_device()}' | |
| log.info(f'Loading checkpoint from {str(path)}') | |
| state_dict = torch.load(path, map_location=device) | |
| # T2T-ViT checkpoint is nested in the key 'state_dict_ema' | |
| if state_dict.keys() == {'state_dict_ema'}: | |
| state_dict = state_dict['state_dict_ema'] | |
| # Swin checkpoint is nested in the key 'model' | |
| if state_dict.keys() == {'model'}: | |
| state_dict = state_dict['model'] | |
| # Lightning checkpoint contains extra stuff, we only want the model state dict | |
| if 'pytorch-lightning_version' in state_dict: | |
| state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()} | |
| return state_dict | |
| def evaluate(config: DictConfig) -> None: | |
| """Example of inference with trained model. | |
| It loads trained image classification model from checkpoint. | |
| Then it loads example image and predicts its label. | |
| """ | |
| # load model from checkpoint | |
| # model __init__ parameters will be loaded from ckpt automatically | |
| # you can also pass some parameter explicitly to override it | |
| # We want to add fields to config so need to call OmegaConf.set_struct | |
| OmegaConf.set_struct(config, False) | |
| # load model | |
| checkpoint_type = config.eval.get('checkpoint_type', 'pytorch') | |
| if checkpoint_type not in ['lightning', 'pytorch']: | |
| raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported') | |
| if checkpoint_type == 'lightning': | |
| cls = hydra.utils.get_class(config.task._target_) | |
| model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt) | |
| elif checkpoint_type == 'pytorch': | |
| model_cfg = config.model_pretrained if 'model_pretrained' in config else None | |
| trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, | |
| model_cfg=model_cfg, | |
| _recursive_=False) | |
| if 'ckpt' in config.eval: | |
| load_return = trained_model.model.load_state_dict( | |
| load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False | |
| ) | |
| log.info(load_return) | |
| if 'model_pretrained' in config: | |
| ... | |
| else: | |
| model = trained_model | |
| datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) | |
| # datamodule: LightningDataModule = model._datamodule | |
| datamodule.prepare_data() | |
| datamodule.setup() | |
| # print model hyperparameters | |
| log.info(f'Model hyperparameters: {model.hparams}') | |
| # Init Lightning callbacks | |
| callbacks: List[Callback] = [] | |
| if "callbacks" in config: | |
| for _, cb_conf in config["callbacks"].items(): | |
| if cb_conf is not None and "_target_" in cb_conf: | |
| log.info(f"Instantiating callback <{cb_conf._target_}>") | |
| callbacks.append(hydra.utils.instantiate(cb_conf)) | |
| # Init Lightning loggers | |
| logger: List[LightningLoggerBase] = [] | |
| if "logger" in config: | |
| for _, lg_conf in config["logger"].items(): | |
| if lg_conf is not None and "_target_" in lg_conf: | |
| log.info(f"Instantiating logger <{lg_conf._target_}>") | |
| logger.append(hydra.utils.instantiate(lg_conf)) | |
| # Init Lightning trainer | |
| log.info(f"Instantiating trainer <{config.trainer._target_}>") | |
| trainer: Trainer = hydra.utils.instantiate( | |
| config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" | |
| ) | |
| # Evaluate the model | |
| log.info("Starting evaluation!") | |
| if config.eval.get('run_val', True): | |
| trainer.validate(model=model, datamodule=datamodule) | |
| if config.eval.get('run_test', True): | |
| trainer.test(model=model, datamodule=datamodule) | |
| # Make sure everything closed properly | |
| log.info("Finalizing!") | |
| utils.finish( | |
| config=config, | |
| model=model, | |
| datamodule=datamodule, | |
| trainer=trainer, | |
| callbacks=callbacks, | |
| logger=logger, | |
| ) | |