Spaces:
Sleeping
Sleeping
import os | |
import gym | |
import numpy as np | |
import copy | |
import torch | |
from tensorboardX import SummaryWriter | |
from ding.config import compile_config | |
from ding.worker import BaseLearner, BattleInteractionSerialEvaluator, NaiveReplayBuffer | |
from ding.envs import BaseEnvManager, DingEnvWrapper | |
from ding.policy import PPOPolicy | |
from ding.model import VAC | |
from ding.utils import set_pkg_seed | |
from dizoo.league_demo.game_env import GameEnv | |
from dizoo.league_demo.league_demo_collector import LeagueDemoCollector | |
from dizoo.league_demo.selfplay_demo_ppo_config import selfplay_demo_ppo_config | |
class EvalPolicy1: | |
def forward(self, data: dict) -> dict: | |
return {env_id: {'action': torch.zeros(1)} for env_id in data.keys()} | |
def reset(self, data_id: list = []) -> None: | |
pass | |
class EvalPolicy2: | |
def forward(self, data: dict) -> dict: | |
return { | |
env_id: { | |
'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, ))) | |
} | |
for env_id in data.keys() | |
} | |
def reset(self, data_id: list = []) -> None: | |
pass | |
def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)): | |
cfg = compile_config( | |
cfg, | |
BaseEnvManager, | |
PPOPolicy, | |
BaseLearner, | |
LeagueDemoCollector, | |
BattleInteractionSerialEvaluator, | |
NaiveReplayBuffer, | |
save_cfg=True | |
) | |
env_type = cfg.env.env_type | |
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num | |
collector_env = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager | |
) | |
evaluator_env1 = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager | |
) | |
evaluator_env2 = BaseEnvManager( | |
env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager | |
) | |
collector_env.seed(seed) | |
evaluator_env1.seed(seed, dynamic_seed=False) | |
evaluator_env2.seed(seed, dynamic_seed=False) | |
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | |
model1 = VAC(**cfg.policy.model) | |
policy1 = PPOPolicy(cfg.policy, model=model1) | |
model2 = VAC(**cfg.policy.model) | |
policy2 = PPOPolicy(cfg.policy, model=model2) | |
eval_policy1 = EvalPolicy1() | |
eval_policy2 = EvalPolicy2() | |
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) | |
learner1 = BaseLearner( | |
cfg.policy.learn.learner, policy1.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1' | |
) | |
learner2 = BaseLearner( | |
cfg.policy.learn.learner, policy2.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner2' | |
) | |
collector = LeagueDemoCollector( | |
cfg.policy.collect.collector, | |
collector_env, [policy1.collect_mode, policy2.collect_mode], | |
tb_logger, | |
exp_name=cfg.exp_name | |
) | |
# collect_mode ppo use multinomial sample for selecting action | |
evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator) | |
evaluator1_cfg.stop_value = cfg.env.stop_value[0] | |
evaluator1 = BattleInteractionSerialEvaluator( | |
evaluator1_cfg, | |
evaluator_env1, [policy1.collect_mode, eval_policy1], | |
tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name='fixed_evaluator' | |
) | |
evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator) | |
evaluator2_cfg.stop_value = cfg.env.stop_value[1] | |
evaluator2 = BattleInteractionSerialEvaluator( | |
evaluator2_cfg, | |
evaluator_env2, [policy1.collect_mode, eval_policy2], | |
tb_logger, | |
exp_name=cfg.exp_name, | |
instance_name='uniform_evaluator' | |
) | |
while True: | |
if evaluator1.should_eval(learner1.train_iter): | |
stop_flag1, _ = evaluator1.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep) | |
if evaluator2.should_eval(learner1.train_iter): | |
stop_flag2, _ = evaluator2.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep) | |
if stop_flag1 and stop_flag2: | |
break | |
train_data, _ = collector.collect(train_iter=learner1.train_iter) | |
for data in train_data: | |
for d in data: | |
d['adv'] = d['reward'] | |
for i in range(cfg.policy.learn.update_per_collect): | |
learner1.train(train_data[0], collector.envstep) | |
learner2.train(train_data[1], collector.envstep) | |
if collector.envstep >= max_env_step or learner1.train_iter >= max_train_iter: | |
break | |
if __name__ == "__main__": | |
main(selfplay_demo_ppo_config) | |