Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from ...dist_utils import master_only | |
| from ..hook import HOOKS | |
| from .base import LoggerHook | |
| class NeptuneLoggerHook(LoggerHook): | |
| """Class to log metrics to NeptuneAI. | |
| It requires `neptune-client` to be installed. | |
| Args: | |
| init_kwargs (dict): a dict contains the initialization keys as below: | |
| - project (str): Name of a project in a form of | |
| namespace/project_name. If None, the value of | |
| NEPTUNE_PROJECT environment variable will be taken. | |
| - api_token (str): User’s API token. | |
| If None, the value of NEPTUNE_API_TOKEN environment | |
| variable will be taken. Note: It is strongly recommended | |
| to use NEPTUNE_API_TOKEN environment variable rather than | |
| placing your API token in plain text in your source code. | |
| - name (str, optional, default is 'Untitled'): Editable name of | |
| the run. Name is displayed in the run's Details and in | |
| Runs table as a column. | |
| Check https://docs.neptune.ai/api-reference/neptune#init for | |
| more init arguments. | |
| interval (int): Logging interval (every k iterations). | |
| ignore_last (bool): Ignore the log of last iterations in each epoch | |
| if less than `interval`. | |
| reset_flag (bool): Whether to clear the output buffer after logging | |
| by_epoch (bool): Whether EpochBasedRunner is used. | |
| .. _NeptuneAI: | |
| https://docs.neptune.ai/you-should-know/logging-metadata | |
| """ | |
| def __init__(self, | |
| init_kwargs=None, | |
| interval=10, | |
| ignore_last=True, | |
| reset_flag=True, | |
| with_step=True, | |
| by_epoch=True): | |
| super(NeptuneLoggerHook, self).__init__(interval, ignore_last, | |
| reset_flag, by_epoch) | |
| self.import_neptune() | |
| self.init_kwargs = init_kwargs | |
| self.with_step = with_step | |
| def import_neptune(self): | |
| try: | |
| import neptune.new as neptune | |
| except ImportError: | |
| raise ImportError( | |
| 'Please run "pip install neptune-client" to install neptune') | |
| self.neptune = neptune | |
| self.run = None | |
| def before_run(self, runner): | |
| if self.init_kwargs: | |
| self.run = self.neptune.init(**self.init_kwargs) | |
| else: | |
| self.run = self.neptune.init() | |
| def log(self, runner): | |
| tags = self.get_loggable_tags(runner) | |
| if tags: | |
| for tag_name, tag_value in tags.items(): | |
| if self.with_step: | |
| self.run[tag_name].log( | |
| tag_value, step=self.get_iter(runner)) | |
| else: | |
| tags['global_step'] = self.get_iter(runner) | |
| self.run[tag_name].log(tags) | |
| def after_run(self, runner): | |
| self.run.stop() | |