Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import logging | |
| import os.path as osp | |
| import warnings | |
| from abc import ABCMeta, abstractmethod | |
| import torch | |
| from torch.optim import Optimizer | |
| import annotator.uniformer.mmcv as mmcv | |
| from ..parallel import is_module_wrapper | |
| from .checkpoint import load_checkpoint | |
| from .dist_utils import get_dist_info | |
| from .hooks import HOOKS, Hook | |
| from .log_buffer import LogBuffer | |
| from .priority import Priority, get_priority | |
| from .utils import get_time_str | |
| class BaseRunner(metaclass=ABCMeta): | |
| """The base class of Runner, a training helper for PyTorch. | |
| All subclasses should implement the following APIs: | |
| - ``run()`` | |
| - ``train()`` | |
| - ``val()`` | |
| - ``save_checkpoint()`` | |
| Args: | |
| model (:obj:`torch.nn.Module`): The model to be run. | |
| batch_processor (callable): A callable method that process a data | |
| batch. The interface of this method should be | |
| `batch_processor(model, data, train_mode) -> dict` | |
| optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an | |
| optimizer (in most cases) or a dict of optimizers (in models that | |
| requires more than one optimizer, e.g., GAN). | |
| work_dir (str, optional): The working directory to save checkpoints | |
| and logs. Defaults to None. | |
| logger (:obj:`logging.Logger`): Logger used during training. | |
| Defaults to None. (The default value is just for backward | |
| compatibility) | |
| meta (dict | None): A dict records some import information such as | |
| environment info and seed, which will be logged in logger hook. | |
| Defaults to None. | |
| max_epochs (int, optional): Total training epochs. | |
| max_iters (int, optional): Total training iterations. | |
| """ | |
| def __init__(self, | |
| model, | |
| batch_processor=None, | |
| optimizer=None, | |
| work_dir=None, | |
| logger=None, | |
| meta=None, | |
| max_iters=None, | |
| max_epochs=None): | |
| if batch_processor is not None: | |
| if not callable(batch_processor): | |
| raise TypeError('batch_processor must be callable, ' | |
| f'but got {type(batch_processor)}') | |
| warnings.warn('batch_processor is deprecated, please implement ' | |
| 'train_step() and val_step() in the model instead.') | |
| # raise an error is `batch_processor` is not None and | |
| # `model.train_step()` exists. | |
| if is_module_wrapper(model): | |
| _model = model.module | |
| else: | |
| _model = model | |
| if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'): | |
| raise RuntimeError( | |
| 'batch_processor and model.train_step()/model.val_step() ' | |
| 'cannot be both available.') | |
| else: | |
| assert hasattr(model, 'train_step') | |
| # check the type of `optimizer` | |
| if isinstance(optimizer, dict): | |
| for name, optim in optimizer.items(): | |
| if not isinstance(optim, Optimizer): | |
| raise TypeError( | |
| f'optimizer must be a dict of torch.optim.Optimizers, ' | |
| f'but optimizer["{name}"] is a {type(optim)}') | |
| elif not isinstance(optimizer, Optimizer) and optimizer is not None: | |
| raise TypeError( | |
| f'optimizer must be a torch.optim.Optimizer object ' | |
| f'or dict or None, but got {type(optimizer)}') | |
| # check the type of `logger` | |
| if not isinstance(logger, logging.Logger): | |
| raise TypeError(f'logger must be a logging.Logger object, ' | |
| f'but got {type(logger)}') | |
| # check the type of `meta` | |
| if meta is not None and not isinstance(meta, dict): | |
| raise TypeError( | |
| f'meta must be a dict or None, but got {type(meta)}') | |
| self.model = model | |
| self.batch_processor = batch_processor | |
| self.optimizer = optimizer | |
| self.logger = logger | |
| self.meta = meta | |
| # create work_dir | |
| if mmcv.is_str(work_dir): | |
| self.work_dir = osp.abspath(work_dir) | |
| mmcv.mkdir_or_exist(self.work_dir) | |
| elif work_dir is None: | |
| self.work_dir = None | |
| else: | |
| raise TypeError('"work_dir" must be a str or None') | |
| # get model name from the model class | |
| if hasattr(self.model, 'module'): | |
| self._model_name = self.model.module.__class__.__name__ | |
| else: | |
| self._model_name = self.model.__class__.__name__ | |
| self._rank, self._world_size = get_dist_info() | |
| self.timestamp = get_time_str() | |
| self.mode = None | |
| self._hooks = [] | |
| self._epoch = 0 | |
| self._iter = 0 | |
| self._inner_iter = 0 | |
| if max_epochs is not None and max_iters is not None: | |
| raise ValueError( | |
| 'Only one of `max_epochs` or `max_iters` can be set.') | |
| self._max_epochs = max_epochs | |
| self._max_iters = max_iters | |
| # TODO: Redesign LogBuffer, it is not flexible and elegant enough | |
| self.log_buffer = LogBuffer() | |
| def model_name(self): | |
| """str: Name of the model, usually the module class name.""" | |
| return self._model_name | |
| def rank(self): | |
| """int: Rank of current process. (distributed training)""" | |
| return self._rank | |
| def world_size(self): | |
| """int: Number of processes participating in the job. | |
| (distributed training)""" | |
| return self._world_size | |
| def hooks(self): | |
| """list[:obj:`Hook`]: A list of registered hooks.""" | |
| return self._hooks | |
| def epoch(self): | |
| """int: Current epoch.""" | |
| return self._epoch | |
| def iter(self): | |
| """int: Current iteration.""" | |
| return self._iter | |
| def inner_iter(self): | |
| """int: Iteration in an epoch.""" | |
| return self._inner_iter | |
| def max_epochs(self): | |
| """int: Maximum training epochs.""" | |
| return self._max_epochs | |
| def max_iters(self): | |
| """int: Maximum training iterations.""" | |
| return self._max_iters | |
| def train(self): | |
| pass | |
| def val(self): | |
| pass | |
| def run(self, data_loaders, workflow, **kwargs): | |
| pass | |
| def save_checkpoint(self, | |
| out_dir, | |
| filename_tmpl, | |
| save_optimizer=True, | |
| meta=None, | |
| create_symlink=True): | |
| pass | |
| def current_lr(self): | |
| """Get current learning rates. | |
| Returns: | |
| list[float] | dict[str, list[float]]: Current learning rates of all | |
| param groups. If the runner has a dict of optimizers, this | |
| method will return a dict. | |
| """ | |
| if isinstance(self.optimizer, torch.optim.Optimizer): | |
| lr = [group['lr'] for group in self.optimizer.param_groups] | |
| elif isinstance(self.optimizer, dict): | |
| lr = dict() | |
| for name, optim in self.optimizer.items(): | |
| lr[name] = [group['lr'] for group in optim.param_groups] | |
| else: | |
| raise RuntimeError( | |
| 'lr is not applicable because optimizer does not exist.') | |
| return lr | |
| def current_momentum(self): | |
| """Get current momentums. | |
| Returns: | |
| list[float] | dict[str, list[float]]: Current momentums of all | |
| param groups. If the runner has a dict of optimizers, this | |
| method will return a dict. | |
| """ | |
| def _get_momentum(optimizer): | |
| momentums = [] | |
| for group in optimizer.param_groups: | |
| if 'momentum' in group.keys(): | |
| momentums.append(group['momentum']) | |
| elif 'betas' in group.keys(): | |
| momentums.append(group['betas'][0]) | |
| else: | |
| momentums.append(0) | |
| return momentums | |
| if self.optimizer is None: | |
| raise RuntimeError( | |
| 'momentum is not applicable because optimizer does not exist.') | |
| elif isinstance(self.optimizer, torch.optim.Optimizer): | |
| momentums = _get_momentum(self.optimizer) | |
| elif isinstance(self.optimizer, dict): | |
| momentums = dict() | |
| for name, optim in self.optimizer.items(): | |
| momentums[name] = _get_momentum(optim) | |
| return momentums | |
| def register_hook(self, hook, priority='NORMAL'): | |
| """Register a hook into the hook list. | |
| The hook will be inserted into a priority queue, with the specified | |
| priority (See :class:`Priority` for details of priorities). | |
| For hooks with the same priority, they will be triggered in the same | |
| order as they are registered. | |
| Args: | |
| hook (:obj:`Hook`): The hook to be registered. | |
| priority (int or str or :obj:`Priority`): Hook priority. | |
| Lower value means higher priority. | |
| """ | |
| assert isinstance(hook, Hook) | |
| if hasattr(hook, 'priority'): | |
| raise ValueError('"priority" is a reserved attribute for hooks') | |
| priority = get_priority(priority) | |
| hook.priority = priority | |
| # insert the hook to a sorted list | |
| inserted = False | |
| for i in range(len(self._hooks) - 1, -1, -1): | |
| if priority >= self._hooks[i].priority: | |
| self._hooks.insert(i + 1, hook) | |
| inserted = True | |
| break | |
| if not inserted: | |
| self._hooks.insert(0, hook) | |
| def register_hook_from_cfg(self, hook_cfg): | |
| """Register a hook from its cfg. | |
| Args: | |
| hook_cfg (dict): Hook config. It should have at least keys 'type' | |
| and 'priority' indicating its type and priority. | |
| Notes: | |
| The specific hook class to register should not use 'type' and | |
| 'priority' arguments during initialization. | |
| """ | |
| hook_cfg = hook_cfg.copy() | |
| priority = hook_cfg.pop('priority', 'NORMAL') | |
| hook = mmcv.build_from_cfg(hook_cfg, HOOKS) | |
| self.register_hook(hook, priority=priority) | |
| def call_hook(self, fn_name): | |
| """Call all hooks. | |
| Args: | |
| fn_name (str): The function name in each hook to be called, such as | |
| "before_train_epoch". | |
| """ | |
| for hook in self._hooks: | |
| getattr(hook, fn_name)(self) | |
| def get_hook_info(self): | |
| # Get hooks info in each stage | |
| stage_hook_map = {stage: [] for stage in Hook.stages} | |
| for hook in self.hooks: | |
| try: | |
| priority = Priority(hook.priority).name | |
| except ValueError: | |
| priority = hook.priority | |
| classname = hook.__class__.__name__ | |
| hook_info = f'({priority:<12}) {classname:<35}' | |
| for trigger_stage in hook.get_triggered_stages(): | |
| stage_hook_map[trigger_stage].append(hook_info) | |
| stage_hook_infos = [] | |
| for stage in Hook.stages: | |
| hook_infos = stage_hook_map[stage] | |
| if len(hook_infos) > 0: | |
| info = f'{stage}:\n' | |
| info += '\n'.join(hook_infos) | |
| info += '\n -------------------- ' | |
| stage_hook_infos.append(info) | |
| return '\n'.join(stage_hook_infos) | |
| def load_checkpoint(self, | |
| filename, | |
| map_location='cpu', | |
| strict=False, | |
| revise_keys=[(r'^module.', '')]): | |
| return load_checkpoint( | |
| self.model, | |
| filename, | |
| map_location, | |
| strict, | |
| self.logger, | |
| revise_keys=revise_keys) | |
| def resume(self, | |
| checkpoint, | |
| resume_optimizer=True, | |
| map_location='default'): | |
| if map_location == 'default': | |
| if torch.cuda.is_available(): | |
| device_id = torch.cuda.current_device() | |
| checkpoint = self.load_checkpoint( | |
| checkpoint, | |
| map_location=lambda storage, loc: storage.cuda(device_id)) | |
| else: | |
| checkpoint = self.load_checkpoint(checkpoint) | |
| else: | |
| checkpoint = self.load_checkpoint( | |
| checkpoint, map_location=map_location) | |
| self._epoch = checkpoint['meta']['epoch'] | |
| self._iter = checkpoint['meta']['iter'] | |
| if self.meta is None: | |
| self.meta = {} | |
| self.meta.setdefault('hook_msgs', {}) | |
| # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages | |
| self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) | |
| # Re-calculate the number of iterations when resuming | |
| # models with different number of GPUs | |
| if 'config' in checkpoint['meta']: | |
| config = mmcv.Config.fromstring( | |
| checkpoint['meta']['config'], file_format='.py') | |
| previous_gpu_ids = config.get('gpu_ids', None) | |
| if previous_gpu_ids and len(previous_gpu_ids) > 0 and len( | |
| previous_gpu_ids) != self.world_size: | |
| self._iter = int(self._iter * len(previous_gpu_ids) / | |
| self.world_size) | |
| self.logger.info('the iteration number is changed due to ' | |
| 'change of GPU number') | |
| # resume meta information meta | |
| self.meta = checkpoint['meta'] | |
| if 'optimizer' in checkpoint and resume_optimizer: | |
| if isinstance(self.optimizer, Optimizer): | |
| self.optimizer.load_state_dict(checkpoint['optimizer']) | |
| elif isinstance(self.optimizer, dict): | |
| for k in self.optimizer.keys(): | |
| self.optimizer[k].load_state_dict( | |
| checkpoint['optimizer'][k]) | |
| else: | |
| raise TypeError( | |
| 'Optimizer should be dict or torch.optim.Optimizer ' | |
| f'but got {type(self.optimizer)}') | |
| self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) | |
| def register_lr_hook(self, lr_config): | |
| if lr_config is None: | |
| return | |
| elif isinstance(lr_config, dict): | |
| assert 'policy' in lr_config | |
| policy_type = lr_config.pop('policy') | |
| # If the type of policy is all in lower case, e.g., 'cyclic', | |
| # then its first letter will be capitalized, e.g., to be 'Cyclic'. | |
| # This is for the convenient usage of Lr updater. | |
| # Since this is not applicable for ` | |
| # CosineAnnealingLrUpdater`, | |
| # the string will not be changed if it contains capital letters. | |
| if policy_type == policy_type.lower(): | |
| policy_type = policy_type.title() | |
| hook_type = policy_type + 'LrUpdaterHook' | |
| lr_config['type'] = hook_type | |
| hook = mmcv.build_from_cfg(lr_config, HOOKS) | |
| else: | |
| hook = lr_config | |
| self.register_hook(hook, priority='VERY_HIGH') | |
| def register_momentum_hook(self, momentum_config): | |
| if momentum_config is None: | |
| return | |
| if isinstance(momentum_config, dict): | |
| assert 'policy' in momentum_config | |
| policy_type = momentum_config.pop('policy') | |
| # If the type of policy is all in lower case, e.g., 'cyclic', | |
| # then its first letter will be capitalized, e.g., to be 'Cyclic'. | |
| # This is for the convenient usage of momentum updater. | |
| # Since this is not applicable for | |
| # `CosineAnnealingMomentumUpdater`, | |
| # the string will not be changed if it contains capital letters. | |
| if policy_type == policy_type.lower(): | |
| policy_type = policy_type.title() | |
| hook_type = policy_type + 'MomentumUpdaterHook' | |
| momentum_config['type'] = hook_type | |
| hook = mmcv.build_from_cfg(momentum_config, HOOKS) | |
| else: | |
| hook = momentum_config | |
| self.register_hook(hook, priority='HIGH') | |
| def register_optimizer_hook(self, optimizer_config): | |
| if optimizer_config is None: | |
| return | |
| if isinstance(optimizer_config, dict): | |
| optimizer_config.setdefault('type', 'OptimizerHook') | |
| hook = mmcv.build_from_cfg(optimizer_config, HOOKS) | |
| else: | |
| hook = optimizer_config | |
| self.register_hook(hook, priority='ABOVE_NORMAL') | |
| def register_checkpoint_hook(self, checkpoint_config): | |
| if checkpoint_config is None: | |
| return | |
| if isinstance(checkpoint_config, dict): | |
| checkpoint_config.setdefault('type', 'CheckpointHook') | |
| hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) | |
| else: | |
| hook = checkpoint_config | |
| self.register_hook(hook, priority='NORMAL') | |
| def register_logger_hooks(self, log_config): | |
| if log_config is None: | |
| return | |
| log_interval = log_config['interval'] | |
| for info in log_config['hooks']: | |
| logger_hook = mmcv.build_from_cfg( | |
| info, HOOKS, default_args=dict(interval=log_interval)) | |
| self.register_hook(logger_hook, priority='VERY_LOW') | |
| def register_timer_hook(self, timer_config): | |
| if timer_config is None: | |
| return | |
| if isinstance(timer_config, dict): | |
| timer_config_ = copy.deepcopy(timer_config) | |
| hook = mmcv.build_from_cfg(timer_config_, HOOKS) | |
| else: | |
| hook = timer_config | |
| self.register_hook(hook, priority='LOW') | |
| def register_custom_hooks(self, custom_config): | |
| if custom_config is None: | |
| return | |
| if not isinstance(custom_config, list): | |
| custom_config = [custom_config] | |
| for item in custom_config: | |
| if isinstance(item, dict): | |
| self.register_hook_from_cfg(item) | |
| else: | |
| self.register_hook(item, priority='NORMAL') | |
| def register_profiler_hook(self, profiler_config): | |
| if profiler_config is None: | |
| return | |
| if isinstance(profiler_config, dict): | |
| profiler_config.setdefault('type', 'ProfilerHook') | |
| hook = mmcv.build_from_cfg(profiler_config, HOOKS) | |
| else: | |
| hook = profiler_config | |
| self.register_hook(hook) | |
| def register_training_hooks(self, | |
| lr_config, | |
| optimizer_config=None, | |
| checkpoint_config=None, | |
| log_config=None, | |
| momentum_config=None, | |
| timer_config=dict(type='IterTimerHook'), | |
| custom_hooks_config=None): | |
| """Register default and custom hooks for training. | |
| Default and custom hooks include: | |
| +----------------------+-------------------------+ | |
| | Hooks | Priority | | |
| +======================+=========================+ | |
| | LrUpdaterHook | VERY_HIGH (10) | | |
| +----------------------+-------------------------+ | |
| | MomentumUpdaterHook | HIGH (30) | | |
| +----------------------+-------------------------+ | |
| | OptimizerStepperHook | ABOVE_NORMAL (40) | | |
| +----------------------+-------------------------+ | |
| | CheckpointSaverHook | NORMAL (50) | | |
| +----------------------+-------------------------+ | |
| | IterTimerHook | LOW (70) | | |
| +----------------------+-------------------------+ | |
| | LoggerHook(s) | VERY_LOW (90) | | |
| +----------------------+-------------------------+ | |
| | CustomHook(s) | defaults to NORMAL (50) | | |
| +----------------------+-------------------------+ | |
| If custom hooks have same priority with default hooks, custom hooks | |
| will be triggered after default hooks. | |
| """ | |
| self.register_lr_hook(lr_config) | |
| self.register_momentum_hook(momentum_config) | |
| self.register_optimizer_hook(optimizer_config) | |
| self.register_checkpoint_hook(checkpoint_config) | |
| self.register_timer_hook(timer_config) | |
| self.register_logger_hooks(log_config) | |
| self.register_custom_hooks(custom_hooks_config) | |