Spaces:
Sleeping
Sleeping
| from typing import List, Union | |
| import os | |
| import copy | |
| import click | |
| from click.core import Context, Option | |
| import numpy as np | |
| from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ | |
| from ding.config import read_config | |
| from .predefined_config import get_predefined_config | |
| def print_version(ctx: Context, param: Option, value: bool) -> None: | |
| if not value or ctx.resilient_parsing: | |
| return | |
| click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) | |
| click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) | |
| ctx.exit() | |
| def print_registry(ctx: Context, param: Option, value: str): | |
| if value is None: | |
| return | |
| from ding.utils import registries # noqa | |
| if value not in registries: | |
| click.echo('[ERROR]: not support registry name: {}'.format(value)) | |
| else: | |
| registered_info = registries[value].query_details() | |
| click.echo('Available {}: [{}]'.format(value, '|'.join(registered_info.keys()))) | |
| for alias, info in registered_info.items(): | |
| click.echo('\t{}: registered at {}#{}'.format(alias, info[0], info[1])) | |
| ctx.exit() | |
| CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) | |
| # the following arguments are only applied to dist mode | |
| def cli( | |
| # serial/eval | |
| mode: str, | |
| config: str, | |
| seed: Union[int, List], | |
| exp_name: str, | |
| env: str, | |
| policy: str, | |
| train_iter: str, # transform into int | |
| env_step: str, # transform into int | |
| load_path: str, | |
| replay_path: str, | |
| # parallel/dist | |
| platform: str, | |
| coordinator_host: str, | |
| coordinator_port: int, | |
| learner_host: str, | |
| learner_port: int, | |
| collector_host: str, | |
| collector_port: int, | |
| aggregator_host: str, | |
| aggregator_port: int, | |
| enable_total_log: bool, | |
| disable_flask_log: bool, | |
| module: str, | |
| module_name: str, | |
| # add/delete/restart | |
| add: str, | |
| delete: str, | |
| restart: str, | |
| kubeconfig: str, | |
| coordinator_name: str, | |
| namespace: str, | |
| replicas: int, | |
| cpus: int, | |
| gpus: int, | |
| memory: str, | |
| restart_pod_name: str, | |
| profile: str, | |
| ): | |
| if profile is not None: | |
| from ..utils.profiler_helper import Profiler | |
| profiler = Profiler() | |
| profiler.profile(profile) | |
| train_iter = int(float(train_iter)) | |
| env_step = int(float(env_step)) | |
| def run_single_pipeline(seed, config): | |
| if config is None: | |
| config = get_predefined_config(env, policy) | |
| else: | |
| config = read_config(config) | |
| if exp_name is not None: | |
| config[0].exp_name = exp_name | |
| if mode == 'serial': | |
| from .serial_entry import serial_pipeline | |
| serial_pipeline(config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_onpolicy': | |
| from .serial_entry_onpolicy import serial_pipeline_onpolicy | |
| serial_pipeline_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_sqil': | |
| from .serial_entry_sqil import serial_pipeline_sqil | |
| expert_config = input("Enter the name of the config you used to generate your expert model: ") | |
| serial_pipeline_sqil(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_reward_model': | |
| from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy | |
| serial_pipeline_reward_model_offpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_gail': | |
| from .serial_entry_gail import serial_pipeline_gail | |
| expert_config = input("Enter the name of the config you used to generate your expert model: ") | |
| serial_pipeline_gail( | |
| config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step, collect_data=True | |
| ) | |
| elif mode == 'serial_dqfd': | |
| from .serial_entry_dqfd import serial_pipeline_dqfd | |
| expert_config = input("Enter the name of the config you used to generate your expert model: ") | |
| assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\ | |
| + "the models used in q learning now; However, one should still type the DQFD config in this "\ | |
| + "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py') | |
| serial_pipeline_dqfd(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_trex': | |
| from .serial_entry_trex import serial_pipeline_trex | |
| serial_pipeline_trex(config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_trex_onpolicy': | |
| from .serial_entry_trex_onpolicy import serial_pipeline_trex_onpolicy | |
| serial_pipeline_trex_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step) | |
| elif mode == 'serial_offline': | |
| from .serial_entry_offline import serial_pipeline_offline | |
| serial_pipeline_offline(config, seed, max_train_iter=train_iter) | |
| elif mode == 'serial_ngu': | |
| from .serial_entry_ngu import serial_pipeline_ngu | |
| serial_pipeline_ngu(config, seed, max_train_iter=train_iter) | |
| elif mode == 'parallel': | |
| from .parallel_entry import parallel_pipeline | |
| parallel_pipeline(config, seed, enable_total_log, disable_flask_log) | |
| elif mode == 'dist': | |
| from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \ | |
| dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \ | |
| dist_add_replicas, dist_delete_replicas, dist_restart_replicas | |
| if module == 'config': | |
| dist_prepare_config( | |
| config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port, | |
| learner_port, collector_port | |
| ) | |
| elif module == 'coordinator': | |
| dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log) | |
| elif module == 'learner_aggregator': | |
| dist_launch_learner_aggregator( | |
| config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log | |
| ) | |
| elif module == 'collector': | |
| dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log) | |
| elif module == 'learner': | |
| dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log) | |
| elif module == 'spawn_learner': | |
| dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log) | |
| elif add in ['collector', 'learner']: | |
| dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory) | |
| elif delete in ['collector', 'learner']: | |
| dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace) | |
| elif restart in ['collector', 'learner']: | |
| dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name) | |
| else: | |
| raise Exception | |
| elif mode == 'eval': | |
| from .application_entry import eval | |
| eval(config, seed, load_path=load_path, replay_path=replay_path) | |
| if mode is None: | |
| raise RuntimeError("Please indicate at least one argument.") | |
| if isinstance(seed, (list, tuple)): | |
| assert len(seed) > 0, "Please input at least 1 seed" | |
| if len(seed) == 1: # necessary | |
| run_single_pipeline(seed[0], config) | |
| else: | |
| if exp_name is None: | |
| multi_exp_root = os.path.basename(config).split('.')[0] + '_result' | |
| else: | |
| multi_exp_root = exp_name | |
| if not os.path.exists(multi_exp_root): | |
| os.makedirs(multi_exp_root) | |
| abs_config_path = os.path.abspath(config) | |
| origin_root = os.getcwd() | |
| for s in seed: | |
| seed_exp_root = os.path.join(multi_exp_root, 'seed{}'.format(s)) | |
| if not os.path.exists(seed_exp_root): | |
| os.makedirs(seed_exp_root) | |
| os.chdir(seed_exp_root) | |
| run_single_pipeline(s, abs_config_path) | |
| os.chdir(origin_root) | |
| else: | |
| raise TypeError("invalid seed type: {}".format(type(seed))) | |