Spaces:
Sleeping
Sleeping
| from typing import Optional, Union, Tuple | |
| import time | |
| import pickle | |
| from ditk import logging | |
| from multiprocessing import Process, Event | |
| import threading | |
| from easydict import EasyDict | |
| from ding.worker import create_comm_learner, create_comm_collector, Coordinator | |
| from ding.config import read_config_with_system, compile_config_parallel | |
| from ding.utils import set_pkg_seed | |
| def parallel_pipeline( | |
| input_cfg: Union[str, Tuple[dict, dict, dict]], | |
| seed: int, | |
| enable_total_log: Optional[bool] = False, | |
| disable_flask_log: Optional[bool] = True, | |
| ) -> None: | |
| r""" | |
| Overview: | |
| Parallel pipeline entry. | |
| Arguments: | |
| - config (:obj:`Union[str, dict]`): Config file path. | |
| - seed (:obj:`int`): Random seed. | |
| - enable_total_log (:obj:`Optional[bool]`): whether enable total DI-engine system log | |
| - disable_flask_log (:obj:`Optional[bool]`): whether disable flask log | |
| """ | |
| # Disable some part of DI-engine log | |
| if not enable_total_log: | |
| coordinator_log = logging.getLogger('coordinator_logger') | |
| coordinator_log.disabled = True | |
| # Disable flask logger. | |
| if disable_flask_log: | |
| log = logging.getLogger('werkzeug') | |
| log.disabled = True | |
| # Parallel job launch. | |
| if isinstance(input_cfg, str): | |
| main_cfg, create_cfg, system_cfg = read_config_with_system(input_cfg) | |
| elif isinstance(input_cfg, tuple) or isinstance(input_cfg, list): | |
| main_cfg, create_cfg, system_cfg = input_cfg | |
| else: | |
| raise TypeError("invalid config type: {}".format(input_cfg)) | |
| config = compile_config_parallel(main_cfg, create_cfg=create_cfg, system_cfg=system_cfg, seed=seed) | |
| learner_handle = [] | |
| collector_handle = [] | |
| for k, v in config.system.items(): | |
| if 'learner' in k: | |
| learner_handle.append(launch_learner(config.seed, v)) | |
| elif 'collector' in k: | |
| collector_handle.append(launch_collector(config.seed, v)) | |
| launch_coordinator(config.seed, config, learner_handle=learner_handle, collector_handle=collector_handle) | |
| # Following functions are used to launch different components(learner, learner aggregator, collector, coordinator). | |
| # Argument ``config`` is the dict type config. If it is None, then ``filename`` and ``name`` must be passed, | |
| # for they can be used to read corresponding config from file. | |
| def run_learner(config, seed, start_learner_event, close_learner_event): | |
| set_pkg_seed(seed) | |
| log = logging.getLogger('werkzeug') | |
| log.disabled = True | |
| learner = create_comm_learner(config) | |
| learner.start() | |
| start_learner_event.set() | |
| close_learner_event.wait() | |
| learner.close() | |
| def launch_learner( | |
| seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None | |
| ) -> list: | |
| if config is None: | |
| with open(filename, 'rb') as f: | |
| config = pickle.load(f)[name] | |
| start_learner_event = Event() | |
| close_learner_event = Event() | |
| learner_thread = Process( | |
| target=run_learner, args=(config, seed, start_learner_event, close_learner_event), name='learner_entry_process' | |
| ) | |
| learner_thread.start() | |
| return learner_thread, start_learner_event, close_learner_event | |
| def run_collector(config, seed, start_collector_event, close_collector_event): | |
| set_pkg_seed(seed) | |
| log = logging.getLogger('werkzeug') | |
| log.disabled = True | |
| collector = create_comm_collector(config) | |
| collector.start() | |
| start_collector_event.set() | |
| close_collector_event.wait() | |
| collector.close() | |
| def launch_collector( | |
| seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None | |
| ) -> list: | |
| if config is None: | |
| with open(filename, 'rb') as f: | |
| config = pickle.load(f)[name] | |
| start_collector_event = Event() | |
| close_collector_event = Event() | |
| collector_thread = Process( | |
| target=run_collector, | |
| args=(config, seed, start_collector_event, close_collector_event), | |
| name='collector_entry_process' | |
| ) | |
| collector_thread.start() | |
| return collector_thread, start_collector_event, close_collector_event | |
| def launch_coordinator( | |
| seed: int, | |
| config: Optional[EasyDict] = None, | |
| filename: Optional[str] = None, | |
| learner_handle: Optional[list] = None, | |
| collector_handle: Optional[list] = None | |
| ) -> None: | |
| set_pkg_seed(seed) | |
| if config is None: | |
| with open(filename, 'rb') as f: | |
| config = pickle.load(f) | |
| coordinator = Coordinator(config) | |
| for _, start_event, _ in learner_handle: | |
| start_event.wait() | |
| for _, start_event, _ in collector_handle: | |
| start_event.wait() | |
| coordinator.start() | |
| system_shutdown_event = threading.Event() | |
| # Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False. | |
| def shutdown_monitor(): | |
| while True: | |
| time.sleep(3) | |
| if coordinator.system_shutdown_flag: | |
| coordinator.close() | |
| for _, _, close_event in learner_handle: | |
| close_event.set() | |
| for _, _, close_event in collector_handle: | |
| close_event.set() | |
| system_shutdown_event.set() | |
| break | |
| shutdown_monitor_thread = threading.Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor') | |
| shutdown_monitor_thread.start() | |
| system_shutdown_event.wait() | |
| print( | |
| "[DI-engine parallel pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details" | |
| ) | |