Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import atexit | |
| import logging | |
| import time | |
| from dataclasses import dataclass | |
| import os | |
| from pathlib import Path | |
| import socket | |
| from subprocess import Popen | |
| from threading import Thread | |
| import time | |
| from typing import Any, List, Optional, Union | |
| import colorama | |
| import psutil | |
| import torch | |
| import torch.nn as nn | |
| import nni.runtime.log | |
| from nni.experiment import Experiment, TrainingServiceConfig | |
| from nni.experiment import management, launcher, rest | |
| from nni.experiment.config import util | |
| from nni.experiment.config.base import ConfigBase, PathLike | |
| from nni.experiment.pipe import Pipe | |
| from nni.tools.nnictl.command_utils import kill_command | |
| from ..codegen import model_to_pytorch_script | |
| from ..converter import convert_to_graph | |
| from ..execution import list_models, set_execution_engine | |
| from ..execution.python import get_mutation_dict | |
| from ..graph import Model, Evaluator | |
| from ..integration import RetiariiAdvisor | |
| from ..mutator import Mutator | |
| from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module | |
| from ..strategy import BaseStrategy | |
| from ..oneshot.interface import BaseOneShotTrainer | |
| _logger = logging.getLogger(__name__) | |
| class RetiariiExeConfig(ConfigBase): | |
| experiment_name: Optional[str] = None | |
| search_space: Any = '' # TODO: remove | |
| trial_command: str = '_reserved' | |
| trial_code_directory: PathLike = '.' | |
| trial_concurrency: int | |
| trial_gpu_number: int = 0 | |
| max_experiment_duration: Optional[str] = None | |
| max_trial_number: Optional[int] = None | |
| nni_manager_ip: Optional[str] = None | |
| debug: bool = False | |
| log_level: Optional[str] = None | |
| experiment_working_directory: PathLike = '~/nni-experiments' | |
| # remove configuration of tuner/assessor/advisor | |
| training_service: TrainingServiceConfig | |
| execution_engine: str = 'py' | |
| def __init__(self, training_service_platform: Optional[str] = None, **kwargs): | |
| super().__init__(**kwargs) | |
| if training_service_platform is not None: | |
| assert 'training_service' not in kwargs | |
| self.training_service = util.training_service_config_factory(platform = training_service_platform) | |
| self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py' | |
| def __setattr__(self, key, value): | |
| fixed_attrs = {'search_space': '', | |
| 'trial_command': '_reserved'} | |
| if key in fixed_attrs and fixed_attrs[key] != value: | |
| raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') | |
| # 'trial_code_directory' is handled differently because the path will be converted to absolute path by us | |
| if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)): | |
| raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') | |
| if key == 'execution_engine': | |
| assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.' | |
| self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value | |
| self.__dict__[key] = value | |
| def validate(self, initialized_tuner: bool = False) -> None: | |
| super().validate() | |
| def _canonical_rules(self): | |
| return _canonical_rules | |
| def _validation_rules(self): | |
| return _validation_rules | |
| _canonical_rules = { | |
| 'trial_code_directory': util.canonical_path, | |
| 'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None, | |
| 'experiment_working_directory': util.canonical_path | |
| } | |
| _validation_rules = { | |
| 'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'), | |
| 'trial_concurrency': lambda value: value > 0, | |
| 'trial_gpu_number': lambda value: value >= 0, | |
| 'max_experiment_duration': lambda value: util.parse_time(value) > 0, | |
| 'max_trial_number': lambda value: value > 0, | |
| 'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"], | |
| 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') | |
| } | |
| def preprocess_model(base_model, trainer, applied_mutators, full_ir=True): | |
| # TODO: this logic might need to be refactored into execution engine | |
| if full_ir: | |
| try: | |
| script_module = torch.jit.script(base_model) | |
| except Exception as e: | |
| _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') | |
| raise e | |
| base_model_ir = convert_to_graph(script_module, base_model) | |
| # handle inline mutations | |
| mutators = process_inline_mutation(base_model_ir) | |
| else: | |
| base_model_ir, mutators = extract_mutation_from_pt_module(base_model) | |
| base_model_ir.evaluator = trainer | |
| if mutators is not None and applied_mutators: | |
| raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' | |
| 'do not use mutators when you use LayerChoice/InputChoice') | |
| if mutators is not None: | |
| applied_mutators = mutators | |
| return base_model_ir, applied_mutators | |
| def debug_mutated_model(base_model, trainer, applied_mutators): | |
| """ | |
| Locally run only one trial without launching an experiment for debug purpose, then exit. | |
| For example, it can be used to quickly check shape mismatch. | |
| Specifically, it applies mutators (default to choose the first candidate for the choices) | |
| to generate a new model, then run this model locally. | |
| Parameters | |
| ---------- | |
| base_model : nni.retiarii.nn.pytorch.nn.Module | |
| the base model | |
| trainer : nni.retiarii.evaluator | |
| the training class of the generated models | |
| applied_mutators : list | |
| a list of mutators that will be applied on the base model for generating a new model | |
| """ | |
| base_model_ir, applied_mutators = preprocess_model(base_model, trainer, applied_mutators) | |
| from ..strategy import _LocalDebugStrategy | |
| strategy = _LocalDebugStrategy() | |
| strategy.run(base_model_ir, applied_mutators) | |
| _logger.info('local debug completed!') | |
| class RetiariiExperiment(Experiment): | |
| def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer], | |
| applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None): | |
| # TODO: The current design of init interface of Retiarii experiment needs to be reviewed. | |
| self.config: RetiariiExeConfig = None | |
| self.port: Optional[int] = None | |
| self.base_model = base_model | |
| self.trainer = trainer | |
| self.applied_mutators = applied_mutators | |
| self.strategy = strategy | |
| self._dispatcher = RetiariiAdvisor() | |
| self._dispatcher_thread: Optional[Thread] = None | |
| self._proc: Optional[Popen] = None | |
| self._pipe: Optional[Pipe] = None | |
| def _start_strategy(self): | |
| base_model_ir, self.applied_mutators = preprocess_model( | |
| self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py') | |
| _logger.info('Start strategy...') | |
| self.strategy.run(base_model_ir, self.applied_mutators) | |
| _logger.info('Strategy exit') | |
| # TODO: find out a proper way to show no more trial message on WebUI | |
| #self._dispatcher.mark_experiment_as_ending() | |
| def start(self, port: int = 8080, debug: bool = False) -> None: | |
| """ | |
| Start the experiment in background. | |
| This method will raise exception on failure. | |
| If it returns, the experiment should have been successfully started. | |
| Parameters | |
| ---------- | |
| port | |
| The port of web UI. | |
| debug | |
| Whether to start in debug mode. | |
| """ | |
| atexit.register(self.stop) | |
| # we will probably need a execution engine factory to make this clean and elegant | |
| if self.config.execution_engine == 'base': | |
| from ..execution.base import BaseExecutionEngine | |
| engine = BaseExecutionEngine() | |
| elif self.config.execution_engine == 'cgo': | |
| from ..execution.cgo_engine import CGOExecutionEngine | |
| engine = CGOExecutionEngine() | |
| elif self.config.execution_engine == 'py': | |
| from ..execution.python import PurePythonExecutionEngine | |
| engine = PurePythonExecutionEngine() | |
| set_execution_engine(engine) | |
| self.id = management.generate_experiment_id() | |
| if self.config.experiment_working_directory is not None: | |
| log_dir = Path(self.config.experiment_working_directory, self.id, 'log') | |
| else: | |
| log_dir = Path.home() / f'nni-experiments/{self.id}/log' | |
| nni.runtime.log.start_experiment_log(self.id, log_dir, debug) | |
| self._proc, self._pipe = launcher.start_experiment_retiarii(self.id, self.config, port, debug) | |
| assert self._proc is not None | |
| assert self._pipe is not None | |
| self.port = port # port will be None if start up failed | |
| # dispatcher must be launched after pipe initialized | |
| # the logic to launch dispatcher in background should be refactored into dispatcher api | |
| self._dispatcher = self._create_dispatcher() | |
| self._dispatcher_thread = Thread(target=self._dispatcher.run) | |
| self._dispatcher_thread.start() | |
| ips = [self.config.nni_manager_ip] | |
| for interfaces in psutil.net_if_addrs().values(): | |
| for interface in interfaces: | |
| if interface.family == socket.AF_INET: | |
| ips.append(interface.address) | |
| ips = [f'http://{ip}:{port}' for ip in ips if ip] | |
| msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL | |
| _logger.info(msg) | |
| exp_status_checker = Thread(target=self._check_exp_status) | |
| exp_status_checker.start() | |
| self._start_strategy() | |
| # TODO: the experiment should be completed, when strategy exits and there is no running job | |
| _logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...') | |
| exp_status_checker.join() | |
| def _create_dispatcher(self): | |
| return self._dispatcher | |
| def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: | |
| """ | |
| Run the experiment. | |
| This function will block until experiment finish or error. | |
| """ | |
| if isinstance(self.trainer, BaseOneShotTrainer): | |
| self.trainer.fit() | |
| else: | |
| assert config is not None, 'You are using classic search mode, config cannot be None!' | |
| self.config = config | |
| self.start(port, debug) | |
| def _check_exp_status(self) -> bool: | |
| """ | |
| Run the experiment. | |
| This function will block until experiment finish or error. | |
| Return `True` when experiment done; or return `False` when experiment failed. | |
| """ | |
| try: | |
| while True: | |
| time.sleep(10) | |
| # this if is to deal with the situation that | |
| # nnimanager is cleaned up by ctrl+c first | |
| if self._proc.poll() is None: | |
| status = self.get_status() | |
| else: | |
| return False | |
| if status == 'DONE' or status == 'STOPPED': | |
| return True | |
| if status == 'ERROR': | |
| return False | |
| except KeyboardInterrupt: | |
| _logger.warning('KeyboardInterrupt detected') | |
| finally: | |
| self.stop() | |
| def stop(self) -> None: | |
| """ | |
| Stop background experiment. | |
| """ | |
| _logger.info('Stopping experiment, please wait...') | |
| atexit.unregister(self.stop) | |
| if self.id is not None: | |
| nni.runtime.log.stop_experiment_log(self.id) | |
| if self._proc is not None: | |
| try: | |
| # this if is to deal with the situation that | |
| # nnimanager is cleaned up by ctrl+c first | |
| if self._proc.poll() is None: | |
| rest.delete(self.port, '/experiment') | |
| except Exception as e: | |
| _logger.exception(e) | |
| _logger.warning('Cannot gracefully stop experiment, killing NNI process...') | |
| kill_command(self._proc.pid) | |
| if self._pipe is not None: | |
| self._pipe.close() | |
| if self._dispatcher_thread is not None: | |
| self._dispatcher.stopping = True | |
| self._dispatcher_thread.join(timeout=1) | |
| self.id = None | |
| self.port = None | |
| self._proc = None | |
| self._pipe = None | |
| self._dispatcher = None | |
| self._dispatcher_thread = None | |
| _logger.info('Experiment stopped') | |
| def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'dict') -> Any: | |
| """ | |
| Export several top performing models. | |
| For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are | |
| available for customization. | |
| top_k : int | |
| How many models are intended to be exported. | |
| optimize_mode : str | |
| ``maximize`` or ``minimize``. Not supported by one-shot algorithms. | |
| ``optimize_mode`` is likely to be removed and defined in strategy in future. | |
| formatter : str | |
| Support ``code`` and ``dict``. Not supported by one-shot algorithms. | |
| If ``code``, the python code of model will be returned. | |
| If ``dict``, the mutation history will be returned. | |
| """ | |
| if formatter == 'code': | |
| assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.' | |
| if isinstance(self.trainer, BaseOneShotTrainer): | |
| assert top_k == 1, 'Only support top_k is 1 for now.' | |
| return self.trainer.export() | |
| else: | |
| all_models = filter(lambda m: m.metric is not None, list_models()) | |
| assert optimize_mode in ['maximize', 'minimize'] | |
| all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize') | |
| assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.' | |
| if formatter == 'code': | |
| return [model_to_pytorch_script(model) for model in all_models[:top_k]] | |
| elif formatter == 'dict': | |
| return [get_mutation_dict(model) for model in all_models[:top_k]] | |
| def retrain_model(self, model): | |
| """ | |
| this function retrains the exported model, and test it to output test accuracy | |
| """ | |
| raise NotImplementedError | |