Spaces:
Sleeping
Sleeping
init space
Browse files- LightZero/.gitignore +3 -8
- LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp +792 -0
- LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h +91 -0
- LightZero/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp +1154 -0
- LightZero/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.h +109 -0
- LightZero/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp +715 -0
- LightZero/lzero/mcts/ctree/ctree_muzero/lib/cnode.h +91 -0
- LightZero/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp +1189 -0
- LightZero/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.h +123 -0
- LightZero/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp +787 -0
- LightZero/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h +95 -0
LightZero/.gitignore
CHANGED
|
@@ -741,8 +741,8 @@ develop-eggs/
|
|
| 741 |
downloads/
|
| 742 |
eggs/
|
| 743 |
.eggs/
|
| 744 |
-
|
| 745 |
-
|
| 746 |
parts/
|
| 747 |
sdist/
|
| 748 |
var/
|
|
@@ -982,11 +982,6 @@ dist
|
|
| 982 |
### VirtualEnv template
|
| 983 |
# Virtualenv
|
| 984 |
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
| 985 |
-
[Bb]in
|
| 986 |
-
[Ii]nclude
|
| 987 |
-
[Ll]ib
|
| 988 |
-
[Ll]ib64
|
| 989 |
-
[Ll]ocal
|
| 990 |
pyvenv.cfg
|
| 991 |
pip-selfcheck.json
|
| 992 |
|
|
@@ -1050,7 +1045,7 @@ Temporary Items
|
|
| 1050 |
*.gch
|
| 1051 |
|
| 1052 |
# Libraries
|
| 1053 |
-
|
| 1054 |
*.a
|
| 1055 |
*.la
|
| 1056 |
*.lo
|
|
|
|
| 741 |
downloads/
|
| 742 |
eggs/
|
| 743 |
.eggs/
|
| 744 |
+
|
| 745 |
+
|
| 746 |
parts/
|
| 747 |
sdist/
|
| 748 |
var/
|
|
|
|
| 982 |
### VirtualEnv template
|
| 983 |
# Virtualenv
|
| 984 |
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
pyvenv.cfg
|
| 986 |
pip-selfcheck.json
|
| 987 |
|
|
|
|
| 1045 |
*.gch
|
| 1046 |
|
| 1047 |
# Libraries
|
| 1048 |
+
|
| 1049 |
*.a
|
| 1050 |
*.la
|
| 1051 |
*.lo
|
LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp
ADDED
|
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include "cnode.h"
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <map>
|
| 7 |
+
#include <cassert>
|
| 8 |
+
|
| 9 |
+
#ifdef _WIN32
|
| 10 |
+
#include "..\..\common_lib\utils.cpp"
|
| 11 |
+
#else
|
| 12 |
+
#include "../../common_lib/utils.cpp"
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
namespace tree
|
| 17 |
+
{
|
| 18 |
+
|
| 19 |
+
CSearchResults::CSearchResults()
|
| 20 |
+
{
|
| 21 |
+
/*
|
| 22 |
+
Overview:
|
| 23 |
+
Initialization of CSearchResults, the default result number is set to 0.
|
| 24 |
+
*/
|
| 25 |
+
this->num = 0;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
CSearchResults::CSearchResults(int num)
|
| 29 |
+
{
|
| 30 |
+
/*
|
| 31 |
+
Overview:
|
| 32 |
+
Initialization of CSearchResults with result number.
|
| 33 |
+
*/
|
| 34 |
+
this->num = num;
|
| 35 |
+
for (int i = 0; i < num; ++i)
|
| 36 |
+
{
|
| 37 |
+
this->search_paths.push_back(std::vector<CNode *>());
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
CSearchResults::~CSearchResults() {}
|
| 42 |
+
|
| 43 |
+
//*********************************************************
|
| 44 |
+
|
| 45 |
+
CNode::CNode()
|
| 46 |
+
{
|
| 47 |
+
/*
|
| 48 |
+
Overview:
|
| 49 |
+
Initialization of CNode.
|
| 50 |
+
*/
|
| 51 |
+
this->prior = 0;
|
| 52 |
+
this->legal_actions = legal_actions;
|
| 53 |
+
|
| 54 |
+
this->is_reset = 0;
|
| 55 |
+
this->visit_count = 0;
|
| 56 |
+
this->value_sum = 0;
|
| 57 |
+
this->best_action = -1;
|
| 58 |
+
this->to_play = 0;
|
| 59 |
+
this->value_prefix = 0.0;
|
| 60 |
+
this->parent_value_prefix = 0.0;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
CNode::CNode(float prior, std::vector<int> &legal_actions)
|
| 64 |
+
{
|
| 65 |
+
/*
|
| 66 |
+
Overview:
|
| 67 |
+
Initialization of CNode with prior value and legal actions.
|
| 68 |
+
Arguments:
|
| 69 |
+
- prior: the prior value of this node.
|
| 70 |
+
- legal_actions: a vector of legal actions of this node.
|
| 71 |
+
*/
|
| 72 |
+
this->prior = prior;
|
| 73 |
+
this->legal_actions = legal_actions;
|
| 74 |
+
|
| 75 |
+
this->is_reset = 0;
|
| 76 |
+
this->visit_count = 0;
|
| 77 |
+
this->value_sum = 0;
|
| 78 |
+
this->best_action = -1;
|
| 79 |
+
this->to_play = 0;
|
| 80 |
+
this->value_prefix = 0.0;
|
| 81 |
+
this->parent_value_prefix = 0.0;
|
| 82 |
+
this->current_latent_state_index = -1;
|
| 83 |
+
this->batch_index = -1;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
CNode::~CNode() {}
|
| 87 |
+
|
| 88 |
+
void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits)
|
| 89 |
+
{
|
| 90 |
+
/*
|
| 91 |
+
Overview:
|
| 92 |
+
Expand the child nodes of the current node.
|
| 93 |
+
Arguments:
|
| 94 |
+
- to_play: which player to play the game in the current node.
|
| 95 |
+
- current_latent_state_index: the x/first index of hidden state vector of the current node, i.e. the search depth.
|
| 96 |
+
- batch_index: the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``.
|
| 97 |
+
- value_prefix: the value prefix of the current node.
|
| 98 |
+
- policy_logits: the policy logit of the child nodes.
|
| 99 |
+
*/
|
| 100 |
+
this->to_play = to_play;
|
| 101 |
+
this->current_latent_state_index = current_latent_state_index;
|
| 102 |
+
this->batch_index = batch_index;
|
| 103 |
+
this->value_prefix = value_prefix;
|
| 104 |
+
|
| 105 |
+
int action_num = policy_logits.size();
|
| 106 |
+
if (this->legal_actions.size() == 0)
|
| 107 |
+
{
|
| 108 |
+
for (int i = 0; i < action_num; ++i)
|
| 109 |
+
{
|
| 110 |
+
this->legal_actions.push_back(i);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
float temp_policy;
|
| 114 |
+
float policy_sum = 0.0;
|
| 115 |
+
|
| 116 |
+
#ifdef _WIN32
|
| 117 |
+
// 创建动态数组
|
| 118 |
+
float* policy = new float[action_num];
|
| 119 |
+
#else
|
| 120 |
+
float policy[action_num];
|
| 121 |
+
#endif
|
| 122 |
+
|
| 123 |
+
float policy_max = FLOAT_MIN;
|
| 124 |
+
for (auto a : this->legal_actions)
|
| 125 |
+
{
|
| 126 |
+
if (policy_max < policy_logits[a])
|
| 127 |
+
{
|
| 128 |
+
policy_max = policy_logits[a];
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
for (auto a : this->legal_actions)
|
| 133 |
+
{
|
| 134 |
+
temp_policy = exp(policy_logits[a] - policy_max);
|
| 135 |
+
policy_sum += temp_policy;
|
| 136 |
+
policy[a] = temp_policy;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
float prior;
|
| 140 |
+
for (auto a : this->legal_actions)
|
| 141 |
+
{
|
| 142 |
+
prior = policy[a] / policy_sum;
|
| 143 |
+
std::vector<int> tmp_empty;
|
| 144 |
+
this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero
|
| 145 |
+
}
|
| 146 |
+
#ifdef _WIN32
|
| 147 |
+
// 释放数组内存
|
| 148 |
+
delete[] policy;
|
| 149 |
+
#else
|
| 150 |
+
#endif
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
|
| 154 |
+
{
|
| 155 |
+
/*
|
| 156 |
+
Overview:
|
| 157 |
+
Add a noise to the prior of the child nodes.
|
| 158 |
+
Arguments:
|
| 159 |
+
- exploration_fraction: the fraction to add noise.
|
| 160 |
+
- noises: the vector of noises added to each child node.
|
| 161 |
+
*/
|
| 162 |
+
float noise, prior;
|
| 163 |
+
for (int i = 0; i < this->legal_actions.size(); ++i)
|
| 164 |
+
{
|
| 165 |
+
noise = noises[i];
|
| 166 |
+
CNode *child = this->get_child(this->legal_actions[i]);
|
| 167 |
+
|
| 168 |
+
prior = child->prior;
|
| 169 |
+
child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
|
| 174 |
+
{
|
| 175 |
+
/*
|
| 176 |
+
Overview:
|
| 177 |
+
Compute the mean q value of the current node.
|
| 178 |
+
Arguments:
|
| 179 |
+
- isRoot: whether the current node is a root node.
|
| 180 |
+
- parent_q: the q value of the parent node.
|
| 181 |
+
- discount_factor: the discount_factor of reward.
|
| 182 |
+
*/
|
| 183 |
+
float total_unsigned_q = 0.0;
|
| 184 |
+
int total_visits = 0;
|
| 185 |
+
float parent_value_prefix = this->value_prefix;
|
| 186 |
+
for (auto a : this->legal_actions)
|
| 187 |
+
{
|
| 188 |
+
CNode *child = this->get_child(a);
|
| 189 |
+
if (child->visit_count > 0)
|
| 190 |
+
{
|
| 191 |
+
float true_reward = child->value_prefix - parent_value_prefix;
|
| 192 |
+
if (this->is_reset == 1)
|
| 193 |
+
{
|
| 194 |
+
true_reward = child->value_prefix;
|
| 195 |
+
}
|
| 196 |
+
float qsa = true_reward + discount_factor * child->value();
|
| 197 |
+
total_unsigned_q += qsa;
|
| 198 |
+
total_visits += 1;
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
float mean_q = 0.0;
|
| 203 |
+
if (isRoot && total_visits > 0)
|
| 204 |
+
{
|
| 205 |
+
mean_q = (total_unsigned_q) / (total_visits);
|
| 206 |
+
}
|
| 207 |
+
else
|
| 208 |
+
{
|
| 209 |
+
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
|
| 210 |
+
}
|
| 211 |
+
return mean_q;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
void CNode::print_out()
|
| 215 |
+
{
|
| 216 |
+
return;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
int CNode::expanded()
|
| 220 |
+
{
|
| 221 |
+
/*
|
| 222 |
+
Overview:
|
| 223 |
+
Return whether the current node is expanded.
|
| 224 |
+
*/
|
| 225 |
+
return this->children.size() > 0;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
float CNode::value()
|
| 229 |
+
{
|
| 230 |
+
/*
|
| 231 |
+
Overview:
|
| 232 |
+
Return the estimated value of the current tree.
|
| 233 |
+
*/
|
| 234 |
+
float true_value = 0.0;
|
| 235 |
+
if (this->visit_count == 0)
|
| 236 |
+
{
|
| 237 |
+
return true_value;
|
| 238 |
+
}
|
| 239 |
+
else
|
| 240 |
+
{
|
| 241 |
+
true_value = this->value_sum / this->visit_count;
|
| 242 |
+
return true_value;
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
std::vector<int> CNode::get_trajectory()
|
| 247 |
+
{
|
| 248 |
+
/*
|
| 249 |
+
Overview:
|
| 250 |
+
Find the current best trajectory starts from the current node.
|
| 251 |
+
Outputs:
|
| 252 |
+
- traj: a vector of node index, which is the current best trajectory from this node.
|
| 253 |
+
*/
|
| 254 |
+
std::vector<int> traj;
|
| 255 |
+
|
| 256 |
+
CNode *node = this;
|
| 257 |
+
int best_action = node->best_action;
|
| 258 |
+
while (best_action >= 0)
|
| 259 |
+
{
|
| 260 |
+
traj.push_back(best_action);
|
| 261 |
+
|
| 262 |
+
node = node->get_child(best_action);
|
| 263 |
+
best_action = node->best_action;
|
| 264 |
+
}
|
| 265 |
+
return traj;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
std::vector<int> CNode::get_children_distribution()
|
| 269 |
+
{
|
| 270 |
+
/*
|
| 271 |
+
Overview:
|
| 272 |
+
Get the distribution of child nodes in the format of visit_count.
|
| 273 |
+
Outputs:
|
| 274 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 275 |
+
*/
|
| 276 |
+
std::vector<int> distribution;
|
| 277 |
+
if (this->expanded())
|
| 278 |
+
{
|
| 279 |
+
for (auto a : this->legal_actions)
|
| 280 |
+
{
|
| 281 |
+
CNode *child = this->get_child(a);
|
| 282 |
+
distribution.push_back(child->visit_count);
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
return distribution;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
CNode *CNode::get_child(int action)
|
| 289 |
+
{
|
| 290 |
+
/*
|
| 291 |
+
Overview:
|
| 292 |
+
Get the child node corresponding to the input action.
|
| 293 |
+
Arguments:
|
| 294 |
+
- action: the action to get child.
|
| 295 |
+
*/
|
| 296 |
+
return &(this->children[action]);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
//*********************************************************
|
| 300 |
+
|
| 301 |
+
CRoots::CRoots()
|
| 302 |
+
{
|
| 303 |
+
/*
|
| 304 |
+
Overview:
|
| 305 |
+
The initialization of CRoots.
|
| 306 |
+
*/
|
| 307 |
+
this->root_num = 0;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list)
|
| 311 |
+
{
|
| 312 |
+
/*
|
| 313 |
+
Overview:
|
| 314 |
+
The initialization of CRoots with root num and legal action lists.
|
| 315 |
+
Arguments:
|
| 316 |
+
- root_num: the number of the current root.
|
| 317 |
+
- legal_action_list: the vector of the legal action of this root.
|
| 318 |
+
*/
|
| 319 |
+
this->root_num = root_num;
|
| 320 |
+
this->legal_actions_list = legal_actions_list;
|
| 321 |
+
|
| 322 |
+
for (int i = 0; i < root_num; ++i)
|
| 323 |
+
{
|
| 324 |
+
this->roots.push_back(CNode(0, this->legal_actions_list[i]));
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
CRoots::~CRoots() {}
|
| 329 |
+
|
| 330 |
+
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 331 |
+
{
|
| 332 |
+
/*
|
| 333 |
+
Overview:
|
| 334 |
+
Expand the roots and add noises.
|
| 335 |
+
Arguments:
|
| 336 |
+
- root_noise_weight: the exploration fraction of roots
|
| 337 |
+
- noises: the vector of noise add to the roots.
|
| 338 |
+
- value_prefixs: the vector of value prefixs of each root.
|
| 339 |
+
- policies: the vector of policy logits of each root.
|
| 340 |
+
- to_play_batch: the vector of the player side of each root.
|
| 341 |
+
*/
|
| 342 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 343 |
+
{
|
| 344 |
+
this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
|
| 345 |
+
this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
|
| 346 |
+
this->roots[i].visit_count += 1;
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 351 |
+
{
|
| 352 |
+
/*
|
| 353 |
+
Overview:
|
| 354 |
+
Expand the roots without noise.
|
| 355 |
+
Arguments:
|
| 356 |
+
- value_prefixs: the vector of value prefixs of each root.
|
| 357 |
+
- policies: the vector of policy logits of each root.
|
| 358 |
+
- to_play_batch: the vector of the player side of each root.
|
| 359 |
+
*/
|
| 360 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 361 |
+
{
|
| 362 |
+
this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
|
| 363 |
+
this->roots[i].visit_count += 1;
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
void CRoots::clear()
|
| 368 |
+
{
|
| 369 |
+
/*
|
| 370 |
+
Overview:
|
| 371 |
+
Clear the roots vector.
|
| 372 |
+
*/
|
| 373 |
+
this->roots.clear();
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
std::vector<std::vector<int> > CRoots::get_trajectories()
|
| 377 |
+
{
|
| 378 |
+
/*
|
| 379 |
+
Overview:
|
| 380 |
+
Find the current best trajectory starts from each root.
|
| 381 |
+
Outputs:
|
| 382 |
+
- traj: a vector of node index, which is the current best trajectory from each root.
|
| 383 |
+
*/
|
| 384 |
+
std::vector<std::vector<int> > trajs;
|
| 385 |
+
trajs.reserve(this->root_num);
|
| 386 |
+
|
| 387 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 388 |
+
{
|
| 389 |
+
trajs.push_back(this->roots[i].get_trajectory());
|
| 390 |
+
}
|
| 391 |
+
return trajs;
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
std::vector<std::vector<int> > CRoots::get_distributions()
|
| 395 |
+
{
|
| 396 |
+
/*
|
| 397 |
+
Overview:
|
| 398 |
+
Get the children distribution of each root.
|
| 399 |
+
Outputs:
|
| 400 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 401 |
+
*/
|
| 402 |
+
std::vector<std::vector<int> > distributions;
|
| 403 |
+
distributions.reserve(this->root_num);
|
| 404 |
+
|
| 405 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 406 |
+
{
|
| 407 |
+
distributions.push_back(this->roots[i].get_children_distribution());
|
| 408 |
+
}
|
| 409 |
+
return distributions;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
std::vector<float> CRoots::get_values()
|
| 413 |
+
{
|
| 414 |
+
/*
|
| 415 |
+
Overview:
|
| 416 |
+
Return the estimated value of each root.
|
| 417 |
+
*/
|
| 418 |
+
std::vector<float> values;
|
| 419 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 420 |
+
{
|
| 421 |
+
values.push_back(this->roots[i].value());
|
| 422 |
+
}
|
| 423 |
+
return values;
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
//*********************************************************
|
| 427 |
+
//
|
| 428 |
+
void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
|
| 429 |
+
{
|
| 430 |
+
/*
|
| 431 |
+
Overview:
|
| 432 |
+
Update the q value of the root and its child nodes.
|
| 433 |
+
Arguments:
|
| 434 |
+
- root: the root that update q value from.
|
| 435 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 436 |
+
- discount_factor: the discount factor of reward.
|
| 437 |
+
- players: the number of players.
|
| 438 |
+
*/
|
| 439 |
+
std::stack<CNode *> node_stack;
|
| 440 |
+
node_stack.push(root);
|
| 441 |
+
float parent_value_prefix = 0.0;
|
| 442 |
+
int is_reset = 0;
|
| 443 |
+
while (node_stack.size() > 0)
|
| 444 |
+
{
|
| 445 |
+
CNode *node = node_stack.top();
|
| 446 |
+
node_stack.pop();
|
| 447 |
+
|
| 448 |
+
if (node != root)
|
| 449 |
+
{
|
| 450 |
+
// NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 451 |
+
// but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 452 |
+
// true_reward = node.value_prefix - (- parent_value_prefix)
|
| 453 |
+
float true_reward = node->value_prefix - node->parent_value_prefix;
|
| 454 |
+
|
| 455 |
+
if (is_reset == 1)
|
| 456 |
+
{
|
| 457 |
+
true_reward = node->value_prefix;
|
| 458 |
+
}
|
| 459 |
+
float qsa;
|
| 460 |
+
if (players == 1)
|
| 461 |
+
{
|
| 462 |
+
qsa = true_reward + discount_factor * node->value();
|
| 463 |
+
}
|
| 464 |
+
else if (players == 2)
|
| 465 |
+
{
|
| 466 |
+
// TODO(pu): why only the last reward multiply the discount_factor?
|
| 467 |
+
qsa = true_reward + discount_factor * (-1) * node->value();
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
min_max_stats.update(qsa);
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
for (auto a : node->legal_actions)
|
| 474 |
+
{
|
| 475 |
+
CNode *child = node->get_child(a);
|
| 476 |
+
if (child->expanded())
|
| 477 |
+
{
|
| 478 |
+
child->parent_value_prefix = node->value_prefix;
|
| 479 |
+
node_stack.push(child);
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
is_reset = node->is_reset;
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
|
| 488 |
+
{
|
| 489 |
+
/*
|
| 490 |
+
Overview:
|
| 491 |
+
Update the value sum and visit count of nodes along the search path.
|
| 492 |
+
Arguments:
|
| 493 |
+
- search_path: a vector of nodes on the search path.
|
| 494 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 495 |
+
- to_play: which player to play the game in the current node.
|
| 496 |
+
- value: the value to propagate along the search path.
|
| 497 |
+
- discount_factor: the discount factor of reward.
|
| 498 |
+
*/
|
| 499 |
+
assert(to_play == -1 || to_play == 1 || to_play == 2);
|
| 500 |
+
if (to_play == -1)
|
| 501 |
+
{
|
| 502 |
+
// for play-with-bot-mode
|
| 503 |
+
float bootstrap_value = value;
|
| 504 |
+
int path_len = search_path.size();
|
| 505 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 506 |
+
{
|
| 507 |
+
CNode *node = search_path[i];
|
| 508 |
+
node->value_sum += bootstrap_value;
|
| 509 |
+
node->visit_count += 1;
|
| 510 |
+
|
| 511 |
+
float parent_value_prefix = 0.0;
|
| 512 |
+
int is_reset = 0;
|
| 513 |
+
if (i >= 1)
|
| 514 |
+
{
|
| 515 |
+
CNode *parent = search_path[i - 1];
|
| 516 |
+
parent_value_prefix = parent->value_prefix;
|
| 517 |
+
is_reset = parent->is_reset;
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
float true_reward = node->value_prefix - parent_value_prefix;
|
| 521 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 522 |
+
|
| 523 |
+
if (is_reset == 1)
|
| 524 |
+
{
|
| 525 |
+
// parent is reset
|
| 526 |
+
true_reward = node->value_prefix;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 530 |
+
}
|
| 531 |
+
}
|
| 532 |
+
else
|
| 533 |
+
{
|
| 534 |
+
// for self-play-mode
|
| 535 |
+
float bootstrap_value = value;
|
| 536 |
+
int path_len = search_path.size();
|
| 537 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 538 |
+
{
|
| 539 |
+
CNode *node = search_path[i];
|
| 540 |
+
if (node->to_play == to_play)
|
| 541 |
+
{
|
| 542 |
+
node->value_sum += bootstrap_value;
|
| 543 |
+
}
|
| 544 |
+
else
|
| 545 |
+
{
|
| 546 |
+
node->value_sum += -bootstrap_value;
|
| 547 |
+
}
|
| 548 |
+
node->visit_count += 1;
|
| 549 |
+
|
| 550 |
+
float parent_value_prefix = 0.0;
|
| 551 |
+
int is_reset = 0;
|
| 552 |
+
if (i >= 1)
|
| 553 |
+
{
|
| 554 |
+
CNode *parent = search_path[i - 1];
|
| 555 |
+
parent_value_prefix = parent->value_prefix;
|
| 556 |
+
is_reset = parent->is_reset;
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
// NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 560 |
+
// but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 561 |
+
float true_reward = node->value_prefix - parent_value_prefix;
|
| 562 |
+
|
| 563 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 564 |
+
|
| 565 |
+
if (is_reset == 1)
|
| 566 |
+
{
|
| 567 |
+
// parent is reset
|
| 568 |
+
true_reward = node->value_prefix;
|
| 569 |
+
}
|
| 570 |
+
if (node->to_play == to_play)
|
| 571 |
+
{
|
| 572 |
+
bootstrap_value = -true_reward + discount_factor * bootstrap_value;
|
| 573 |
+
}
|
| 574 |
+
else
|
| 575 |
+
{
|
| 576 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 577 |
+
}
|
| 578 |
+
}
|
| 579 |
+
}
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch)
|
| 583 |
+
{
|
| 584 |
+
/*
|
| 585 |
+
Overview:
|
| 586 |
+
Expand the nodes along the search path and update the infos.
|
| 587 |
+
Arguments:
|
| 588 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path.
|
| 589 |
+
- discount_factor: the discount factor of reward.
|
| 590 |
+
- value_prefixs: the value prefixs of nodes along the search path.
|
| 591 |
+
- values: the values to propagate along the search path.
|
| 592 |
+
- policies: the policy logits of nodes along the search path.
|
| 593 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 594 |
+
- results: the search results.
|
| 595 |
+
- is_reset_list: the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset.
|
| 596 |
+
- to_play_batch: the batch of which player is playing on this node.
|
| 597 |
+
*/
|
| 598 |
+
for (int i = 0; i < results.num; ++i)
|
| 599 |
+
{
|
| 600 |
+
results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]);
|
| 601 |
+
// reset
|
| 602 |
+
results.nodes[i]->is_reset = is_reset_list[i];
|
| 603 |
+
|
| 604 |
+
cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
|
| 609 |
+
{
|
| 610 |
+
/*
|
| 611 |
+
Overview:
|
| 612 |
+
Select the child node of the roots according to ucb scores.
|
| 613 |
+
Arguments:
|
| 614 |
+
- root: the roots to select the child node.
|
| 615 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 616 |
+
- pb_c_base: constants c2 in muzero.
|
| 617 |
+
- pb_c_init: constants c1 in muzero.
|
| 618 |
+
- disount_factor: the discount factor of reward.
|
| 619 |
+
- mean_q: the mean q value of the parent node.
|
| 620 |
+
- players: the number of players.
|
| 621 |
+
Outputs:
|
| 622 |
+
- action: the action to select.
|
| 623 |
+
*/
|
| 624 |
+
float max_score = FLOAT_MIN;
|
| 625 |
+
const float epsilon = 0.000001;
|
| 626 |
+
std::vector<int> max_index_lst;
|
| 627 |
+
for (auto a : root->legal_actions)
|
| 628 |
+
{
|
| 629 |
+
CNode *child = root->get_child(a);
|
| 630 |
+
float temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players);
|
| 631 |
+
|
| 632 |
+
if (max_score < temp_score)
|
| 633 |
+
{
|
| 634 |
+
max_score = temp_score;
|
| 635 |
+
|
| 636 |
+
max_index_lst.clear();
|
| 637 |
+
max_index_lst.push_back(a);
|
| 638 |
+
}
|
| 639 |
+
else if (temp_score >= max_score - epsilon)
|
| 640 |
+
{
|
| 641 |
+
max_index_lst.push_back(a);
|
| 642 |
+
}
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
int action = 0;
|
| 646 |
+
if (max_index_lst.size() > 0)
|
| 647 |
+
{
|
| 648 |
+
int rand_index = rand() % max_index_lst.size();
|
| 649 |
+
action = max_index_lst[rand_index];
|
| 650 |
+
}
|
| 651 |
+
return action;
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players)
|
| 655 |
+
{
|
| 656 |
+
/*
|
| 657 |
+
Overview:
|
| 658 |
+
Compute the ucb score of the child.
|
| 659 |
+
Arguments:
|
| 660 |
+
- child: the child node to compute ucb score.
|
| 661 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 662 |
+
- parent_mean_q: the mean q value of the parent node.
|
| 663 |
+
- is_reset: whether the value prefix needs to be reset.
|
| 664 |
+
- total_children_visit_counts: the total visit counts of the child nodes of the parent node.
|
| 665 |
+
- parent_value_prefix: the value prefix of parent node.
|
| 666 |
+
- pb_c_base: constants c2 in muzero.
|
| 667 |
+
- pb_c_init: constants c1 in muzero.
|
| 668 |
+
- disount_factor: the discount factor of reward.
|
| 669 |
+
- players: the number of players.
|
| 670 |
+
Outputs:
|
| 671 |
+
- ucb_value: the ucb score of the child.
|
| 672 |
+
*/
|
| 673 |
+
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
|
| 674 |
+
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
|
| 675 |
+
pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
|
| 676 |
+
|
| 677 |
+
prior_score = pb_c * child->prior;
|
| 678 |
+
if (child->visit_count == 0)
|
| 679 |
+
{
|
| 680 |
+
value_score = parent_mean_q;
|
| 681 |
+
}
|
| 682 |
+
else
|
| 683 |
+
{
|
| 684 |
+
float true_reward = child->value_prefix - parent_value_prefix;
|
| 685 |
+
if (is_reset == 1)
|
| 686 |
+
{
|
| 687 |
+
true_reward = child->value_prefix;
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
if (players == 1)
|
| 691 |
+
{
|
| 692 |
+
value_score = true_reward + discount_factor * child->value();
|
| 693 |
+
}
|
| 694 |
+
else if (players == 2)
|
| 695 |
+
{
|
| 696 |
+
value_score = true_reward + discount_factor * (-child->value());
|
| 697 |
+
}
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
value_score = min_max_stats.normalize(value_score);
|
| 701 |
+
|
| 702 |
+
if (value_score < 0)
|
| 703 |
+
{
|
| 704 |
+
value_score = 0;
|
| 705 |
+
}
|
| 706 |
+
else if (value_score > 1)
|
| 707 |
+
{
|
| 708 |
+
value_score = 1;
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
return prior_score + value_score; // ucb_value
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
|
| 715 |
+
{
|
| 716 |
+
/*
|
| 717 |
+
Overview:
|
| 718 |
+
Search node path from the roots.
|
| 719 |
+
Arguments:
|
| 720 |
+
- roots: the roots that search from.
|
| 721 |
+
- pb_c_base: constants c2 in muzero.
|
| 722 |
+
- pb_c_init: constants c1 in muzero.
|
| 723 |
+
- disount_factor: the discount factor of reward.
|
| 724 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 725 |
+
- results: the search results.
|
| 726 |
+
- virtual_to_play_batch: the batch of which player is playing on this node.
|
| 727 |
+
*/
|
| 728 |
+
// set seed
|
| 729 |
+
get_time_and_set_rand_seed();
|
| 730 |
+
|
| 731 |
+
int last_action = -1;
|
| 732 |
+
float parent_q = 0.0;
|
| 733 |
+
results.search_lens = std::vector<int>();
|
| 734 |
+
|
| 735 |
+
int players = 0;
|
| 736 |
+
int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
|
| 737 |
+
if (largest_element == -1)
|
| 738 |
+
{
|
| 739 |
+
players = 1;
|
| 740 |
+
}
|
| 741 |
+
else
|
| 742 |
+
{
|
| 743 |
+
players = 2;
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
for (int i = 0; i < results.num; ++i)
|
| 747 |
+
{
|
| 748 |
+
CNode *node = &(roots->roots[i]);
|
| 749 |
+
int is_root = 1;
|
| 750 |
+
int search_len = 0;
|
| 751 |
+
results.search_paths[i].push_back(node);
|
| 752 |
+
|
| 753 |
+
while (node->expanded())
|
| 754 |
+
{
|
| 755 |
+
float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
|
| 756 |
+
is_root = 0;
|
| 757 |
+
parent_q = mean_q;
|
| 758 |
+
|
| 759 |
+
int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
|
| 760 |
+
if (players > 1)
|
| 761 |
+
{
|
| 762 |
+
assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
|
| 763 |
+
if (virtual_to_play_batch[i] == 1)
|
| 764 |
+
{
|
| 765 |
+
virtual_to_play_batch[i] = 2;
|
| 766 |
+
}
|
| 767 |
+
else
|
| 768 |
+
{
|
| 769 |
+
virtual_to_play_batch[i] = 1;
|
| 770 |
+
}
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
node->best_action = action;
|
| 774 |
+
// next
|
| 775 |
+
node = node->get_child(action);
|
| 776 |
+
last_action = action;
|
| 777 |
+
results.search_paths[i].push_back(node);
|
| 778 |
+
search_len += 1;
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
|
| 782 |
+
|
| 783 |
+
results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
|
| 784 |
+
results.latent_state_index_in_batch.push_back(parent->batch_index);
|
| 785 |
+
|
| 786 |
+
results.last_actions.push_back(last_action);
|
| 787 |
+
results.search_lens.push_back(search_len);
|
| 788 |
+
results.nodes.push_back(node);
|
| 789 |
+
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
}
|
LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#ifndef CNODE_H
|
| 4 |
+
#define CNODE_H
|
| 5 |
+
|
| 6 |
+
#include "../../common_lib/cminimax.h"
|
| 7 |
+
#include <math.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <stack>
|
| 10 |
+
#include <stdlib.h>
|
| 11 |
+
#include <time.h>
|
| 12 |
+
#include <cmath>
|
| 13 |
+
#include <sys/timeb.h>
|
| 14 |
+
#include <time.h>
|
| 15 |
+
#include <map>
|
| 16 |
+
|
| 17 |
+
const int DEBUG_MODE = 0;
|
| 18 |
+
|
| 19 |
+
namespace tree {
|
| 20 |
+
class CNode {
|
| 21 |
+
public:
|
| 22 |
+
int visit_count, to_play, current_latent_state_index, batch_index, best_action, is_reset;
|
| 23 |
+
float value_prefix, prior, value_sum;
|
| 24 |
+
float parent_value_prefix;
|
| 25 |
+
std::vector<int> children_index;
|
| 26 |
+
std::map<int, CNode> children;
|
| 27 |
+
|
| 28 |
+
std::vector<int> legal_actions;
|
| 29 |
+
|
| 30 |
+
CNode();
|
| 31 |
+
CNode(float prior, std::vector<int> &legal_actions);
|
| 32 |
+
~CNode();
|
| 33 |
+
|
| 34 |
+
void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits);
|
| 35 |
+
void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
|
| 36 |
+
float compute_mean_q(int isRoot, float parent_q, float discount_factor);
|
| 37 |
+
void print_out();
|
| 38 |
+
|
| 39 |
+
int expanded();
|
| 40 |
+
|
| 41 |
+
float value();
|
| 42 |
+
|
| 43 |
+
std::vector<int> get_trajectory();
|
| 44 |
+
std::vector<int> get_children_distribution();
|
| 45 |
+
CNode* get_child(int action);
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
class CRoots{
|
| 49 |
+
public:
|
| 50 |
+
int root_num;
|
| 51 |
+
std::vector<CNode> roots;
|
| 52 |
+
std::vector<std::vector<int> > legal_actions_list;
|
| 53 |
+
|
| 54 |
+
CRoots();
|
| 55 |
+
CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list);
|
| 56 |
+
~CRoots();
|
| 57 |
+
|
| 58 |
+
void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 59 |
+
void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 60 |
+
void clear();
|
| 61 |
+
std::vector<std::vector<int> > get_trajectories();
|
| 62 |
+
std::vector<std::vector<int> > get_distributions();
|
| 63 |
+
std::vector<float> get_values();
|
| 64 |
+
CNode* get_root(int index);
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
class CSearchResults{
|
| 68 |
+
public:
|
| 69 |
+
int num;
|
| 70 |
+
std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
|
| 71 |
+
std::vector<int> virtual_to_play_batchs;
|
| 72 |
+
std::vector<CNode*> nodes;
|
| 73 |
+
std::vector<std::vector<CNode*> > search_paths;
|
| 74 |
+
|
| 75 |
+
CSearchResults();
|
| 76 |
+
CSearchResults(int num);
|
| 77 |
+
~CSearchResults();
|
| 78 |
+
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
//*********************************************************
|
| 83 |
+
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
|
| 84 |
+
void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
|
| 85 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch);
|
| 86 |
+
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
|
| 87 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players);
|
| 88 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#endif
|
LightZero/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp
ADDED
|
@@ -0,0 +1,1154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include "cnode.h"
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <map>
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <random>
|
| 9 |
+
#include <numeric>
|
| 10 |
+
|
| 11 |
+
#ifdef _WIN32
|
| 12 |
+
#include "..\..\common_lib\utils.cpp"
|
| 13 |
+
#else
|
| 14 |
+
#include "../../common_lib/utils.cpp"
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
namespace tree{
|
| 18 |
+
|
| 19 |
+
CSearchResults::CSearchResults()
|
| 20 |
+
{
|
| 21 |
+
/*
|
| 22 |
+
Overview:
|
| 23 |
+
Initialization of CSearchResults, the default result number is set to 0.
|
| 24 |
+
*/
|
| 25 |
+
this->num = 0;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
CSearchResults::CSearchResults(int num)
|
| 29 |
+
{
|
| 30 |
+
/*
|
| 31 |
+
Overview:
|
| 32 |
+
Initialization of CSearchResults with result number.
|
| 33 |
+
*/
|
| 34 |
+
this->num = num;
|
| 35 |
+
for (int i = 0; i < num; ++i)
|
| 36 |
+
{
|
| 37 |
+
this->search_paths.push_back(std::vector<CNode *>());
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
CSearchResults::~CSearchResults(){}
|
| 42 |
+
|
| 43 |
+
//*********************************************************
|
| 44 |
+
|
| 45 |
+
CNode::CNode()
|
| 46 |
+
{
|
| 47 |
+
/*
|
| 48 |
+
Overview:
|
| 49 |
+
Initialization of CNode.
|
| 50 |
+
*/
|
| 51 |
+
this->prior = 0;
|
| 52 |
+
this->legal_actions = legal_actions;
|
| 53 |
+
|
| 54 |
+
this->visit_count = 0;
|
| 55 |
+
this->value_sum = 0;
|
| 56 |
+
this->raw_value = 0; // the value network approximation of value
|
| 57 |
+
this->best_action = -1;
|
| 58 |
+
this->to_play = 0;
|
| 59 |
+
this->reward = 0.0;
|
| 60 |
+
|
| 61 |
+
// gumbel muzero related code
|
| 62 |
+
this->gumbel_scale = 10.0;
|
| 63 |
+
this->gumbel_rng=0.0;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
CNode::CNode(float prior, std::vector<int> &legal_actions)
|
| 67 |
+
{
|
| 68 |
+
/*
|
| 69 |
+
Overview:
|
| 70 |
+
Initialization of CNode with prior value and legal actions.
|
| 71 |
+
Arguments:
|
| 72 |
+
- prior: the prior value of this node.
|
| 73 |
+
- legal_actions: a vector of legal actions of this node.
|
| 74 |
+
*/
|
| 75 |
+
this->prior = prior;
|
| 76 |
+
this->legal_actions = legal_actions;
|
| 77 |
+
|
| 78 |
+
this->visit_count = 0;
|
| 79 |
+
this->value_sum = 0;
|
| 80 |
+
this->raw_value = 0; // the value network approximation of value
|
| 81 |
+
this->best_action = -1;
|
| 82 |
+
this->to_play = 0;
|
| 83 |
+
this->current_latent_state_index = -1;
|
| 84 |
+
this->batch_index = -1;
|
| 85 |
+
|
| 86 |
+
// gumbel muzero related code
|
| 87 |
+
this->gumbel_scale = 10.0;
|
| 88 |
+
this->gumbel_rng=0.0;
|
| 89 |
+
this->gumbel = generate_gumbel(this->gumbel_scale, this->gumbel_rng, legal_actions.size());
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
CNode::~CNode(){}
|
| 93 |
+
|
| 94 |
+
void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, float value, const std::vector<float> &policy_logits)
|
| 95 |
+
{
|
| 96 |
+
/*
|
| 97 |
+
Overview:
|
| 98 |
+
Expand the child nodes of the current node.
|
| 99 |
+
Arguments:
|
| 100 |
+
- to_play: which player to play the game in the current node.
|
| 101 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path of the current node.
|
| 102 |
+
- batch_index: The index of latent state of the leaf node in the search path of the current node.
|
| 103 |
+
- reward: the reward of the current node.
|
| 104 |
+
- value: the value network approximation of current node.
|
| 105 |
+
- policy_logits: the logit of the child nodes.
|
| 106 |
+
*/
|
| 107 |
+
this->to_play = to_play;
|
| 108 |
+
this->current_latent_state_index = current_latent_state_index;
|
| 109 |
+
this->batch_index = batch_index;
|
| 110 |
+
this->reward = reward;
|
| 111 |
+
this->raw_value = value;
|
| 112 |
+
|
| 113 |
+
int action_num = policy_logits.size();
|
| 114 |
+
if (this->legal_actions.size() == 0)
|
| 115 |
+
{
|
| 116 |
+
for (int i = 0; i < action_num; ++i)
|
| 117 |
+
{
|
| 118 |
+
this->legal_actions.push_back(i);
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
float temp_policy;
|
| 122 |
+
float policy_sum = 0.0;
|
| 123 |
+
|
| 124 |
+
#ifdef _WIN32
|
| 125 |
+
// 创建动态数组
|
| 126 |
+
float* policy = new float[action_num];
|
| 127 |
+
#else
|
| 128 |
+
float policy[action_num];
|
| 129 |
+
#endif
|
| 130 |
+
|
| 131 |
+
float policy_max = FLOAT_MIN;
|
| 132 |
+
for(auto a: this->legal_actions){
|
| 133 |
+
if(policy_max < policy_logits[a]){
|
| 134 |
+
policy_max = policy_logits[a];
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
for(auto a: this->legal_actions){
|
| 139 |
+
temp_policy = exp(policy_logits[a] - policy_max);
|
| 140 |
+
policy_sum += temp_policy;
|
| 141 |
+
policy[a] = temp_policy;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
float prior;
|
| 145 |
+
for(auto a: this->legal_actions){
|
| 146 |
+
prior = policy[a] / policy_sum;
|
| 147 |
+
std::vector<int> tmp_empty;
|
| 148 |
+
this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
#ifdef _WIN32
|
| 152 |
+
// 释放数组内存
|
| 153 |
+
delete[] policy;
|
| 154 |
+
#else
|
| 155 |
+
#endif
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
|
| 159 |
+
{
|
| 160 |
+
/*
|
| 161 |
+
Overview:
|
| 162 |
+
Add a noise to the prior of the child nodes.
|
| 163 |
+
Arguments:
|
| 164 |
+
- exploration_fraction: the fraction to add noise.
|
| 165 |
+
- noises: the vector of noises added to each child node.
|
| 166 |
+
*/
|
| 167 |
+
float noise, prior;
|
| 168 |
+
for(int i =0; i<this->legal_actions.size(); ++i){
|
| 169 |
+
noise = noises[i];
|
| 170 |
+
CNode* child = this->get_child(this->legal_actions[i]);
|
| 171 |
+
|
| 172 |
+
prior = child->prior;
|
| 173 |
+
child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
//*********************************************************
|
| 178 |
+
// Gumbel Muzero related code
|
| 179 |
+
//*********************************************************
|
| 180 |
+
|
| 181 |
+
std::vector<float> CNode::get_q(float discount_factor)
|
| 182 |
+
{
|
| 183 |
+
/*
|
| 184 |
+
Overview:
|
| 185 |
+
Compute the q value of the current node.
|
| 186 |
+
Arguments:
|
| 187 |
+
- discount_factor: the discount_factor of reward.
|
| 188 |
+
*/
|
| 189 |
+
std::vector<float> child_value;
|
| 190 |
+
for(auto a: this->legal_actions){
|
| 191 |
+
CNode* child = this->get_child(a);
|
| 192 |
+
float true_reward = child->reward;
|
| 193 |
+
float qsa = true_reward + discount_factor * child->value();
|
| 194 |
+
child_value.push_back(qsa);
|
| 195 |
+
}
|
| 196 |
+
return child_value;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
|
| 200 |
+
{
|
| 201 |
+
/*
|
| 202 |
+
Overview:
|
| 203 |
+
Compute the mean q value of the current node.
|
| 204 |
+
Arguments:
|
| 205 |
+
- isRoot: whether the current node is a root node.
|
| 206 |
+
- parent_q: the q value of the parent node.
|
| 207 |
+
- discount_factor: the discount_factor of reward.
|
| 208 |
+
*/
|
| 209 |
+
float total_unsigned_q = 0.0;
|
| 210 |
+
int total_visits = 0;
|
| 211 |
+
for(auto a: this->legal_actions){
|
| 212 |
+
CNode* child = this->get_child(a);
|
| 213 |
+
if(child->visit_count > 0){
|
| 214 |
+
float true_reward = child->reward;
|
| 215 |
+
float qsa = true_reward + discount_factor * child->value();
|
| 216 |
+
total_unsigned_q += qsa;
|
| 217 |
+
total_visits += 1;
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
float mean_q = 0.0;
|
| 222 |
+
if(isRoot && total_visits > 0){
|
| 223 |
+
mean_q = (total_unsigned_q) / (total_visits);
|
| 224 |
+
}
|
| 225 |
+
else{
|
| 226 |
+
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
|
| 227 |
+
}
|
| 228 |
+
return mean_q;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
void CNode::print_out()
|
| 232 |
+
{
|
| 233 |
+
return;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
int CNode::expanded()
|
| 237 |
+
{
|
| 238 |
+
/*
|
| 239 |
+
Overview:
|
| 240 |
+
Return whether the current node is expanded.
|
| 241 |
+
*/
|
| 242 |
+
return this->children.size() > 0;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
float CNode::value()
|
| 246 |
+
{
|
| 247 |
+
/*
|
| 248 |
+
Overview:
|
| 249 |
+
Return the real value of the current tree.
|
| 250 |
+
*/
|
| 251 |
+
float true_value = 0.0;
|
| 252 |
+
if (this->visit_count == 0)
|
| 253 |
+
{
|
| 254 |
+
return true_value;
|
| 255 |
+
}
|
| 256 |
+
else
|
| 257 |
+
{
|
| 258 |
+
true_value = this->value_sum / this->visit_count;
|
| 259 |
+
return true_value;
|
| 260 |
+
}
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
std::vector<int> CNode::get_trajectory()
|
| 264 |
+
{
|
| 265 |
+
/*
|
| 266 |
+
Overview:
|
| 267 |
+
Find the current best trajectory starts from the current node.
|
| 268 |
+
Outputs:
|
| 269 |
+
- traj: a vector of node index, which is the current best trajectory from this node.
|
| 270 |
+
*/
|
| 271 |
+
std::vector<int> traj;
|
| 272 |
+
|
| 273 |
+
CNode *node = this;
|
| 274 |
+
int best_action = node->best_action;
|
| 275 |
+
while (best_action >= 0)
|
| 276 |
+
{
|
| 277 |
+
traj.push_back(best_action);
|
| 278 |
+
|
| 279 |
+
node = node->get_child(best_action);
|
| 280 |
+
best_action = node->best_action;
|
| 281 |
+
}
|
| 282 |
+
return traj;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
std::vector<int> CNode::get_children_distribution()
|
| 286 |
+
{
|
| 287 |
+
/*
|
| 288 |
+
Overview:
|
| 289 |
+
Get the distribution of child nodes in the format of visit_count.
|
| 290 |
+
Outputs:
|
| 291 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 292 |
+
*/
|
| 293 |
+
std::vector<int> distribution;
|
| 294 |
+
if (this->expanded())
|
| 295 |
+
{
|
| 296 |
+
for (auto a : this->legal_actions)
|
| 297 |
+
{
|
| 298 |
+
CNode *child = this->get_child(a);
|
| 299 |
+
distribution.push_back(child->visit_count);
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
return distribution;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
//*********************************************************
|
| 306 |
+
// Gumbel Muzero related code
|
| 307 |
+
//*********************************************************
|
| 308 |
+
|
| 309 |
+
std::vector<float> CNode::get_children_value(float discount_factor, int action_space_size)
|
| 310 |
+
{
|
| 311 |
+
/*
|
| 312 |
+
Overview:
|
| 313 |
+
Get the completed value of child nodes.
|
| 314 |
+
Outputs:
|
| 315 |
+
- discount_factor: the discount_factor of reward.
|
| 316 |
+
- action_space_size: the size of action space.
|
| 317 |
+
*/
|
| 318 |
+
float infymin = -std::numeric_limits<float>::infinity();
|
| 319 |
+
std::vector<int> child_visit_count;
|
| 320 |
+
std::vector<float> child_prior;
|
| 321 |
+
for(auto a: this->legal_actions){
|
| 322 |
+
CNode* child = this->get_child(a);
|
| 323 |
+
child_visit_count.push_back(child->visit_count);
|
| 324 |
+
child_prior.push_back(child->prior);
|
| 325 |
+
}
|
| 326 |
+
assert(child_visit_count.size()==child_prior.size());
|
| 327 |
+
// compute the completed value
|
| 328 |
+
std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(this, child_visit_count, child_prior, discount_factor);
|
| 329 |
+
std::vector<float> values;
|
| 330 |
+
for (int i=0;i<action_space_size;i++){
|
| 331 |
+
values.push_back(infymin);
|
| 332 |
+
}
|
| 333 |
+
for (int i=0;i<child_prior.size();i++){
|
| 334 |
+
values[this->legal_actions[i]] = completed_qvalues[i];
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
return values;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
CNode *CNode::get_child(int action)
|
| 341 |
+
{
|
| 342 |
+
/*
|
| 343 |
+
Overview:
|
| 344 |
+
Get the child node corresponding to the input action.
|
| 345 |
+
Arguments:
|
| 346 |
+
- action: the action to get child.
|
| 347 |
+
*/
|
| 348 |
+
return &(this->children[action]);
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
//*********************************************************
|
| 352 |
+
// Gumbel Muzero related code
|
| 353 |
+
//*********************************************************
|
| 354 |
+
|
| 355 |
+
std::vector<float> CNode::get_policy(float discount_factor, int action_space_size){
|
| 356 |
+
/*
|
| 357 |
+
Overview:
|
| 358 |
+
Compute the improved policy of the current node.
|
| 359 |
+
Arguments:
|
| 360 |
+
- discount_factor: the discount_factor of reward.
|
| 361 |
+
- action_space_size: the action space size of environment.
|
| 362 |
+
*/
|
| 363 |
+
float infymin = -std::numeric_limits<float>::infinity();
|
| 364 |
+
std::vector<int> child_visit_count;
|
| 365 |
+
std::vector<float> child_prior;
|
| 366 |
+
for(auto a: this->legal_actions){
|
| 367 |
+
CNode* child = this->get_child(a);
|
| 368 |
+
child_visit_count.push_back(child->visit_count);
|
| 369 |
+
child_prior.push_back(child->prior);
|
| 370 |
+
}
|
| 371 |
+
assert(child_visit_count.size()==child_prior.size());
|
| 372 |
+
// compute the completed value
|
| 373 |
+
std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(this, child_visit_count, child_prior, discount_factor);
|
| 374 |
+
std::vector<float> probs;
|
| 375 |
+
for (int i=0;i<action_space_size;i++){
|
| 376 |
+
probs.push_back(infymin);
|
| 377 |
+
}
|
| 378 |
+
for (int i=0;i<child_prior.size();i++){
|
| 379 |
+
probs[this->legal_actions[i]] = child_prior[i] + completed_qvalues[i];
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
csoftmax(probs, probs.size());
|
| 383 |
+
|
| 384 |
+
return probs;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
//*********************************************************
|
| 388 |
+
|
| 389 |
+
CRoots::CRoots()
|
| 390 |
+
{
|
| 391 |
+
/*
|
| 392 |
+
Overview:
|
| 393 |
+
The initialization of CRoots.
|
| 394 |
+
*/
|
| 395 |
+
this->root_num = 0;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list)
|
| 399 |
+
{
|
| 400 |
+
/*
|
| 401 |
+
Overview:
|
| 402 |
+
The initialization of CRoots with root num and legal action lists.
|
| 403 |
+
Arguments:
|
| 404 |
+
- root_num: the number of the current root.
|
| 405 |
+
- legal_action_list: the vector of the legal action of this root.
|
| 406 |
+
*/
|
| 407 |
+
this->root_num = root_num;
|
| 408 |
+
this->legal_actions_list = legal_actions_list;
|
| 409 |
+
|
| 410 |
+
for (int i = 0; i < root_num; ++i)
|
| 411 |
+
{
|
| 412 |
+
this->roots.push_back(CNode(0, this->legal_actions_list[i]));
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
CRoots::~CRoots() {}
|
| 417 |
+
|
| 418 |
+
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 419 |
+
{
|
| 420 |
+
/*
|
| 421 |
+
Overview:
|
| 422 |
+
Expand the roots and add noises.
|
| 423 |
+
Arguments:
|
| 424 |
+
- root_noise_weight: the exploration fraction of roots.
|
| 425 |
+
- noises: the vector of noise add to the roots.
|
| 426 |
+
- rewards: the vector of rewards of each root.
|
| 427 |
+
- values: the vector of values of each root.
|
| 428 |
+
- policies: the vector of policy logits of each root.
|
| 429 |
+
- to_play_batch: the vector of the player side of each root.
|
| 430 |
+
*/
|
| 431 |
+
for(int i = 0; i < this->root_num; ++i){
|
| 432 |
+
this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], values[i], policies[i]);
|
| 433 |
+
this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
|
| 434 |
+
|
| 435 |
+
this->roots[i].visit_count += 1;
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
void CRoots::prepare_no_noise(const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 440 |
+
{
|
| 441 |
+
/*
|
| 442 |
+
Overview:
|
| 443 |
+
Expand the roots without noise.
|
| 444 |
+
Arguments:
|
| 445 |
+
- rewards: the vector of rewards of each root.
|
| 446 |
+
- values: the vector of values of each root.
|
| 447 |
+
- policies: the vector of policy logits of each root.
|
| 448 |
+
- to_play_batch: the vector of the player side of each root.
|
| 449 |
+
*/
|
| 450 |
+
for(int i = 0; i < this->root_num; ++i){
|
| 451 |
+
this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], values[i], policies[i]);
|
| 452 |
+
|
| 453 |
+
this->roots[i].visit_count += 1;
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
void CRoots::clear()
|
| 458 |
+
{
|
| 459 |
+
/*
|
| 460 |
+
Overview:
|
| 461 |
+
Clear the roots vector.
|
| 462 |
+
*/
|
| 463 |
+
this->roots.clear();
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
std::vector<std::vector<int> > CRoots::get_trajectories()
|
| 467 |
+
{
|
| 468 |
+
/*
|
| 469 |
+
Overview:
|
| 470 |
+
Find the current best trajectory starts from each root.
|
| 471 |
+
Outputs:
|
| 472 |
+
- traj: a vector of node index, which is the current best trajectory from each root.
|
| 473 |
+
*/
|
| 474 |
+
std::vector<std::vector<int> > trajs;
|
| 475 |
+
trajs.reserve(this->root_num);
|
| 476 |
+
|
| 477 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 478 |
+
{
|
| 479 |
+
trajs.push_back(this->roots[i].get_trajectory());
|
| 480 |
+
}
|
| 481 |
+
return trajs;
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
std::vector<std::vector<int> > CRoots::get_distributions()
|
| 485 |
+
{
|
| 486 |
+
/*
|
| 487 |
+
Overview:
|
| 488 |
+
Get the children distribution of each root.
|
| 489 |
+
Outputs:
|
| 490 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 491 |
+
*/
|
| 492 |
+
std::vector<std::vector<int> > distributions;
|
| 493 |
+
distributions.reserve(this->root_num);
|
| 494 |
+
|
| 495 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 496 |
+
{
|
| 497 |
+
distributions.push_back(this->roots[i].get_children_distribution());
|
| 498 |
+
}
|
| 499 |
+
return distributions;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
//*********************************************************
|
| 503 |
+
// Gumbel Muzero related code
|
| 504 |
+
//*********************************************************
|
| 505 |
+
|
| 506 |
+
std::vector<std::vector<float> > CRoots::get_children_values(float discount_factor, int action_space_size)
|
| 507 |
+
{
|
| 508 |
+
/*
|
| 509 |
+
Overview:
|
| 510 |
+
Compute the completed value of each root.
|
| 511 |
+
Arguments:
|
| 512 |
+
- discount_factor: the discount_factor of reward.
|
| 513 |
+
- action_space_size: the action space size of environment.
|
| 514 |
+
*/
|
| 515 |
+
std::vector<std::vector<float> > values;
|
| 516 |
+
values.reserve(this->root_num);
|
| 517 |
+
|
| 518 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 519 |
+
{
|
| 520 |
+
values.push_back(this->roots[i].get_children_value(discount_factor, action_space_size));
|
| 521 |
+
}
|
| 522 |
+
return values;
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
std::vector<std::vector<float> > CRoots::get_policies(float discount_factor, int action_space_size)
|
| 526 |
+
{
|
| 527 |
+
/*
|
| 528 |
+
Overview:
|
| 529 |
+
Compute the improved policy of each root.
|
| 530 |
+
Arguments:
|
| 531 |
+
- discount_factor: the discount_factor of reward.
|
| 532 |
+
- action_space_size: the action space size of environment.
|
| 533 |
+
*/
|
| 534 |
+
std::vector<std::vector<float> > probs;
|
| 535 |
+
probs.reserve(this->root_num);
|
| 536 |
+
|
| 537 |
+
for(int i = 0; i < this->root_num; ++i){
|
| 538 |
+
probs.push_back(this->roots[i].get_policy(discount_factor, action_space_size));
|
| 539 |
+
}
|
| 540 |
+
return probs;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
std::vector<float> CRoots::get_values()
|
| 544 |
+
{
|
| 545 |
+
/*
|
| 546 |
+
Overview:
|
| 547 |
+
Return the real value of each root.
|
| 548 |
+
*/
|
| 549 |
+
std::vector<float> values;
|
| 550 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 551 |
+
{
|
| 552 |
+
values.push_back(this->roots[i].value());
|
| 553 |
+
}
|
| 554 |
+
return values;
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
//*********************************************************
|
| 558 |
+
//
|
| 559 |
+
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
|
| 560 |
+
{
|
| 561 |
+
/*
|
| 562 |
+
Overview:
|
| 563 |
+
Update the q value of the root and its child nodes.
|
| 564 |
+
Arguments:
|
| 565 |
+
- root: the root that update q value from.
|
| 566 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 567 |
+
- discount_factor: the discount factor of reward.
|
| 568 |
+
- players: the number of players.
|
| 569 |
+
*/
|
| 570 |
+
std::stack<CNode*> node_stack;
|
| 571 |
+
node_stack.push(root);
|
| 572 |
+
// float parent_value_prefix = 0.0;
|
| 573 |
+
while(node_stack.size() > 0){
|
| 574 |
+
CNode* node = node_stack.top();
|
| 575 |
+
node_stack.pop();
|
| 576 |
+
|
| 577 |
+
if(node != root){
|
| 578 |
+
// # NOTE: in 2 player mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 579 |
+
// # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 580 |
+
// # true_reward = node.value_prefix - (- parent_value_prefix)
|
| 581 |
+
// float true_reward = node->value_prefix - node->parent_value_prefix;
|
| 582 |
+
float true_reward = node->reward;
|
| 583 |
+
|
| 584 |
+
float qsa;
|
| 585 |
+
if(players == 1)
|
| 586 |
+
qsa = true_reward + discount_factor * node->value();
|
| 587 |
+
else if(players == 2)
|
| 588 |
+
// TODO(pu):
|
| 589 |
+
qsa = true_reward + discount_factor * (-1) * node->value();
|
| 590 |
+
|
| 591 |
+
min_max_stats.update(qsa);
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
for(auto a: node->legal_actions){
|
| 595 |
+
CNode* child = node->get_child(a);
|
| 596 |
+
if(child->expanded()){
|
| 597 |
+
// child->parent_value_prefix = node->value_prefix;
|
| 598 |
+
node_stack.push(child);
|
| 599 |
+
}
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
}
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
void cback_propagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
|
| 606 |
+
{
|
| 607 |
+
/*
|
| 608 |
+
Overview:
|
| 609 |
+
Update the value sum and visit count of nodes along the search path.
|
| 610 |
+
Arguments:
|
| 611 |
+
- search_path: a vector of nodes on the search path.
|
| 612 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 613 |
+
- to_play: which player to play the game in the current node.
|
| 614 |
+
- value: the value to propagate along the search path.
|
| 615 |
+
- discount_factor: the discount factor of reward.
|
| 616 |
+
*/
|
| 617 |
+
assert(to_play == -1);
|
| 618 |
+
float bootstrap_value = value;
|
| 619 |
+
int path_len = search_path.size();
|
| 620 |
+
for(int i = path_len - 1; i >= 0; --i){
|
| 621 |
+
CNode* node = search_path[i];
|
| 622 |
+
node->value_sum += bootstrap_value;
|
| 623 |
+
node->visit_count += 1;
|
| 624 |
+
|
| 625 |
+
float true_reward = node->reward;
|
| 626 |
+
|
| 627 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 628 |
+
|
| 629 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
void cbatch_back_propagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch)
|
| 634 |
+
{
|
| 635 |
+
/*
|
| 636 |
+
Overview:
|
| 637 |
+
Expand the nodes along the search path and update the infos.
|
| 638 |
+
Arguments:
|
| 639 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path.
|
| 640 |
+
- discount_factor: the discount factor of reward.
|
| 641 |
+
- value_prefixs: the value prefixs of nodes along the search path.
|
| 642 |
+
- values: the values to propagate along the search path.
|
| 643 |
+
- policies: the policy logits of nodes along the search path.
|
| 644 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 645 |
+
- results: the search results.
|
| 646 |
+
- to_play_batch: the batch of which player is playing on this node.
|
| 647 |
+
*/
|
| 648 |
+
for(int i = 0; i < results.num; ++i){
|
| 649 |
+
results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], values[i], policies[i]);
|
| 650 |
+
cback_propagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
|
| 651 |
+
}
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
|
| 655 |
+
{
|
| 656 |
+
/*
|
| 657 |
+
Overview:
|
| 658 |
+
Select the child node of the roots according to ucb scores.
|
| 659 |
+
Arguments:
|
| 660 |
+
- root: the roots to select the child node.
|
| 661 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 662 |
+
- pb_c_base: constants c2 in muzero.
|
| 663 |
+
- pb_c_init: constants c1 in muzero.
|
| 664 |
+
- disount_factor: the discount factor of reward.
|
| 665 |
+
- mean_q: the mean q value of the parent node.
|
| 666 |
+
- players: the number of players.
|
| 667 |
+
Outputs:
|
| 668 |
+
- action: the action to select.
|
| 669 |
+
*/
|
| 670 |
+
float max_score = FLOAT_MIN;
|
| 671 |
+
const float epsilon = 0.000001;
|
| 672 |
+
std::vector<int> max_index_lst;
|
| 673 |
+
for(auto a: root->legal_actions){
|
| 674 |
+
|
| 675 |
+
CNode* child = root->get_child(a);
|
| 676 |
+
float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
|
| 677 |
+
|
| 678 |
+
if(max_score < temp_score){
|
| 679 |
+
max_score = temp_score;
|
| 680 |
+
|
| 681 |
+
max_index_lst.clear();
|
| 682 |
+
max_index_lst.push_back(a);
|
| 683 |
+
}
|
| 684 |
+
else if(temp_score >= max_score - epsilon){
|
| 685 |
+
max_index_lst.push_back(a);
|
| 686 |
+
}
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
int action = 0;
|
| 690 |
+
if(max_index_lst.size() > 0){
|
| 691 |
+
int rand_index = rand() % max_index_lst.size();
|
| 692 |
+
action = max_index_lst[rand_index];
|
| 693 |
+
}
|
| 694 |
+
return action;
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
//*********************************************************
|
| 698 |
+
// Gumbel Muzero related code
|
| 699 |
+
//*********************************************************
|
| 700 |
+
|
| 701 |
+
int cselect_root_child(CNode* root, float discount_factor, int num_simulations, int max_num_considered_actions)
|
| 702 |
+
{
|
| 703 |
+
/*
|
| 704 |
+
Overview:
|
| 705 |
+
Select the child node of the roots in gumbel muzero.
|
| 706 |
+
Arguments:
|
| 707 |
+
- root: the roots to select the child node.
|
| 708 |
+
- disount_factor: the discount factor of reward.
|
| 709 |
+
- num_simulations: the upper limit number of simulations.
|
| 710 |
+
- max_num_considered_actions: the maximum number of considered actions.
|
| 711 |
+
Outputs:
|
| 712 |
+
- action: the action to select.
|
| 713 |
+
*/
|
| 714 |
+
std::vector<int> child_visit_count;
|
| 715 |
+
std::vector<float> child_prior;
|
| 716 |
+
for(auto a: root->legal_actions){
|
| 717 |
+
CNode* child = root->get_child(a);
|
| 718 |
+
child_visit_count.push_back(child->visit_count);
|
| 719 |
+
child_prior.push_back(child->prior);
|
| 720 |
+
}
|
| 721 |
+
assert(child_visit_count.size()==child_prior.size());
|
| 722 |
+
|
| 723 |
+
std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(root, child_visit_count, child_prior, discount_factor);
|
| 724 |
+
std::vector<std::vector<int> > visit_table = get_table_of_considered_visits(max_num_considered_actions, num_simulations);
|
| 725 |
+
|
| 726 |
+
int num_valid_actions = root->legal_actions.size();
|
| 727 |
+
int num_considered = std::min(max_num_considered_actions, num_simulations);
|
| 728 |
+
int simulation_index = std::accumulate(child_visit_count.begin(), child_visit_count.end(), 0);
|
| 729 |
+
int considered_visit = visit_table[num_considered][simulation_index];
|
| 730 |
+
|
| 731 |
+
std::vector<float> score = score_considered(considered_visit, root->gumbel, child_prior, completed_qvalues, child_visit_count);
|
| 732 |
+
|
| 733 |
+
float argmax = -std::numeric_limits<float>::infinity();
|
| 734 |
+
int max_action = root->legal_actions[0];
|
| 735 |
+
int index = 0;
|
| 736 |
+
for(auto a: root->legal_actions){
|
| 737 |
+
if(score[index] > argmax){
|
| 738 |
+
argmax = score[index];
|
| 739 |
+
max_action = a;
|
| 740 |
+
}
|
| 741 |
+
index += 1;
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
return max_action;
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
int cselect_interior_child(CNode* root, float discount_factor)
|
| 748 |
+
{
|
| 749 |
+
/*
|
| 750 |
+
Overview:
|
| 751 |
+
Select the child node of the interior node in gumbel muzero.
|
| 752 |
+
Arguments:
|
| 753 |
+
- root: the roots to select the child node.
|
| 754 |
+
- disount_factor: the discount factor of reward.
|
| 755 |
+
Outputs:
|
| 756 |
+
- action: the action to select.
|
| 757 |
+
*/
|
| 758 |
+
std::vector<int> child_visit_count;
|
| 759 |
+
std::vector<float> child_prior;
|
| 760 |
+
for(auto a: root->legal_actions){
|
| 761 |
+
CNode* child = root->get_child(a);
|
| 762 |
+
child_visit_count.push_back(child->visit_count);
|
| 763 |
+
child_prior.push_back(child->prior);
|
| 764 |
+
}
|
| 765 |
+
assert(child_visit_count.size()==child_prior.size());
|
| 766 |
+
std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(root, child_visit_count, child_prior, discount_factor);
|
| 767 |
+
std::vector<float> probs;
|
| 768 |
+
for (int i=0;i<child_prior.size();i++){
|
| 769 |
+
probs.push_back(child_prior[i] + completed_qvalues[i]);
|
| 770 |
+
}
|
| 771 |
+
csoftmax(probs, probs.size());
|
| 772 |
+
int visit_count_sum = std::accumulate(child_visit_count.begin(), child_visit_count.end(), 0);
|
| 773 |
+
std::vector<float> to_argmax;
|
| 774 |
+
for (int i=0;i<probs.size();i++){
|
| 775 |
+
to_argmax.push_back(probs[i] - (float)child_visit_count[i]/(float)(1+visit_count_sum));
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
float argmax = -std::numeric_limits<float>::infinity();
|
| 779 |
+
int max_action = root->legal_actions[0];
|
| 780 |
+
int index = 0;
|
| 781 |
+
for(auto a: root->legal_actions){
|
| 782 |
+
if(to_argmax[index] > argmax){
|
| 783 |
+
argmax = to_argmax[index];
|
| 784 |
+
max_action = a;
|
| 785 |
+
}
|
| 786 |
+
index += 1;
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
return max_action;
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
|
| 793 |
+
{
|
| 794 |
+
/*
|
| 795 |
+
Overview:
|
| 796 |
+
Compute the ucb score of the child.
|
| 797 |
+
Arguments:
|
| 798 |
+
- child: the child node to compute ucb score.
|
| 799 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 800 |
+
- mean_q: the mean q value of the parent node.
|
| 801 |
+
- total_children_visit_counts: the total visit counts of the child nodes of the parent node.
|
| 802 |
+
- pb_c_base: constants c2 in muzero.
|
| 803 |
+
- pb_c_init: constants c1 in muzero.
|
| 804 |
+
- disount_factor: the discount factor of reward.
|
| 805 |
+
- players: the number of players.
|
| 806 |
+
Outputs:
|
| 807 |
+
- ucb_value: the ucb score of the child.
|
| 808 |
+
*/
|
| 809 |
+
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
|
| 810 |
+
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
|
| 811 |
+
pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
|
| 812 |
+
|
| 813 |
+
prior_score = pb_c * child->prior;
|
| 814 |
+
if (child->visit_count == 0){
|
| 815 |
+
value_score = parent_mean_q;
|
| 816 |
+
}
|
| 817 |
+
else {
|
| 818 |
+
float true_reward = child->reward;
|
| 819 |
+
if(players == 1)
|
| 820 |
+
value_score = true_reward + discount_factor * child->value();
|
| 821 |
+
else if(players == 2)
|
| 822 |
+
value_score = true_reward + discount_factor * (-child->value());
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
value_score = min_max_stats.normalize(value_score);
|
| 826 |
+
|
| 827 |
+
if (value_score < 0) value_score = 0;
|
| 828 |
+
if (value_score > 1) value_score = 1;
|
| 829 |
+
|
| 830 |
+
float ucb_value = prior_score + value_score;
|
| 831 |
+
return ucb_value;
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
void cbatch_traverse(CRoots *roots, int num_simulations, int max_num_considered_actions, float discount_factor, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
|
| 835 |
+
{
|
| 836 |
+
/*
|
| 837 |
+
Overview:
|
| 838 |
+
Search node path from the roots.
|
| 839 |
+
Arguments:
|
| 840 |
+
- roots: the roots that search from.
|
| 841 |
+
- num_simulations: the upper limit number of simulations.
|
| 842 |
+
- max_num_considered_actions: the maximum number of considered actions.
|
| 843 |
+
- disount_factor: the discount factor of reward.
|
| 844 |
+
- results: the search results.
|
| 845 |
+
- virtual_to_play_batch: the batch of which player is playing on this node.
|
| 846 |
+
*/
|
| 847 |
+
// set seed
|
| 848 |
+
timeval t1;
|
| 849 |
+
gettimeofday(&t1, NULL);
|
| 850 |
+
srand(t1.tv_usec);
|
| 851 |
+
|
| 852 |
+
int last_action = -1;
|
| 853 |
+
float parent_q = 0.0;
|
| 854 |
+
results.search_lens = std::vector<int>();
|
| 855 |
+
|
| 856 |
+
int players = 0;
|
| 857 |
+
int largest_element = *max_element(virtual_to_play_batch.begin(),virtual_to_play_batch.end()); // 0 or 2
|
| 858 |
+
if(largest_element==-1)
|
| 859 |
+
players = 1;
|
| 860 |
+
else
|
| 861 |
+
players = 2;
|
| 862 |
+
|
| 863 |
+
for(int i = 0; i < results.num; ++i){
|
| 864 |
+
CNode *node = &(roots->roots[i]);
|
| 865 |
+
int is_root = 1;
|
| 866 |
+
int search_len = 0;
|
| 867 |
+
int action = 0;
|
| 868 |
+
results.search_paths[i].push_back(node);
|
| 869 |
+
|
| 870 |
+
while(node->expanded()){
|
| 871 |
+
if(is_root){
|
| 872 |
+
action = cselect_root_child(node, discount_factor, num_simulations, max_num_considered_actions);
|
| 873 |
+
}
|
| 874 |
+
else{
|
| 875 |
+
action = cselect_interior_child(node, discount_factor);
|
| 876 |
+
}
|
| 877 |
+
is_root = 0;
|
| 878 |
+
|
| 879 |
+
node->best_action = action;
|
| 880 |
+
// next
|
| 881 |
+
node = node->get_child(action);
|
| 882 |
+
last_action = action;
|
| 883 |
+
results.search_paths[i].push_back(node);
|
| 884 |
+
search_len += 1;
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
CNode* parent = results.search_paths[i][results.search_paths[i].size() - 2];
|
| 888 |
+
|
| 889 |
+
results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
|
| 890 |
+
results.latent_state_index_in_batch.push_back(parent->batch_index);
|
| 891 |
+
|
| 892 |
+
results.last_actions.push_back(last_action);
|
| 893 |
+
results.search_lens.push_back(search_len);
|
| 894 |
+
results.nodes.push_back(node);
|
| 895 |
+
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
|
| 896 |
+
|
| 897 |
+
}
|
| 898 |
+
}
|
| 899 |
+
|
| 900 |
+
//*********************************************************
|
| 901 |
+
// Gumbel Muzero related code
|
| 902 |
+
//*********************************************************
|
| 903 |
+
|
| 904 |
+
void csoftmax(std::vector<float> &input, int input_len)
|
| 905 |
+
{
|
| 906 |
+
/*
|
| 907 |
+
Overview:
|
| 908 |
+
Softmax transformation.
|
| 909 |
+
Arguments:
|
| 910 |
+
- input: the vector to be transformed.
|
| 911 |
+
- input_len: the length of input vector.
|
| 912 |
+
*/
|
| 913 |
+
assert (input != NULL);
|
| 914 |
+
assert (input_len != 0);
|
| 915 |
+
int i;
|
| 916 |
+
float m;
|
| 917 |
+
// Find maximum value from input array
|
| 918 |
+
m = input[0];
|
| 919 |
+
for (i = 1; i < input_len; i++) {
|
| 920 |
+
if (input[i] > m) {
|
| 921 |
+
m = input[i];
|
| 922 |
+
}
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
float sum = 0;
|
| 926 |
+
for (i = 0; i < input_len; i++) {
|
| 927 |
+
sum += expf(input[i]-m);
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
for (i = 0; i < input_len; i++) {
|
| 931 |
+
input[i] = expf(input[i] - m - log(sum));
|
| 932 |
+
}
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
float compute_mixed_value(float raw_value, std::vector<float> q_values, std::vector<int> &child_visit, std::vector<float> &child_prior)
|
| 936 |
+
{
|
| 937 |
+
/*
|
| 938 |
+
Overview:
|
| 939 |
+
Compute the mixed Q value.
|
| 940 |
+
Arguments:
|
| 941 |
+
- raw_value: the approximated value of the current node from the value network.
|
| 942 |
+
- q_value: the q value of the current node.
|
| 943 |
+
- child_visit: the visit counts of the child nodes.
|
| 944 |
+
- child_prior: the prior of the child nodes.
|
| 945 |
+
Outputs:
|
| 946 |
+
- mixed Q value.
|
| 947 |
+
*/
|
| 948 |
+
float visit_count_sum = 0.0;
|
| 949 |
+
float probs_sum = 0.0;
|
| 950 |
+
float weighted_q_sum = 0.0;
|
| 951 |
+
float min_num = -10e7;
|
| 952 |
+
|
| 953 |
+
for(unsigned int i = 0;i < child_visit.size();i++)
|
| 954 |
+
visit_count_sum += child_visit[i];
|
| 955 |
+
|
| 956 |
+
for(unsigned int i = 0;i < child_prior.size();i++)
|
| 957 |
+
// Ensuring non-nan prior
|
| 958 |
+
child_prior[i] = std::max(child_prior[i], min_num);
|
| 959 |
+
|
| 960 |
+
for(unsigned int i = 0;i < child_prior.size();i++)
|
| 961 |
+
if (child_visit[i] > 0)
|
| 962 |
+
probs_sum += child_prior[i];
|
| 963 |
+
|
| 964 |
+
for (unsigned int i = 0;i < child_prior.size();i++)
|
| 965 |
+
if (child_visit[i] > 0){
|
| 966 |
+
weighted_q_sum += child_prior[i] * q_values[i] / probs_sum;
|
| 967 |
+
}
|
| 968 |
+
|
| 969 |
+
return (raw_value + visit_count_sum * weighted_q_sum) / (visit_count_sum+1);
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
void rescale_qvalues(std::vector<float> &value, float epsilon){
|
| 973 |
+
/*
|
| 974 |
+
Overview:
|
| 975 |
+
Rescale the q value with max-min normalization.
|
| 976 |
+
Arguments:
|
| 977 |
+
- value: the value vector to be rescaled.
|
| 978 |
+
- epsilon: the lower limit of gap.
|
| 979 |
+
*/
|
| 980 |
+
float max_value = *max_element(value.begin(), value.end());
|
| 981 |
+
float min_value = *min_element(value.begin(), value.end());
|
| 982 |
+
float gap = max_value - min_value;
|
| 983 |
+
gap = std::max(gap, epsilon);
|
| 984 |
+
for (unsigned int i = 0;i < value.size();i++){
|
| 985 |
+
value[i] = (value[i]-min_value)/gap;
|
| 986 |
+
}
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
std::vector<float> qtransform_completed_by_mix_value(CNode *root, std::vector<int> & child_visit, \
|
| 990 |
+
std::vector<float> & child_prior, float discount_factor, float maxvisit_init, float value_scale, \
|
| 991 |
+
bool rescale_values, float epsilon)
|
| 992 |
+
{
|
| 993 |
+
/*
|
| 994 |
+
Overview:
|
| 995 |
+
Calculate the q value with mixed value.
|
| 996 |
+
Arguments:
|
| 997 |
+
- root: the roots that search from.
|
| 998 |
+
- child_visit: the visit counts of the child nodes.
|
| 999 |
+
- child_prior: the prior of the child nodes.
|
| 1000 |
+
- discount_factor: the discount factor of reward.
|
| 1001 |
+
- maxvisit_init: the init of the maximization of visit counts.
|
| 1002 |
+
- value_cale: the scale of value.
|
| 1003 |
+
- rescale_values: whether to rescale the values.
|
| 1004 |
+
- epsilon: the lower limit of gap in max-min normalization
|
| 1005 |
+
Outputs:
|
| 1006 |
+
- completed Q value.
|
| 1007 |
+
*/
|
| 1008 |
+
assert (child_visit.size() == child_prior.size());
|
| 1009 |
+
std::vector<float> qvalues;
|
| 1010 |
+
std::vector<float> child_prior_tmp;
|
| 1011 |
+
|
| 1012 |
+
child_prior_tmp.assign(child_prior.begin(), child_prior.end());
|
| 1013 |
+
qvalues = root->get_q(discount_factor);
|
| 1014 |
+
csoftmax(child_prior_tmp, child_prior_tmp.size());
|
| 1015 |
+
// TODO: should be raw_value here
|
| 1016 |
+
float value = compute_mixed_value(root->raw_value, qvalues, child_visit, child_prior_tmp);
|
| 1017 |
+
std::vector<float> completed_qvalue;
|
| 1018 |
+
|
| 1019 |
+
for (unsigned int i = 0;i < child_prior_tmp.size();i++){
|
| 1020 |
+
if (child_visit[i] > 0){
|
| 1021 |
+
completed_qvalue.push_back(qvalues[i]);
|
| 1022 |
+
}
|
| 1023 |
+
else{
|
| 1024 |
+
completed_qvalue.push_back(value);
|
| 1025 |
+
}
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
if (rescale_values){
|
| 1029 |
+
rescale_qvalues(completed_qvalue, epsilon);
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
float max_visit = *max_element(child_visit.begin(), child_visit.end());
|
| 1033 |
+
float visit_scale = maxvisit_init + max_visit;
|
| 1034 |
+
|
| 1035 |
+
for (unsigned int i=0;i < completed_qvalue.size();i++){
|
| 1036 |
+
completed_qvalue[i] = completed_qvalue[i] * visit_scale * value_scale;
|
| 1037 |
+
}
|
| 1038 |
+
return completed_qvalue;
|
| 1039 |
+
|
| 1040 |
+
}
|
| 1041 |
+
|
| 1042 |
+
std::vector<int> get_sequence_of_considered_visits(int max_num_considered_actions, int num_simulations)
|
| 1043 |
+
{
|
| 1044 |
+
/*
|
| 1045 |
+
Overview:
|
| 1046 |
+
Calculate the considered visit sequence.
|
| 1047 |
+
Arguments:
|
| 1048 |
+
- max_num_considered_actions: the maximum number of considered actions.
|
| 1049 |
+
- num_simulations: the upper limit number of simulations.
|
| 1050 |
+
Outputs:
|
| 1051 |
+
- the considered visit sequence.
|
| 1052 |
+
*/
|
| 1053 |
+
std::vector<int> visit_seq;
|
| 1054 |
+
if(max_num_considered_actions <= 1){
|
| 1055 |
+
for (int i=0;i < num_simulations;i++)
|
| 1056 |
+
visit_seq.push_back(i);
|
| 1057 |
+
return visit_seq;
|
| 1058 |
+
}
|
| 1059 |
+
|
| 1060 |
+
int log2max = std::ceil(std::log2(max_num_considered_actions));
|
| 1061 |
+
std::vector<int> visits;
|
| 1062 |
+
for (int i = 0;i < max_num_considered_actions;i++)
|
| 1063 |
+
visits.push_back(0);
|
| 1064 |
+
int num_considered = max_num_considered_actions;
|
| 1065 |
+
while (visit_seq.size() < num_simulations){
|
| 1066 |
+
int num_extra_visits = std::max(1, (int)(num_simulations / (log2max * num_considered)));
|
| 1067 |
+
for (int i = 0;i < num_extra_visits;i++){
|
| 1068 |
+
visit_seq.insert(visit_seq.end(), visits.begin(), visits.begin() + num_considered);
|
| 1069 |
+
for (int j = 0;j < num_considered;j++)
|
| 1070 |
+
visits[j] += 1;
|
| 1071 |
+
}
|
| 1072 |
+
num_considered = std::max(2, num_considered/2);
|
| 1073 |
+
}
|
| 1074 |
+
std::vector<int> visit_seq_slice;
|
| 1075 |
+
visit_seq_slice.assign(visit_seq.begin(), visit_seq.begin() + num_simulations);
|
| 1076 |
+
return visit_seq_slice;
|
| 1077 |
+
}
|
| 1078 |
+
|
| 1079 |
+
std::vector<std::vector<int> > get_table_of_considered_visits(int max_num_considered_actions, int num_simulations)
|
| 1080 |
+
{
|
| 1081 |
+
/*
|
| 1082 |
+
Overview:
|
| 1083 |
+
Calculate the table of considered visits.
|
| 1084 |
+
Arguments:
|
| 1085 |
+
- max_num_considered_actions: the maximum number of considered actions.
|
| 1086 |
+
- num_simulations: the upper limit number of simulations.
|
| 1087 |
+
Outputs:
|
| 1088 |
+
- the table of considered visits.
|
| 1089 |
+
*/
|
| 1090 |
+
std::vector<std::vector<int> > table;
|
| 1091 |
+
for (int m=0;m < max_num_considered_actions+1;m++){
|
| 1092 |
+
table.push_back(get_sequence_of_considered_visits(m, num_simulations));
|
| 1093 |
+
}
|
| 1094 |
+
return table;
|
| 1095 |
+
}
|
| 1096 |
+
|
| 1097 |
+
std::vector<float> score_considered(int considered_visit, std::vector<float> gumbel, std::vector<float> logits, std::vector<float> normalized_qvalues, std::vector<int> visit_counts)
|
| 1098 |
+
{
|
| 1099 |
+
/*
|
| 1100 |
+
Overview:
|
| 1101 |
+
Calculate the score of nodes to be considered according to the considered visit.
|
| 1102 |
+
Arguments:
|
| 1103 |
+
- considered_visit: the visit counts of node to be considered.
|
| 1104 |
+
- gumbel: the gumbel vector.
|
| 1105 |
+
- logits: the logits vector of child nodes.
|
| 1106 |
+
- normalized_qvalues: the normalized Q values of child nodes.
|
| 1107 |
+
- visit_counts: the visit counts of child nodes.
|
| 1108 |
+
Outputs:
|
| 1109 |
+
- the score of nodes to be considered.
|
| 1110 |
+
*/
|
| 1111 |
+
float low_logit = -1e9;
|
| 1112 |
+
float max_logit = *max_element(logits.begin(), logits.end());
|
| 1113 |
+
for (unsigned int i=0;i < logits.size();i++){
|
| 1114 |
+
logits[i] -= max_logit;
|
| 1115 |
+
}
|
| 1116 |
+
std::vector<float> penalty;
|
| 1117 |
+
for (unsigned int i=0;i < visit_counts.size();i++){
|
| 1118 |
+
// Only consider the nodes with specific visit counts
|
| 1119 |
+
if (visit_counts[i]==considered_visit)
|
| 1120 |
+
penalty.push_back(0);
|
| 1121 |
+
else
|
| 1122 |
+
penalty.push_back(-std::numeric_limits<float>::infinity());
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
assert(gumbel.size()==logits.size()==normalized_qvalues.size()==penalty.size());
|
| 1126 |
+
std::vector<float> score;
|
| 1127 |
+
for (unsigned int i=0;i < visit_counts.size();i++){
|
| 1128 |
+
score.push_back(std::max(low_logit, gumbel[i] + logits[i] + normalized_qvalues[i]) + penalty[i]);
|
| 1129 |
+
}
|
| 1130 |
+
|
| 1131 |
+
return score;
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
std::vector<float> generate_gumbel(float gumbel_scale, float gumbel_rng, int shape){
|
| 1135 |
+
/*
|
| 1136 |
+
Overview:
|
| 1137 |
+
Generate gumbel vectors.
|
| 1138 |
+
Arguments:
|
| 1139 |
+
- gumbel_scale: the scale of gumbel.
|
| 1140 |
+
- gumbel_rng: the seed to generate gumbel.
|
| 1141 |
+
- shape: the shape of gumbel vectors to be generated
|
| 1142 |
+
Outputs:
|
| 1143 |
+
- gumbel vectors.
|
| 1144 |
+
*/
|
| 1145 |
+
std::mt19937 gen(static_cast<unsigned int>(gumbel_rng));
|
| 1146 |
+
std::extreme_value_distribution<float> d(0, 1);
|
| 1147 |
+
|
| 1148 |
+
std::vector<float> gumbel;
|
| 1149 |
+
for (int i = 0;i < shape;i++)
|
| 1150 |
+
gumbel.push_back(gumbel_scale * d(gen));
|
| 1151 |
+
return gumbel;
|
| 1152 |
+
}
|
| 1153 |
+
|
| 1154 |
+
}
|
LightZero/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.h
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#ifndef CNODE_H
|
| 4 |
+
#define CNODE_H
|
| 5 |
+
|
| 6 |
+
#include "./../common_lib/cminimax.h"
|
| 7 |
+
#include <math.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <stack>
|
| 10 |
+
#include <stdlib.h>
|
| 11 |
+
#include <time.h>
|
| 12 |
+
#include <cmath>
|
| 13 |
+
#include <sys/timeb.h>
|
| 14 |
+
#include <sys/time.h>
|
| 15 |
+
#include <map>
|
| 16 |
+
|
| 17 |
+
const int DEBUG_MODE = 0;
|
| 18 |
+
|
| 19 |
+
namespace tree {
|
| 20 |
+
|
| 21 |
+
class CNode {
|
| 22 |
+
public:
|
| 23 |
+
int visit_count, to_play, current_latent_state_index, batch_index, best_action;
|
| 24 |
+
float reward, prior, value_sum, raw_value, gumbel_scale, gumbel_rng;
|
| 25 |
+
std::vector<int> children_index;
|
| 26 |
+
std::map<int, CNode> children;
|
| 27 |
+
|
| 28 |
+
std::vector<int> legal_actions;
|
| 29 |
+
std::vector<float> gumbel;
|
| 30 |
+
|
| 31 |
+
CNode();
|
| 32 |
+
CNode(float prior, std::vector<int> &legal_actions);
|
| 33 |
+
~CNode();
|
| 34 |
+
|
| 35 |
+
void expand(int to_play, int current_latent_state_index, int batch_index, float reward, float value, const std::vector<float> &policy_logits);
|
| 36 |
+
void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
|
| 37 |
+
std::vector<float> get_q(float discount);
|
| 38 |
+
float compute_mean_q(int isRoot, float parent_q, float discount);
|
| 39 |
+
void print_out();
|
| 40 |
+
|
| 41 |
+
int expanded();
|
| 42 |
+
|
| 43 |
+
float value();
|
| 44 |
+
|
| 45 |
+
std::vector<int> get_trajectory();
|
| 46 |
+
std::vector<int> get_children_distribution();
|
| 47 |
+
std::vector<float> get_children_value(float discount_factor, int action_space_size);
|
| 48 |
+
std::vector<float> get_policy(float discount, int action_space_size);
|
| 49 |
+
CNode* get_child(int action);
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
class CRoots{
|
| 53 |
+
public:
|
| 54 |
+
int root_num;
|
| 55 |
+
std::vector<CNode> roots;
|
| 56 |
+
std::vector<std::vector<int> > legal_actions_list;
|
| 57 |
+
|
| 58 |
+
CRoots();
|
| 59 |
+
CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list);
|
| 60 |
+
~CRoots();
|
| 61 |
+
|
| 62 |
+
void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 63 |
+
void prepare_no_noise(const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 64 |
+
void clear();
|
| 65 |
+
std::vector<std::vector<int> > get_trajectories();
|
| 66 |
+
std::vector<std::vector<int> > get_distributions();
|
| 67 |
+
std::vector<std::vector<float> > get_children_values(float discount, int action_space_size);
|
| 68 |
+
std::vector<std::vector<float> > get_policies(float discount, int action_space_size);
|
| 69 |
+
std::vector<float> get_values();
|
| 70 |
+
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
class CSearchResults{
|
| 74 |
+
public:
|
| 75 |
+
int num;
|
| 76 |
+
std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
|
| 77 |
+
std::vector<int> virtual_to_play_batchs;
|
| 78 |
+
std::vector<CNode*> nodes;
|
| 79 |
+
std::vector<std::vector<CNode*> > search_paths;
|
| 80 |
+
|
| 81 |
+
CSearchResults();
|
| 82 |
+
CSearchResults(int num);
|
| 83 |
+
~CSearchResults();
|
| 84 |
+
|
| 85 |
+
};
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
//*********************************************************
|
| 89 |
+
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount, int players);
|
| 90 |
+
void cback_propagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount);
|
| 91 |
+
void cbatch_back_propagate(int current_latent_state_index, float discount, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch);
|
| 92 |
+
int cselect_root_child(CNode* root, float discount, int num_simulations, int max_num_considered_actions);
|
| 93 |
+
int cselect_interior_child(CNode* root, float discount);
|
| 94 |
+
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q, int players);
|
| 95 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount, int players);
|
| 96 |
+
void cbatch_traverse(CRoots *roots, int num_simulations, int max_num_considered_actions, float discount, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
|
| 97 |
+
void csoftmax(std::vector<float> &input, int input_len);
|
| 98 |
+
float compute_mixed_value(float raw_value, std::vector<float> q_values, std::vector<int> &child_visit, std::vector<float> &child_prior);
|
| 99 |
+
void rescale_qvalues(std::vector<float> &value, float epsilon);
|
| 100 |
+
std::vector<float> qtransform_completed_by_mix_value(CNode *root, std::vector<int> & child_visit, \
|
| 101 |
+
std::vector<float> & child_prior, float discount= 0.99, float maxvisit_init = 50.0, float value_scale = 0.1, \
|
| 102 |
+
bool rescale_values = true, float epsilon = 1e-8);
|
| 103 |
+
std::vector<int> get_sequence_of_considered_visits(int max_num_considered_actions, int num_simulations);
|
| 104 |
+
std::vector<std::vector<int> > get_table_of_considered_visits(int max_num_considered_actions, int num_simulations);
|
| 105 |
+
std::vector<float> score_considered(int considered_visit, std::vector<float> gumbel, std::vector<float> logits, std::vector<float> normalized_qvalues, std::vector<int> visit_counts);
|
| 106 |
+
std::vector<float> generate_gumbel(float gumbel_scale, float gumbel_rng, int shape);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
#endif
|
LightZero/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp
ADDED
|
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include "cnode.h"
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <map>
|
| 7 |
+
#include <cassert>
|
| 8 |
+
|
| 9 |
+
#ifdef _WIN32
|
| 10 |
+
#include "..\..\common_lib\utils.cpp"
|
| 11 |
+
#else
|
| 12 |
+
#include "../../common_lib/utils.cpp"
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
namespace tree
|
| 17 |
+
{
|
| 18 |
+
|
| 19 |
+
CSearchResults::CSearchResults()
|
| 20 |
+
{
|
| 21 |
+
/*
|
| 22 |
+
Overview:
|
| 23 |
+
Initialization of CSearchResults, the default result number is set to 0.
|
| 24 |
+
*/
|
| 25 |
+
this->num = 0;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
CSearchResults::CSearchResults(int num)
|
| 29 |
+
{
|
| 30 |
+
/*
|
| 31 |
+
Overview:
|
| 32 |
+
Initialization of CSearchResults with result number.
|
| 33 |
+
*/
|
| 34 |
+
this->num = num;
|
| 35 |
+
for (int i = 0; i < num; ++i)
|
| 36 |
+
{
|
| 37 |
+
this->search_paths.push_back(std::vector<CNode *>());
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
CSearchResults::~CSearchResults() {}
|
| 42 |
+
|
| 43 |
+
//*********************************************************
|
| 44 |
+
|
| 45 |
+
CNode::CNode()
|
| 46 |
+
{
|
| 47 |
+
/*
|
| 48 |
+
Overview:
|
| 49 |
+
Initialization of CNode.
|
| 50 |
+
*/
|
| 51 |
+
this->prior = 0;
|
| 52 |
+
this->legal_actions = legal_actions;
|
| 53 |
+
|
| 54 |
+
this->visit_count = 0;
|
| 55 |
+
this->value_sum = 0;
|
| 56 |
+
this->best_action = -1;
|
| 57 |
+
this->to_play = 0;
|
| 58 |
+
this->reward = 0.0;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
CNode::CNode(float prior, std::vector<int> &legal_actions)
|
| 62 |
+
{
|
| 63 |
+
/*
|
| 64 |
+
Overview:
|
| 65 |
+
Initialization of CNode with prior value and legal actions.
|
| 66 |
+
Arguments:
|
| 67 |
+
- prior: the prior value of this node.
|
| 68 |
+
- legal_actions: a vector of legal actions of this node.
|
| 69 |
+
*/
|
| 70 |
+
this->prior = prior;
|
| 71 |
+
this->legal_actions = legal_actions;
|
| 72 |
+
|
| 73 |
+
this->visit_count = 0;
|
| 74 |
+
this->value_sum = 0;
|
| 75 |
+
this->best_action = -1;
|
| 76 |
+
this->to_play = 0;
|
| 77 |
+
this->current_latent_state_index = -1;
|
| 78 |
+
this->batch_index = -1;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
CNode::~CNode() {}
|
| 82 |
+
|
| 83 |
+
void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits)
|
| 84 |
+
{
|
| 85 |
+
/*
|
| 86 |
+
Overview:
|
| 87 |
+
Expand the child nodes of the current node.
|
| 88 |
+
Arguments:
|
| 89 |
+
- to_play: which player to play the game in the current node.
|
| 90 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path of the current node.
|
| 91 |
+
- batch_index: The index of latent state of the leaf node in the search path of the current node.
|
| 92 |
+
- reward: the reward of the current node.
|
| 93 |
+
- policy_logits: the logit of the child nodes.
|
| 94 |
+
*/
|
| 95 |
+
this->to_play = to_play;
|
| 96 |
+
this->current_latent_state_index = current_latent_state_index;
|
| 97 |
+
this->batch_index = batch_index;
|
| 98 |
+
this->reward = reward;
|
| 99 |
+
|
| 100 |
+
int action_num = policy_logits.size();
|
| 101 |
+
if (this->legal_actions.size() == 0)
|
| 102 |
+
{
|
| 103 |
+
for (int i = 0; i < action_num; ++i)
|
| 104 |
+
{
|
| 105 |
+
this->legal_actions.push_back(i);
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
float temp_policy;
|
| 109 |
+
float policy_sum = 0.0;
|
| 110 |
+
|
| 111 |
+
#ifdef _WIN32
|
| 112 |
+
// 创建动态数组
|
| 113 |
+
float* policy = new float[action_num];
|
| 114 |
+
#else
|
| 115 |
+
float policy[action_num];
|
| 116 |
+
#endif
|
| 117 |
+
|
| 118 |
+
float policy_max = FLOAT_MIN;
|
| 119 |
+
for (auto a : this->legal_actions)
|
| 120 |
+
{
|
| 121 |
+
if (policy_max < policy_logits[a])
|
| 122 |
+
{
|
| 123 |
+
policy_max = policy_logits[a];
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
for (auto a : this->legal_actions)
|
| 128 |
+
{
|
| 129 |
+
temp_policy = exp(policy_logits[a] - policy_max);
|
| 130 |
+
policy_sum += temp_policy;
|
| 131 |
+
policy[a] = temp_policy;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
float prior;
|
| 135 |
+
for (auto a : this->legal_actions)
|
| 136 |
+
{
|
| 137 |
+
prior = policy[a] / policy_sum;
|
| 138 |
+
std::vector<int> tmp_empty;
|
| 139 |
+
this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
#ifdef _WIN32
|
| 143 |
+
// 释放数组内存
|
| 144 |
+
delete[] policy;
|
| 145 |
+
#else
|
| 146 |
+
#endif
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
|
| 150 |
+
{
|
| 151 |
+
/*
|
| 152 |
+
Overview:
|
| 153 |
+
Add a noise to the prior of the child nodes.
|
| 154 |
+
Arguments:
|
| 155 |
+
- exploration_fraction: the fraction to add noise.
|
| 156 |
+
- noises: the vector of noises added to each child node.
|
| 157 |
+
*/
|
| 158 |
+
float noise, prior;
|
| 159 |
+
for (int i = 0; i < this->legal_actions.size(); ++i)
|
| 160 |
+
{
|
| 161 |
+
noise = noises[i];
|
| 162 |
+
CNode *child = this->get_child(this->legal_actions[i]);
|
| 163 |
+
|
| 164 |
+
prior = child->prior;
|
| 165 |
+
child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
|
| 170 |
+
{
|
| 171 |
+
/*
|
| 172 |
+
Overview:
|
| 173 |
+
Compute the mean q value of the current node.
|
| 174 |
+
Arguments:
|
| 175 |
+
- isRoot: whether the current node is a root node.
|
| 176 |
+
- parent_q: the q value of the parent node.
|
| 177 |
+
- discount_factor: the discount_factor of reward.
|
| 178 |
+
*/
|
| 179 |
+
float total_unsigned_q = 0.0;
|
| 180 |
+
int total_visits = 0;
|
| 181 |
+
for (auto a : this->legal_actions)
|
| 182 |
+
{
|
| 183 |
+
CNode *child = this->get_child(a);
|
| 184 |
+
if (child->visit_count > 0)
|
| 185 |
+
{
|
| 186 |
+
float true_reward = child->reward;
|
| 187 |
+
float qsa = true_reward + discount_factor * child->value();
|
| 188 |
+
total_unsigned_q += qsa;
|
| 189 |
+
total_visits += 1;
|
| 190 |
+
}
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
float mean_q = 0.0;
|
| 194 |
+
if (isRoot && total_visits > 0)
|
| 195 |
+
{
|
| 196 |
+
mean_q = (total_unsigned_q) / (total_visits);
|
| 197 |
+
}
|
| 198 |
+
else
|
| 199 |
+
{
|
| 200 |
+
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
|
| 201 |
+
}
|
| 202 |
+
return mean_q;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
void CNode::print_out()
|
| 206 |
+
{
|
| 207 |
+
return;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
int CNode::expanded()
|
| 211 |
+
{
|
| 212 |
+
/*
|
| 213 |
+
Overview:
|
| 214 |
+
Return whether the current node is expanded.
|
| 215 |
+
*/
|
| 216 |
+
return this->children.size() > 0;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
float CNode::value()
|
| 220 |
+
{
|
| 221 |
+
/*
|
| 222 |
+
Overview:
|
| 223 |
+
Return the real value of the current tree.
|
| 224 |
+
*/
|
| 225 |
+
float true_value = 0.0;
|
| 226 |
+
if (this->visit_count == 0)
|
| 227 |
+
{
|
| 228 |
+
return true_value;
|
| 229 |
+
}
|
| 230 |
+
else
|
| 231 |
+
{
|
| 232 |
+
true_value = this->value_sum / this->visit_count;
|
| 233 |
+
return true_value;
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
std::vector<int> CNode::get_trajectory()
|
| 238 |
+
{
|
| 239 |
+
/*
|
| 240 |
+
Overview:
|
| 241 |
+
Find the current best trajectory starts from the current node.
|
| 242 |
+
Outputs:
|
| 243 |
+
- traj: a vector of node index, which is the current best trajectory from this node.
|
| 244 |
+
*/
|
| 245 |
+
std::vector<int> traj;
|
| 246 |
+
|
| 247 |
+
CNode *node = this;
|
| 248 |
+
int best_action = node->best_action;
|
| 249 |
+
while (best_action >= 0)
|
| 250 |
+
{
|
| 251 |
+
traj.push_back(best_action);
|
| 252 |
+
|
| 253 |
+
node = node->get_child(best_action);
|
| 254 |
+
best_action = node->best_action;
|
| 255 |
+
}
|
| 256 |
+
return traj;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
std::vector<int> CNode::get_children_distribution()
|
| 260 |
+
{
|
| 261 |
+
/*
|
| 262 |
+
Overview:
|
| 263 |
+
Get the distribution of child nodes in the format of visit_count.
|
| 264 |
+
Outputs:
|
| 265 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 266 |
+
*/
|
| 267 |
+
std::vector<int> distribution;
|
| 268 |
+
if (this->expanded())
|
| 269 |
+
{
|
| 270 |
+
for (auto a : this->legal_actions)
|
| 271 |
+
{
|
| 272 |
+
CNode *child = this->get_child(a);
|
| 273 |
+
distribution.push_back(child->visit_count);
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
return distribution;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
CNode *CNode::get_child(int action)
|
| 280 |
+
{
|
| 281 |
+
/*
|
| 282 |
+
Overview:
|
| 283 |
+
Get the child node corresponding to the input action.
|
| 284 |
+
Arguments:
|
| 285 |
+
- action: the action to get child.
|
| 286 |
+
*/
|
| 287 |
+
return &(this->children[action]);
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
//*********************************************************
|
| 291 |
+
|
| 292 |
+
CRoots::CRoots()
|
| 293 |
+
{
|
| 294 |
+
/*
|
| 295 |
+
Overview:
|
| 296 |
+
The initialization of CRoots.
|
| 297 |
+
*/
|
| 298 |
+
this->root_num = 0;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list)
|
| 302 |
+
{
|
| 303 |
+
/*
|
| 304 |
+
Overview:
|
| 305 |
+
The initialization of CRoots with root num and legal action lists.
|
| 306 |
+
Arguments:
|
| 307 |
+
- root_num: the number of the current root.
|
| 308 |
+
- legal_action_list: the vector of the legal action of this root.
|
| 309 |
+
*/
|
| 310 |
+
this->root_num = root_num;
|
| 311 |
+
this->legal_actions_list = legal_actions_list;
|
| 312 |
+
|
| 313 |
+
for (int i = 0; i < root_num; ++i)
|
| 314 |
+
{
|
| 315 |
+
this->roots.push_back(CNode(0, this->legal_actions_list[i]));
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
CRoots::~CRoots() {}
|
| 320 |
+
|
| 321 |
+
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 322 |
+
{
|
| 323 |
+
/*
|
| 324 |
+
Overview:
|
| 325 |
+
Expand the roots and add noises.
|
| 326 |
+
Arguments:
|
| 327 |
+
- root_noise_weight: the exploration fraction of roots
|
| 328 |
+
- noises: the vector of noise add to the roots.
|
| 329 |
+
- rewards: the vector of rewards of each root.
|
| 330 |
+
- policies: the vector of policy logits of each root.
|
| 331 |
+
- to_play_batch: the vector of the player side of each root.
|
| 332 |
+
*/
|
| 333 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 334 |
+
{
|
| 335 |
+
this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i]);
|
| 336 |
+
this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
|
| 337 |
+
|
| 338 |
+
this->roots[i].visit_count += 1;
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
void CRoots::prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 343 |
+
{
|
| 344 |
+
/*
|
| 345 |
+
Overview:
|
| 346 |
+
Expand the roots without noise.
|
| 347 |
+
Arguments:
|
| 348 |
+
- rewards: the vector of rewards of each root.
|
| 349 |
+
- policies: the vector of policy logits of each root.
|
| 350 |
+
- to_play_batch: the vector of the player side of each root.
|
| 351 |
+
*/
|
| 352 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 353 |
+
{
|
| 354 |
+
this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i]);
|
| 355 |
+
|
| 356 |
+
this->roots[i].visit_count += 1;
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
void CRoots::clear()
|
| 361 |
+
{
|
| 362 |
+
/*
|
| 363 |
+
Overview:
|
| 364 |
+
Clear the roots vector.
|
| 365 |
+
*/
|
| 366 |
+
this->roots.clear();
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
std::vector<std::vector<int> > CRoots::get_trajectories()
|
| 370 |
+
{
|
| 371 |
+
/*
|
| 372 |
+
Overview:
|
| 373 |
+
Find the current best trajectory starts from each root.
|
| 374 |
+
Outputs:
|
| 375 |
+
- traj: a vector of node index, which is the current best trajectory from each root.
|
| 376 |
+
*/
|
| 377 |
+
std::vector<std::vector<int> > trajs;
|
| 378 |
+
trajs.reserve(this->root_num);
|
| 379 |
+
|
| 380 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 381 |
+
{
|
| 382 |
+
trajs.push_back(this->roots[i].get_trajectory());
|
| 383 |
+
}
|
| 384 |
+
return trajs;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
std::vector<std::vector<int> > CRoots::get_distributions()
|
| 388 |
+
{
|
| 389 |
+
/*
|
| 390 |
+
Overview:
|
| 391 |
+
Get the children distribution of each root.
|
| 392 |
+
Outputs:
|
| 393 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 394 |
+
*/
|
| 395 |
+
std::vector<std::vector<int> > distributions;
|
| 396 |
+
distributions.reserve(this->root_num);
|
| 397 |
+
|
| 398 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 399 |
+
{
|
| 400 |
+
distributions.push_back(this->roots[i].get_children_distribution());
|
| 401 |
+
}
|
| 402 |
+
return distributions;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
std::vector<float> CRoots::get_values()
|
| 406 |
+
{
|
| 407 |
+
/*
|
| 408 |
+
Overview:
|
| 409 |
+
Return the real value of each root.
|
| 410 |
+
*/
|
| 411 |
+
std::vector<float> values;
|
| 412 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 413 |
+
{
|
| 414 |
+
values.push_back(this->roots[i].value());
|
| 415 |
+
}
|
| 416 |
+
return values;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
//*********************************************************
|
| 420 |
+
//
|
| 421 |
+
void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
|
| 422 |
+
{
|
| 423 |
+
/*
|
| 424 |
+
Overview:
|
| 425 |
+
Update the q value of the root and its child nodes.
|
| 426 |
+
Arguments:
|
| 427 |
+
- root: the root that update q value from.
|
| 428 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 429 |
+
- discount_factor: the discount factor of reward.
|
| 430 |
+
- players: the number of players.
|
| 431 |
+
*/
|
| 432 |
+
std::stack<CNode *> node_stack;
|
| 433 |
+
node_stack.push(root);
|
| 434 |
+
while (node_stack.size() > 0)
|
| 435 |
+
{
|
| 436 |
+
CNode *node = node_stack.top();
|
| 437 |
+
node_stack.pop();
|
| 438 |
+
|
| 439 |
+
if (node != root)
|
| 440 |
+
{
|
| 441 |
+
// # NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 442 |
+
// # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 443 |
+
// # true_reward = node.value_prefix - (- parent_value_prefix)
|
| 444 |
+
// float true_reward = node->value_prefix - node->parent_value_prefix;
|
| 445 |
+
float true_reward = node->reward;
|
| 446 |
+
|
| 447 |
+
float qsa;
|
| 448 |
+
if (players == 1)
|
| 449 |
+
qsa = true_reward + discount_factor * node->value();
|
| 450 |
+
else if (players == 2)
|
| 451 |
+
// TODO(pu):
|
| 452 |
+
qsa = true_reward + discount_factor * (-1) * node->value();
|
| 453 |
+
|
| 454 |
+
min_max_stats.update(qsa);
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
for (auto a : node->legal_actions)
|
| 458 |
+
{
|
| 459 |
+
CNode *child = node->get_child(a);
|
| 460 |
+
if (child->expanded())
|
| 461 |
+
{
|
| 462 |
+
node_stack.push(child);
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
|
| 469 |
+
{
|
| 470 |
+
/*
|
| 471 |
+
Overview:
|
| 472 |
+
Update the value sum and visit count of nodes along the search path.
|
| 473 |
+
Arguments:
|
| 474 |
+
- search_path: a vector of nodes on the search path.
|
| 475 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 476 |
+
- to_play: which player to play the game in the current node.
|
| 477 |
+
- value: the value to propagate along the search path.
|
| 478 |
+
- discount_factor: the discount factor of reward.
|
| 479 |
+
*/
|
| 480 |
+
assert(to_play == -1 || to_play == 1 || to_play == 2);
|
| 481 |
+
if (to_play == -1)
|
| 482 |
+
{
|
| 483 |
+
// for play-with-bot-mode
|
| 484 |
+
float bootstrap_value = value;
|
| 485 |
+
int path_len = search_path.size();
|
| 486 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 487 |
+
{
|
| 488 |
+
CNode *node = search_path[i];
|
| 489 |
+
node->value_sum += bootstrap_value;
|
| 490 |
+
node->visit_count += 1;
|
| 491 |
+
|
| 492 |
+
float true_reward = node->reward;
|
| 493 |
+
|
| 494 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 495 |
+
|
| 496 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 497 |
+
}
|
| 498 |
+
}
|
| 499 |
+
else
|
| 500 |
+
{
|
| 501 |
+
// for self-play-mode
|
| 502 |
+
float bootstrap_value = value;
|
| 503 |
+
int path_len = search_path.size();
|
| 504 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 505 |
+
{
|
| 506 |
+
CNode *node = search_path[i];
|
| 507 |
+
if (node->to_play == to_play)
|
| 508 |
+
node->value_sum += bootstrap_value;
|
| 509 |
+
else
|
| 510 |
+
node->value_sum += -bootstrap_value;
|
| 511 |
+
node->visit_count += 1;
|
| 512 |
+
|
| 513 |
+
// NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 514 |
+
// but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 515 |
+
// float true_reward = node->value_prefix - parent_value_prefix;
|
| 516 |
+
float true_reward = node->reward;
|
| 517 |
+
|
| 518 |
+
// TODO(pu): why in muzero-general is - node.value
|
| 519 |
+
min_max_stats.update(true_reward + discount_factor * -node->value());
|
| 520 |
+
|
| 521 |
+
if (node->to_play == to_play)
|
| 522 |
+
bootstrap_value = -true_reward + discount_factor * bootstrap_value;
|
| 523 |
+
else
|
| 524 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 525 |
+
}
|
| 526 |
+
}
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch)
|
| 530 |
+
{
|
| 531 |
+
/*
|
| 532 |
+
Overview:
|
| 533 |
+
Expand the nodes along the search path and update the infos.
|
| 534 |
+
Arguments:
|
| 535 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path.
|
| 536 |
+
- discount_factor: the discount factor of reward.
|
| 537 |
+
- value_prefixs: the value prefixs of nodes along the search path.
|
| 538 |
+
- values: the values to propagate along the search path.
|
| 539 |
+
- policies: the policy logits of nodes along the search path.
|
| 540 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 541 |
+
- results: the search results.
|
| 542 |
+
- to_play_batch: the batch of which player is playing on this node.
|
| 543 |
+
*/
|
| 544 |
+
for (int i = 0; i < results.num; ++i)
|
| 545 |
+
{
|
| 546 |
+
results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]);
|
| 547 |
+
cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
|
| 548 |
+
}
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
|
| 552 |
+
{
|
| 553 |
+
/*
|
| 554 |
+
Overview:
|
| 555 |
+
Select the child node of the roots according to ucb scores.
|
| 556 |
+
Arguments:
|
| 557 |
+
- root: the roots to select the child node.
|
| 558 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 559 |
+
- pb_c_base: constants c2 in muzero.
|
| 560 |
+
- pb_c_init: constants c1 in muzero.
|
| 561 |
+
- disount_factor: the discount factor of reward.
|
| 562 |
+
- mean_q: the mean q value of the parent node.
|
| 563 |
+
- players: the number of players.
|
| 564 |
+
Outputs:
|
| 565 |
+
- action: the action to select.
|
| 566 |
+
*/
|
| 567 |
+
float max_score = FLOAT_MIN;
|
| 568 |
+
const float epsilon = 0.000001;
|
| 569 |
+
std::vector<int> max_index_lst;
|
| 570 |
+
for (auto a : root->legal_actions)
|
| 571 |
+
{
|
| 572 |
+
|
| 573 |
+
CNode *child = root->get_child(a);
|
| 574 |
+
float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
|
| 575 |
+
|
| 576 |
+
if (max_score < temp_score)
|
| 577 |
+
{
|
| 578 |
+
max_score = temp_score;
|
| 579 |
+
|
| 580 |
+
max_index_lst.clear();
|
| 581 |
+
max_index_lst.push_back(a);
|
| 582 |
+
}
|
| 583 |
+
else if (temp_score >= max_score - epsilon)
|
| 584 |
+
{
|
| 585 |
+
max_index_lst.push_back(a);
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
int action = 0;
|
| 590 |
+
if (max_index_lst.size() > 0)
|
| 591 |
+
{
|
| 592 |
+
int rand_index = rand() % max_index_lst.size();
|
| 593 |
+
action = max_index_lst[rand_index];
|
| 594 |
+
}
|
| 595 |
+
return action;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
|
| 599 |
+
{
|
| 600 |
+
/*
|
| 601 |
+
Overview:
|
| 602 |
+
Compute the ucb score of the child.
|
| 603 |
+
Arguments:
|
| 604 |
+
- child: the child node to compute ucb score.
|
| 605 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 606 |
+
- mean_q: the mean q value of the parent node.
|
| 607 |
+
- total_children_visit_counts: the total visit counts of the child nodes of the parent node.
|
| 608 |
+
- pb_c_base: constants c2 in muzero.
|
| 609 |
+
- pb_c_init: constants c1 in muzero.
|
| 610 |
+
- disount_factor: the discount factor of reward.
|
| 611 |
+
- players: the number of players.
|
| 612 |
+
Outputs:
|
| 613 |
+
- ucb_value: the ucb score of the child.
|
| 614 |
+
*/
|
| 615 |
+
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
|
| 616 |
+
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
|
| 617 |
+
pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
|
| 618 |
+
|
| 619 |
+
prior_score = pb_c * child->prior;
|
| 620 |
+
if (child->visit_count == 0)
|
| 621 |
+
{
|
| 622 |
+
value_score = parent_mean_q;
|
| 623 |
+
}
|
| 624 |
+
else
|
| 625 |
+
{
|
| 626 |
+
float true_reward = child->reward;
|
| 627 |
+
if (players == 1)
|
| 628 |
+
value_score = true_reward + discount_factor * child->value();
|
| 629 |
+
else if (players == 2)
|
| 630 |
+
value_score = true_reward + discount_factor * (-child->value());
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
value_score = min_max_stats.normalize(value_score);
|
| 634 |
+
|
| 635 |
+
if (value_score < 0)
|
| 636 |
+
value_score = 0;
|
| 637 |
+
if (value_score > 1)
|
| 638 |
+
value_score = 1;
|
| 639 |
+
|
| 640 |
+
float ucb_value = prior_score + value_score;
|
| 641 |
+
return ucb_value;
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
|
| 645 |
+
{
|
| 646 |
+
/*
|
| 647 |
+
Overview:
|
| 648 |
+
Search node path from the roots.
|
| 649 |
+
Arguments:
|
| 650 |
+
- roots: the roots that search from.
|
| 651 |
+
- pb_c_base: constants c2 in muzero.
|
| 652 |
+
- pb_c_init: constants c1 in muzero.
|
| 653 |
+
- disount_factor: the discount factor of reward.
|
| 654 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 655 |
+
- results: the search results.
|
| 656 |
+
- virtual_to_play_batch: the batch of which player is playing on this node.
|
| 657 |
+
*/
|
| 658 |
+
// set seed
|
| 659 |
+
get_time_and_set_rand_seed();
|
| 660 |
+
|
| 661 |
+
int last_action = -1;
|
| 662 |
+
float parent_q = 0.0;
|
| 663 |
+
results.search_lens = std::vector<int>();
|
| 664 |
+
|
| 665 |
+
int players = 0;
|
| 666 |
+
int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
|
| 667 |
+
if (largest_element == -1)
|
| 668 |
+
players = 1;
|
| 669 |
+
else
|
| 670 |
+
players = 2;
|
| 671 |
+
|
| 672 |
+
for (int i = 0; i < results.num; ++i)
|
| 673 |
+
{
|
| 674 |
+
CNode *node = &(roots->roots[i]);
|
| 675 |
+
int is_root = 1;
|
| 676 |
+
int search_len = 0;
|
| 677 |
+
results.search_paths[i].push_back(node);
|
| 678 |
+
|
| 679 |
+
while (node->expanded())
|
| 680 |
+
{
|
| 681 |
+
float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
|
| 682 |
+
is_root = 0;
|
| 683 |
+
parent_q = mean_q;
|
| 684 |
+
|
| 685 |
+
int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
|
| 686 |
+
if (players > 1)
|
| 687 |
+
{
|
| 688 |
+
assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
|
| 689 |
+
if (virtual_to_play_batch[i] == 1)
|
| 690 |
+
virtual_to_play_batch[i] = 2;
|
| 691 |
+
else
|
| 692 |
+
virtual_to_play_batch[i] = 1;
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
node->best_action = action;
|
| 696 |
+
// next
|
| 697 |
+
node = node->get_child(action);
|
| 698 |
+
last_action = action;
|
| 699 |
+
results.search_paths[i].push_back(node);
|
| 700 |
+
search_len += 1;
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
|
| 704 |
+
|
| 705 |
+
results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
|
| 706 |
+
results.latent_state_index_in_batch.push_back(parent->batch_index);
|
| 707 |
+
|
| 708 |
+
results.last_actions.push_back(last_action);
|
| 709 |
+
results.search_lens.push_back(search_len);
|
| 710 |
+
results.nodes.push_back(node);
|
| 711 |
+
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
|
| 712 |
+
}
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
}
|
LightZero/lzero/mcts/ctree/ctree_muzero/lib/cnode.h
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#ifndef CNODE_H
|
| 4 |
+
#define CNODE_H
|
| 5 |
+
|
| 6 |
+
#include "./../common_lib/cminimax.h"
|
| 7 |
+
#include <math.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <stack>
|
| 10 |
+
#include <stdlib.h>
|
| 11 |
+
#include <time.h>
|
| 12 |
+
#include <cmath>
|
| 13 |
+
#include <sys/timeb.h>
|
| 14 |
+
#include <time.h>
|
| 15 |
+
#include <map>
|
| 16 |
+
|
| 17 |
+
const int DEBUG_MODE = 0;
|
| 18 |
+
|
| 19 |
+
namespace tree {
|
| 20 |
+
|
| 21 |
+
class CNode {
|
| 22 |
+
public:
|
| 23 |
+
int visit_count, to_play, current_latent_state_index, batch_index, best_action;
|
| 24 |
+
float reward, prior, value_sum;
|
| 25 |
+
std::vector<int> children_index;
|
| 26 |
+
std::map<int, CNode> children;
|
| 27 |
+
|
| 28 |
+
std::vector<int> legal_actions;
|
| 29 |
+
|
| 30 |
+
CNode();
|
| 31 |
+
CNode(float prior, std::vector<int> &legal_actions);
|
| 32 |
+
~CNode();
|
| 33 |
+
|
| 34 |
+
void expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits);
|
| 35 |
+
void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
|
| 36 |
+
float compute_mean_q(int isRoot, float parent_q, float discount_factor);
|
| 37 |
+
void print_out();
|
| 38 |
+
|
| 39 |
+
int expanded();
|
| 40 |
+
|
| 41 |
+
float value();
|
| 42 |
+
|
| 43 |
+
std::vector<int> get_trajectory();
|
| 44 |
+
std::vector<int> get_children_distribution();
|
| 45 |
+
CNode* get_child(int action);
|
| 46 |
+
};
|
| 47 |
+
|
| 48 |
+
class CRoots{
|
| 49 |
+
public:
|
| 50 |
+
int root_num;
|
| 51 |
+
std::vector<CNode> roots;
|
| 52 |
+
std::vector<std::vector<int> > legal_actions_list;
|
| 53 |
+
|
| 54 |
+
CRoots();
|
| 55 |
+
CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list);
|
| 56 |
+
~CRoots();
|
| 57 |
+
|
| 58 |
+
void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 59 |
+
void prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 60 |
+
void clear();
|
| 61 |
+
std::vector<std::vector<int> > get_trajectories();
|
| 62 |
+
std::vector<std::vector<int> > get_distributions();
|
| 63 |
+
std::vector<float> get_values();
|
| 64 |
+
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
class CSearchResults{
|
| 68 |
+
public:
|
| 69 |
+
int num;
|
| 70 |
+
std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
|
| 71 |
+
std::vector<int> virtual_to_play_batchs;
|
| 72 |
+
std::vector<CNode*> nodes;
|
| 73 |
+
std::vector<std::vector<CNode*> > search_paths;
|
| 74 |
+
|
| 75 |
+
CSearchResults();
|
| 76 |
+
CSearchResults(int num);
|
| 77 |
+
~CSearchResults();
|
| 78 |
+
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
//*********************************************************
|
| 83 |
+
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
|
| 84 |
+
void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
|
| 85 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch);
|
| 86 |
+
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
|
| 87 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players);
|
| 88 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#endif
|
LightZero/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
ADDED
|
@@ -0,0 +1,1189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include "cnode.h"
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <map>
|
| 7 |
+
#include <random>
|
| 8 |
+
#include <chrono>
|
| 9 |
+
#include <iostream>
|
| 10 |
+
#include <vector>
|
| 11 |
+
#include <stack>
|
| 12 |
+
#include <math.h>
|
| 13 |
+
|
| 14 |
+
#include <stdlib.h>
|
| 15 |
+
#include <time.h>
|
| 16 |
+
#include <cmath>
|
| 17 |
+
#include <sys/timeb.h>
|
| 18 |
+
#include <time.h>
|
| 19 |
+
#include <cassert>
|
| 20 |
+
|
| 21 |
+
#ifdef _WIN32
|
| 22 |
+
#include "..\..\common_lib\utils.cpp"
|
| 23 |
+
#else
|
| 24 |
+
#include "../../common_lib/utils.cpp"
|
| 25 |
+
#endif
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
template <class T>
|
| 30 |
+
size_t hash_combine(std::size_t &seed, const T &val)
|
| 31 |
+
{
|
| 32 |
+
/*
|
| 33 |
+
Overview:
|
| 34 |
+
Combines a hash value with a new value using a bitwise XOR and a rotation.
|
| 35 |
+
This function is used to create a hash value for multiple values.
|
| 36 |
+
Arguments:
|
| 37 |
+
- seed The current hash value to be combined with.
|
| 38 |
+
- val The new value to be hashed and combined with the seed.
|
| 39 |
+
*/
|
| 40 |
+
std::hash<T> hasher; // Create a hash object for the new value.
|
| 41 |
+
seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); // Combine the new hash value with the seed.
|
| 42 |
+
return seed;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// Sort by the value of second in descending order.
|
| 46 |
+
bool cmp(std::pair<int, double> x, std::pair<int, double> y)
|
| 47 |
+
{
|
| 48 |
+
return x.second > y.second;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
namespace tree
|
| 52 |
+
{
|
| 53 |
+
//*********************************************************
|
| 54 |
+
|
| 55 |
+
CAction::CAction()
|
| 56 |
+
{
|
| 57 |
+
/*
|
| 58 |
+
Overview:
|
| 59 |
+
Initialization of CAction. Parameterized constructor.
|
| 60 |
+
*/
|
| 61 |
+
this->is_root_action = 0;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
CAction::CAction(std::vector<float> value, int is_root_action)
|
| 65 |
+
{
|
| 66 |
+
/*
|
| 67 |
+
Overview:
|
| 68 |
+
Initialization of CAction with value and is_root_action. Default constructor.
|
| 69 |
+
Arguments:
|
| 70 |
+
- value: a multi-dimensional action.
|
| 71 |
+
- is_root_action: whether value is a root node.
|
| 72 |
+
*/
|
| 73 |
+
this->value = value;
|
| 74 |
+
this->is_root_action = is_root_action;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
CAction::~CAction() {} // Destructors.
|
| 78 |
+
|
| 79 |
+
std::vector<size_t> CAction::get_hash(void)
|
| 80 |
+
{
|
| 81 |
+
/*
|
| 82 |
+
Overview:
|
| 83 |
+
get a hash value for each dimension in the multi-dimensional action.
|
| 84 |
+
*/
|
| 85 |
+
std::vector<size_t> hash;
|
| 86 |
+
for (int i = 0; i < this->value.size(); ++i)
|
| 87 |
+
{
|
| 88 |
+
std::size_t hash_i = std::hash<std::string>()(std::to_string(this->value[i]));
|
| 89 |
+
hash.push_back(hash_i);
|
| 90 |
+
}
|
| 91 |
+
return hash;
|
| 92 |
+
}
|
| 93 |
+
size_t CAction::get_combined_hash(void)
|
| 94 |
+
{
|
| 95 |
+
/*
|
| 96 |
+
Overview:
|
| 97 |
+
get the final combined hash value from the hash values of each dimension of the multi-dimensional action.
|
| 98 |
+
*/
|
| 99 |
+
std::vector<size_t> hash = this->get_hash();
|
| 100 |
+
size_t combined_hash = hash[0];
|
| 101 |
+
|
| 102 |
+
if (hash.size() >= 1)
|
| 103 |
+
{
|
| 104 |
+
for (int i = 1; i < hash.size(); ++i)
|
| 105 |
+
{
|
| 106 |
+
combined_hash = hash_combine(combined_hash, hash[i]);
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
return combined_hash;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
//*********************************************************
|
| 114 |
+
|
| 115 |
+
CSearchResults::CSearchResults()
|
| 116 |
+
{
|
| 117 |
+
/*
|
| 118 |
+
Overview:
|
| 119 |
+
Initialization of CSearchResults, the default result number is set to 0.
|
| 120 |
+
*/
|
| 121 |
+
this->num = 0;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
CSearchResults::CSearchResults(int num)
|
| 125 |
+
{
|
| 126 |
+
/*
|
| 127 |
+
Overview:
|
| 128 |
+
Initialization of CSearchResults with result number.
|
| 129 |
+
*/
|
| 130 |
+
this->num = num;
|
| 131 |
+
for (int i = 0; i < num; ++i)
|
| 132 |
+
{
|
| 133 |
+
this->search_paths.push_back(std::vector<CNode *>());
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
CSearchResults::~CSearchResults() {}
|
| 138 |
+
|
| 139 |
+
//*********************************************************
|
| 140 |
+
|
| 141 |
+
CNode::CNode()
|
| 142 |
+
{
|
| 143 |
+
/*
|
| 144 |
+
Overview:
|
| 145 |
+
Initialization of CNode.
|
| 146 |
+
*/
|
| 147 |
+
this->prior = 0;
|
| 148 |
+
this->action_space_size = 9;
|
| 149 |
+
this->num_of_sampled_actions = 20;
|
| 150 |
+
this->continuous_action_space = false;
|
| 151 |
+
|
| 152 |
+
this->is_reset = 0;
|
| 153 |
+
this->visit_count = 0;
|
| 154 |
+
this->value_sum = 0;
|
| 155 |
+
CAction best_action;
|
| 156 |
+
this->best_action = best_action;
|
| 157 |
+
|
| 158 |
+
this->to_play = 0;
|
| 159 |
+
this->value_prefix = 0.0;
|
| 160 |
+
this->parent_value_prefix = 0.0;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
CNode::CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
|
| 164 |
+
{
|
| 165 |
+
/*
|
| 166 |
+
Overview:
|
| 167 |
+
Initialization of CNode with prior, legal actions, action_space_size, num_of_sampled_actions, continuous_action_space.
|
| 168 |
+
Arguments:
|
| 169 |
+
- prior: the prior value of this node.
|
| 170 |
+
- legal_actions: a vector of legal actions of this node.
|
| 171 |
+
- action_space_size: the size of action space of the current env.
|
| 172 |
+
- num_of_sampled_actions: the number of sampled actions, i.e. K in the Sampled MuZero papers.
|
| 173 |
+
- continuous_action_space: whether the action space is continous in current env.
|
| 174 |
+
*/
|
| 175 |
+
this->prior = prior;
|
| 176 |
+
this->legal_actions = legal_actions;
|
| 177 |
+
|
| 178 |
+
this->action_space_size = action_space_size;
|
| 179 |
+
this->num_of_sampled_actions = num_of_sampled_actions;
|
| 180 |
+
this->continuous_action_space = continuous_action_space;
|
| 181 |
+
this->is_reset = 0;
|
| 182 |
+
this->visit_count = 0;
|
| 183 |
+
this->value_sum = 0;
|
| 184 |
+
this->to_play = 0;
|
| 185 |
+
this->value_prefix = 0.0;
|
| 186 |
+
this->parent_value_prefix = 0.0;
|
| 187 |
+
this->current_latent_state_index = -1;
|
| 188 |
+
this->batch_index = -1;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
CNode::~CNode() {}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits)
|
| 195 |
+
{
|
| 196 |
+
/*
|
| 197 |
+
Overview:
|
| 198 |
+
Expand the child nodes of the current node.
|
| 199 |
+
Arguments:
|
| 200 |
+
- to_play: which player to play the game in the current node.
|
| 201 |
+
- current_latent_state_index: the x/first index of hidden state vector of the current node, i.e. the search depth.
|
| 202 |
+
- batch_index: the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``.
|
| 203 |
+
- value_prefix: the value prefix of the current node.
|
| 204 |
+
- policy_logits: the logit of the child nodes.
|
| 205 |
+
*/
|
| 206 |
+
this->to_play = to_play;
|
| 207 |
+
this->current_latent_state_index = current_latent_state_index;
|
| 208 |
+
this->batch_index = batch_index;
|
| 209 |
+
this->value_prefix = value_prefix;
|
| 210 |
+
int action_num = policy_logits.size();
|
| 211 |
+
|
| 212 |
+
#ifdef _WIN32
|
| 213 |
+
// 创建动态数组
|
| 214 |
+
float* policy = new float[action_num];
|
| 215 |
+
#else
|
| 216 |
+
float policy[action_num];
|
| 217 |
+
#endif
|
| 218 |
+
|
| 219 |
+
std::vector<int> all_actions;
|
| 220 |
+
for (int i = 0; i < action_num; ++i)
|
| 221 |
+
{
|
| 222 |
+
all_actions.push_back(i);
|
| 223 |
+
}
|
| 224 |
+
std::vector<std::vector<float> > sampled_actions_after_tanh;
|
| 225 |
+
std::vector<float> sampled_actions_log_probs_after_tanh;
|
| 226 |
+
|
| 227 |
+
std::vector<int> sampled_actions;
|
| 228 |
+
std::vector<float> sampled_actions_log_probs;
|
| 229 |
+
std::vector<float> sampled_actions_probs;
|
| 230 |
+
std::vector<float> probs;
|
| 231 |
+
|
| 232 |
+
/*
|
| 233 |
+
Overview:
|
| 234 |
+
When the currennt env has continuous action space, sampled K actions from continuous gaussia distribution policy.
|
| 235 |
+
When the currennt env has discrete action space, sampled K actions from discrete categirical distribution policy.
|
| 236 |
+
|
| 237 |
+
*/
|
| 238 |
+
if (this->continuous_action_space == true)
|
| 239 |
+
{
|
| 240 |
+
// continuous action space for sampled algo..
|
| 241 |
+
this->action_space_size = policy_logits.size() / 2;
|
| 242 |
+
std::vector<float> mu;
|
| 243 |
+
std::vector<float> sigma;
|
| 244 |
+
for (int i = 0; i < this->action_space_size; ++i)
|
| 245 |
+
{
|
| 246 |
+
mu.push_back(policy_logits[i]);
|
| 247 |
+
sigma.push_back(policy_logits[this->action_space_size + i]);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
// The number of nanoseconds that have elapsed since epoch(1970: 00: 00 UTC on January 1, 1970). unsigned type will truncate this value.
|
| 251 |
+
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
|
| 252 |
+
|
| 253 |
+
// SAC-like tanh, pleasee refer to paper https://arxiv.org/abs/1801.01290.
|
| 254 |
+
std::vector<std::vector<float> > sampled_actions_before_tanh;
|
| 255 |
+
|
| 256 |
+
float sampled_action_one_dim_before_tanh;
|
| 257 |
+
std::vector<float> sampled_actions_log_probs_before_tanh;
|
| 258 |
+
|
| 259 |
+
std::default_random_engine generator(seed);
|
| 260 |
+
for (int i = 0; i < this->num_of_sampled_actions; ++i)
|
| 261 |
+
{
|
| 262 |
+
float sampled_action_prob_before_tanh = 1;
|
| 263 |
+
// TODO(pu): why here
|
| 264 |
+
std::vector<float> sampled_action_before_tanh;
|
| 265 |
+
std::vector<float> sampled_action_after_tanh;
|
| 266 |
+
std::vector<float> y;
|
| 267 |
+
|
| 268 |
+
for (int j = 0; j < this->action_space_size; ++j)
|
| 269 |
+
{
|
| 270 |
+
std::normal_distribution<float> distribution(mu[j], sigma[j]);
|
| 271 |
+
sampled_action_one_dim_before_tanh = distribution(generator);
|
| 272 |
+
// refer to python normal log_prob method
|
| 273 |
+
sampled_action_prob_before_tanh *= exp(-pow((sampled_action_one_dim_before_tanh - mu[j]), 2) / (2 * pow(sigma[j], 2)) - log(sigma[j]) - log(sqrt(2 * M_PI)));
|
| 274 |
+
sampled_action_before_tanh.push_back(sampled_action_one_dim_before_tanh);
|
| 275 |
+
sampled_action_after_tanh.push_back(tanh(sampled_action_one_dim_before_tanh));
|
| 276 |
+
y.push_back(1 - pow(tanh(sampled_action_one_dim_before_tanh), 2) + 1e-6);
|
| 277 |
+
}
|
| 278 |
+
sampled_actions_before_tanh.push_back(sampled_action_before_tanh);
|
| 279 |
+
sampled_actions_after_tanh.push_back(sampled_action_after_tanh);
|
| 280 |
+
sampled_actions_log_probs_before_tanh.push_back(log(sampled_action_prob_before_tanh));
|
| 281 |
+
float y_sum = std::accumulate(y.begin(), y.end(), 0.);
|
| 282 |
+
sampled_actions_log_probs_after_tanh.push_back(log(sampled_action_prob_before_tanh) - log(y_sum));
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
else
|
| 286 |
+
{
|
| 287 |
+
// discrete action space for sampled algo..
|
| 288 |
+
|
| 289 |
+
//========================================================
|
| 290 |
+
// python code
|
| 291 |
+
//========================================================
|
| 292 |
+
// if self.legal_actions is not None:
|
| 293 |
+
// # fisrt use the self.legal_actions to exclude the illegal actions
|
| 294 |
+
// policy_tmp = [0. for _ in range(self.action_space_size)]
|
| 295 |
+
// for index, legal_action in enumerate(self.legal_actions):
|
| 296 |
+
// policy_tmp[legal_action] = policy_logits[index]
|
| 297 |
+
// policy_logits = policy_tmp
|
| 298 |
+
// # then empty the self.legal_actions
|
| 299 |
+
// self.legal_actions = []
|
| 300 |
+
// then empty the self.legal_actions
|
| 301 |
+
// prob = torch.softmax(torch.tensor(policy_logits), dim=-1)
|
| 302 |
+
// sampled_actions = torch.multinomial(prob, self.num_of_sampled_actions, replacement=False)
|
| 303 |
+
|
| 304 |
+
//========================================================
|
| 305 |
+
// TODO(pu): legal actions
|
| 306 |
+
//========================================================
|
| 307 |
+
// std::vector<float> policy_tmp;
|
| 308 |
+
// for (int i = 0; i < this->action_space_size; ++i)
|
| 309 |
+
// {
|
| 310 |
+
// policy_tmp.push_back(0.);
|
| 311 |
+
// }
|
| 312 |
+
// for (int i = 0; i < this->legal_actions.size(); ++i)
|
| 313 |
+
// {
|
| 314 |
+
// policy_tmp[this->legal_actions[i].value] = policy_logits[i];
|
| 315 |
+
// }
|
| 316 |
+
// for (int i = 0; i < this->action_space_size; ++i)
|
| 317 |
+
// {
|
| 318 |
+
// policy_logits[i] = policy_tmp[i];
|
| 319 |
+
// }
|
| 320 |
+
// std::cout << "position 3" << std::endl;
|
| 321 |
+
|
| 322 |
+
// python code: legal_actions = []
|
| 323 |
+
std::vector<CAction> legal_actions;
|
| 324 |
+
|
| 325 |
+
// python code: probs = softmax(policy_logits)
|
| 326 |
+
float logits_exp_sum = 0;
|
| 327 |
+
for (int i = 0; i < policy_logits.size(); ++i)
|
| 328 |
+
{
|
| 329 |
+
logits_exp_sum += exp(policy_logits[i]);
|
| 330 |
+
}
|
| 331 |
+
for (int i = 0; i < policy_logits.size(); ++i)
|
| 332 |
+
{
|
| 333 |
+
probs.push_back(exp(policy_logits[i]) / (logits_exp_sum + 1e-6));
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
|
| 337 |
+
|
| 338 |
+
// cout << "sampled_action[0]:" << sampled_action[0] <<endl;
|
| 339 |
+
|
| 340 |
+
// std::vector<int> sampled_actions;
|
| 341 |
+
// std::vector<float> sampled_actions_log_probs;
|
| 342 |
+
// std::vector<float> sampled_actions_probs;
|
| 343 |
+
std::default_random_engine generator(seed);
|
| 344 |
+
|
| 345 |
+
// 有放回抽样
|
| 346 |
+
// for (int i = 0; i < num_of_sampled_actions; ++i)
|
| 347 |
+
// {
|
| 348 |
+
// float sampled_action_prob = 1;
|
| 349 |
+
// int sampled_action;
|
| 350 |
+
|
| 351 |
+
// std::discrete_distribution<float> distribution(probs.begin(), probs.end());
|
| 352 |
+
|
| 353 |
+
// // for (float x:distribution.probabilities()) std::cout << x << " ";
|
| 354 |
+
// sampled_action = distribution(generator);
|
| 355 |
+
// // std::cout << "sampled_action: " << sampled_action << std::endl;
|
| 356 |
+
|
| 357 |
+
// sampled_actions.push_back(sampled_action);
|
| 358 |
+
// sampled_actions_probs.push_back(probs[sampled_action]);
|
| 359 |
+
// std::cout << "sampled_actions_probs" << '[' << i << ']' << sampled_actions_probs[i] << std::endl;
|
| 360 |
+
|
| 361 |
+
// sampled_actions_log_probs.push_back(log(probs[sampled_action]));
|
| 362 |
+
// std::cout << "sampled_actions_log_probs" << '[' << i << ']' << sampled_actions_log_probs[i] << std::endl;
|
| 363 |
+
// }
|
| 364 |
+
|
| 365 |
+
// 每个节点的legal_actions应该为一个固定离散集合,所以采用无放回抽样
|
| 366 |
+
// std::cout << "position uniform_distribution init" << std::endl;
|
| 367 |
+
std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0); //均匀分布
|
| 368 |
+
// std::cout << "position uniform_distribution done" << std::endl;
|
| 369 |
+
std::vector<double> disturbed_probs;
|
| 370 |
+
std::vector<std::pair<int, double> > disc_action_with_probs;
|
| 371 |
+
|
| 372 |
+
// Use the reciprocal of the probability value as the exponent and a random number sampled from a uniform distribution as the base:
|
| 373 |
+
// Equivalent to adding a uniform random disturbance to the original probability value.
|
| 374 |
+
for (auto prob : probs)
|
| 375 |
+
{
|
| 376 |
+
disturbed_probs.push_back(std::pow(uniform_distribution(generator), 1. / prob));
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
// Sort from large to small according to the probability value after the disturbance:
|
| 380 |
+
// After sorting, the first vector is the index, and the second vector is the probability value after perturbation sorted from large to small.
|
| 381 |
+
for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
|
| 382 |
+
{
|
| 383 |
+
|
| 384 |
+
#ifdef __GNUC__
|
| 385 |
+
// Use push_back for GCC
|
| 386 |
+
disc_action_with_probs.push_back(std::make_pair(iter, disturbed_probs[iter]));
|
| 387 |
+
#else
|
| 388 |
+
// Use emplace_back for other compilers
|
| 389 |
+
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
|
| 390 |
+
#endif
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
|
| 394 |
+
|
| 395 |
+
// take the fist ``num_of_sampled_actions`` actions
|
| 396 |
+
for (int k = 0; k < num_of_sampled_actions; ++k)
|
| 397 |
+
{
|
| 398 |
+
sampled_actions.push_back(disc_action_with_probs[k].first);
|
| 399 |
+
// disc_action_with_probs[k].second is disturbed_probs
|
| 400 |
+
// sampled_actions_probs.push_back(disc_action_with_probs[k].second);
|
| 401 |
+
sampled_actions_probs.push_back(probs[disc_action_with_probs[k].first]);
|
| 402 |
+
|
| 403 |
+
// TODO(pu): logging
|
| 404 |
+
// std::cout << "sampled_actions[k]: " << sampled_actions[k] << std::endl;
|
| 405 |
+
// std::cout << "sampled_actions_probs[k]: " << sampled_actions_probs[k] << std::endl;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
// TODO(pu): fixed k, only for debugging
|
| 409 |
+
// Take the first ``num_of_sampled_actions`` actions: k=0,1,...,K-1
|
| 410 |
+
// for (int k = 0; k < num_of_sampled_actions; ++k)
|
| 411 |
+
// {
|
| 412 |
+
// sampled_actions.push_back(k);
|
| 413 |
+
// // disc_action_with_probs[k].second is disturbed_probs
|
| 414 |
+
// // sampled_actions_probs.push_back(disc_action_with_probs[k].second);
|
| 415 |
+
// sampled_actions_probs.push_back(probs[k]);
|
| 416 |
+
// }
|
| 417 |
+
|
| 418 |
+
disturbed_probs.clear(); // Empty the collection to prepare for the next sampling.
|
| 419 |
+
disc_action_with_probs.clear(); // Empty the collection to prepare for the next sampling.
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
float prior;
|
| 423 |
+
for (int i = 0; i < this->num_of_sampled_actions; ++i)
|
| 424 |
+
{
|
| 425 |
+
|
| 426 |
+
if (this->continuous_action_space == true)
|
| 427 |
+
{
|
| 428 |
+
CAction action = CAction(sampled_actions_after_tanh[i], 0);
|
| 429 |
+
std::vector<CAction> legal_actions;
|
| 430 |
+
this->children[action.get_combined_hash()] = CNode(sampled_actions_log_probs_after_tanh[i], legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space); // only for muzero/efficient zero, not support alphazero
|
| 431 |
+
this->legal_actions.push_back(action);
|
| 432 |
+
}
|
| 433 |
+
else
|
| 434 |
+
{
|
| 435 |
+
std::vector<float> sampled_action_tmp;
|
| 436 |
+
for (size_t iter = 0; iter < 1; iter++)
|
| 437 |
+
{
|
| 438 |
+
sampled_action_tmp.push_back(float(sampled_actions[i]));
|
| 439 |
+
}
|
| 440 |
+
CAction action = CAction(sampled_action_tmp, 0);
|
| 441 |
+
std::vector<CAction> legal_actions;
|
| 442 |
+
this->children[action.get_combined_hash()] = CNode(sampled_actions_probs[i], legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space); // only for muzero/efficient zero, not support alphazero
|
| 443 |
+
this->legal_actions.push_back(action);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
#ifdef _WIN32
|
| 448 |
+
// 释放数组内存
|
| 449 |
+
delete[] policy;
|
| 450 |
+
#else
|
| 451 |
+
#endif
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
|
| 455 |
+
{
|
| 456 |
+
/*
|
| 457 |
+
Overview:
|
| 458 |
+
Add a noise to the prior of the child nodes.
|
| 459 |
+
Arguments:
|
| 460 |
+
- exploration_fraction: the fraction to add noise.
|
| 461 |
+
- noises: the vector of noises added to each child node.
|
| 462 |
+
*/
|
| 463 |
+
float noise, prior;
|
| 464 |
+
for (int i = 0; i < this->num_of_sampled_actions; ++i)
|
| 465 |
+
{
|
| 466 |
+
|
| 467 |
+
noise = noises[i];
|
| 468 |
+
CNode *child = this->get_child(this->legal_actions[i]);
|
| 469 |
+
prior = child->prior;
|
| 470 |
+
if (this->continuous_action_space == true)
|
| 471 |
+
{
|
| 472 |
+
// if prior is log_prob
|
| 473 |
+
child->prior = log(exp(prior) * (1 - exploration_fraction) + noise * exploration_fraction + 1e-6);
|
| 474 |
+
}
|
| 475 |
+
else
|
| 476 |
+
{
|
| 477 |
+
// if prior is prob
|
| 478 |
+
child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
|
| 479 |
+
}
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
|
| 484 |
+
{
|
| 485 |
+
/*
|
| 486 |
+
Overview:
|
| 487 |
+
Compute the mean q value of the current node.
|
| 488 |
+
Arguments:
|
| 489 |
+
- isRoot: whether the current node is a root node.
|
| 490 |
+
- parent_q: the q value of the parent node.
|
| 491 |
+
- discount_factor: the discount_factor of reward.
|
| 492 |
+
*/
|
| 493 |
+
float total_unsigned_q = 0.0;
|
| 494 |
+
int total_visits = 0;
|
| 495 |
+
float parent_value_prefix = this->value_prefix;
|
| 496 |
+
for (auto a : this->legal_actions)
|
| 497 |
+
{
|
| 498 |
+
CNode *child = this->get_child(a);
|
| 499 |
+
if (child->visit_count > 0)
|
| 500 |
+
{
|
| 501 |
+
float true_reward = child->value_prefix - parent_value_prefix;
|
| 502 |
+
if (this->is_reset == 1)
|
| 503 |
+
{
|
| 504 |
+
true_reward = child->value_prefix;
|
| 505 |
+
}
|
| 506 |
+
float qsa = true_reward + discount_factor * child->value();
|
| 507 |
+
total_unsigned_q += qsa;
|
| 508 |
+
total_visits += 1;
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
float mean_q = 0.0;
|
| 513 |
+
if (isRoot && total_visits > 0)
|
| 514 |
+
{
|
| 515 |
+
mean_q = (total_unsigned_q) / (total_visits);
|
| 516 |
+
}
|
| 517 |
+
else
|
| 518 |
+
{
|
| 519 |
+
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
|
| 520 |
+
}
|
| 521 |
+
return mean_q;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
void CNode::print_out()
|
| 525 |
+
{
|
| 526 |
+
return;
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
int CNode::expanded()
|
| 530 |
+
{
|
| 531 |
+
/*
|
| 532 |
+
Overview:
|
| 533 |
+
Return whether the current node is expanded.
|
| 534 |
+
*/
|
| 535 |
+
return this->children.size() > 0;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
float CNode::value()
|
| 539 |
+
{
|
| 540 |
+
/*
|
| 541 |
+
Overview:
|
| 542 |
+
Return the real value of the current tree.
|
| 543 |
+
*/
|
| 544 |
+
float true_value = 0.0;
|
| 545 |
+
if (this->visit_count == 0)
|
| 546 |
+
{
|
| 547 |
+
return true_value;
|
| 548 |
+
}
|
| 549 |
+
else
|
| 550 |
+
{
|
| 551 |
+
true_value = this->value_sum / this->visit_count;
|
| 552 |
+
return true_value;
|
| 553 |
+
}
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
std::vector<std::vector<float> > CNode::get_trajectory()
|
| 557 |
+
{
|
| 558 |
+
/*
|
| 559 |
+
Overview:
|
| 560 |
+
Find the current best trajectory starts from the current node.
|
| 561 |
+
Outputs:
|
| 562 |
+
- traj: a vector of node index, which is the current best trajectory from this node.
|
| 563 |
+
*/
|
| 564 |
+
std::vector<CAction> traj;
|
| 565 |
+
|
| 566 |
+
CNode *node = this;
|
| 567 |
+
CAction best_action = node->best_action;
|
| 568 |
+
while (best_action.is_root_action != 1)
|
| 569 |
+
{
|
| 570 |
+
traj.push_back(best_action);
|
| 571 |
+
node = node->get_child(best_action);
|
| 572 |
+
best_action = node->best_action;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
std::vector<std::vector<float> > traj_return;
|
| 576 |
+
for (int i = 0; i < traj.size(); ++i)
|
| 577 |
+
{
|
| 578 |
+
traj_return.push_back(traj[i].value);
|
| 579 |
+
}
|
| 580 |
+
return traj_return;
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
std::vector<int> CNode::get_children_distribution()
|
| 584 |
+
{
|
| 585 |
+
/*
|
| 586 |
+
Overview:
|
| 587 |
+
Get the distribution of child nodes in the format of visit_count.
|
| 588 |
+
Outputs:
|
| 589 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 590 |
+
*/
|
| 591 |
+
std::vector<int> distribution;
|
| 592 |
+
if (this->expanded())
|
| 593 |
+
{
|
| 594 |
+
for (auto a : this->legal_actions)
|
| 595 |
+
{
|
| 596 |
+
CNode *child = this->get_child(a);
|
| 597 |
+
distribution.push_back(child->visit_count);
|
| 598 |
+
}
|
| 599 |
+
}
|
| 600 |
+
return distribution;
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
CNode *CNode::get_child(CAction action)
|
| 604 |
+
{
|
| 605 |
+
/*
|
| 606 |
+
Overview:
|
| 607 |
+
Get the child node corresponding to the input action.
|
| 608 |
+
Arguments:
|
| 609 |
+
- action: the action to get child.
|
| 610 |
+
*/
|
| 611 |
+
return &(this->children[action.get_combined_hash()]);
|
| 612 |
+
// TODO(pu): no hash
|
| 613 |
+
// return &(this->children[action]);
|
| 614 |
+
// return &(this->children[action.value[0]]);
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
//*********************************************************
|
| 618 |
+
|
| 619 |
+
CRoots::CRoots()
|
| 620 |
+
{
|
| 621 |
+
this->root_num = 0;
|
| 622 |
+
this->num_of_sampled_actions = 20;
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
+
CRoots::CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
|
| 626 |
+
{
|
| 627 |
+
/*
|
| 628 |
+
Overview:
|
| 629 |
+
Initialization of CNode with root_num, legal_actions_list, action_space_size, num_of_sampled_actions, continuous_action_space.
|
| 630 |
+
Arguments:
|
| 631 |
+
- root_num: the number of the current root.
|
| 632 |
+
- legal_action_list: the vector of the legal action of this root.
|
| 633 |
+
- action_space_size: the size of action space of the current env.
|
| 634 |
+
- num_of_sampled_actions: the number of sampled actions, i.e. K in the Sampled MuZero papers.
|
| 635 |
+
- continuous_action_space: whether the action space is continous in current env.
|
| 636 |
+
*/
|
| 637 |
+
this->root_num = root_num;
|
| 638 |
+
this->legal_actions_list = legal_actions_list;
|
| 639 |
+
this->continuous_action_space = continuous_action_space;
|
| 640 |
+
|
| 641 |
+
// sampled related core code
|
| 642 |
+
this->num_of_sampled_actions = num_of_sampled_actions;
|
| 643 |
+
this->action_space_size = action_space_size;
|
| 644 |
+
|
| 645 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 646 |
+
{
|
| 647 |
+
if (this->continuous_action_space == true and this->legal_actions_list[0][0] == -1)
|
| 648 |
+
{
|
| 649 |
+
// continous action space
|
| 650 |
+
std::vector<CAction> legal_actions;
|
| 651 |
+
this->roots.push_back(CNode(0, legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space));
|
| 652 |
+
}
|
| 653 |
+
else if (this->continuous_action_space == false or this->legal_actions_list[0][0] == -1)
|
| 654 |
+
{
|
| 655 |
+
// sampled
|
| 656 |
+
// discrete action space without action mask
|
| 657 |
+
std::vector<CAction> legal_actions;
|
| 658 |
+
this->roots.push_back(CNode(0, legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space));
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
else
|
| 662 |
+
{
|
| 663 |
+
// TODO(pu): discrete action space
|
| 664 |
+
std::vector<CAction> c_legal_actions;
|
| 665 |
+
for (int i = 0; i < this->legal_actions_list.size(); ++i)
|
| 666 |
+
{
|
| 667 |
+
CAction c_legal_action = CAction(legal_actions_list[i], 0);
|
| 668 |
+
c_legal_actions.push_back(c_legal_action);
|
| 669 |
+
}
|
| 670 |
+
this->roots.push_back(CNode(0, c_legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space));
|
| 671 |
+
}
|
| 672 |
+
}
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
CRoots::~CRoots() {}
|
| 676 |
+
|
| 677 |
+
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 678 |
+
{
|
| 679 |
+
/*
|
| 680 |
+
Overview:
|
| 681 |
+
Expand the roots and add noises.
|
| 682 |
+
Arguments:
|
| 683 |
+
- root_noise_weight: the exploration fraction of roots
|
| 684 |
+
- noises: the vector of noise add to the roots.
|
| 685 |
+
- value_prefixs: the vector of value prefixs of each root.
|
| 686 |
+
- policies: the vector of policy logits of each root.
|
| 687 |
+
- to_play_batch: the vector of the player side of each root.
|
| 688 |
+
*/
|
| 689 |
+
|
| 690 |
+
// sampled related core code
|
| 691 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 692 |
+
{
|
| 693 |
+
this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
|
| 694 |
+
this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
|
| 695 |
+
this->roots[i].visit_count += 1;
|
| 696 |
+
}
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 700 |
+
{
|
| 701 |
+
/*
|
| 702 |
+
Overview:
|
| 703 |
+
Expand the roots without noise.
|
| 704 |
+
Arguments:
|
| 705 |
+
- value_prefixs: the vector of value prefixs of each root.
|
| 706 |
+
- policies: the vector of policy logits of each root.
|
| 707 |
+
- to_play_batch: the vector of the player side of each root.
|
| 708 |
+
*/
|
| 709 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 710 |
+
{
|
| 711 |
+
this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
|
| 712 |
+
|
| 713 |
+
this->roots[i].visit_count += 1;
|
| 714 |
+
}
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
void CRoots::clear()
|
| 718 |
+
{
|
| 719 |
+
this->roots.clear();
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
std::vector<std::vector<std::vector<float> > > CRoots::get_trajectories()
|
| 723 |
+
{
|
| 724 |
+
/*
|
| 725 |
+
Overview:
|
| 726 |
+
Find the current best trajectory starts from each root.
|
| 727 |
+
Outputs:
|
| 728 |
+
- traj: a vector of node index, which is the current best trajectory from each root.
|
| 729 |
+
*/
|
| 730 |
+
std::vector<std::vector<std::vector<float> > > trajs;
|
| 731 |
+
trajs.reserve(this->root_num);
|
| 732 |
+
|
| 733 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 734 |
+
{
|
| 735 |
+
trajs.push_back(this->roots[i].get_trajectory());
|
| 736 |
+
}
|
| 737 |
+
return trajs;
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
std::vector<std::vector<int> > CRoots::get_distributions()
|
| 741 |
+
{
|
| 742 |
+
/*
|
| 743 |
+
Overview:
|
| 744 |
+
Get the children distribution of each root.
|
| 745 |
+
Outputs:
|
| 746 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 747 |
+
*/
|
| 748 |
+
std::vector<std::vector<int> > distributions;
|
| 749 |
+
distributions.reserve(this->root_num);
|
| 750 |
+
|
| 751 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 752 |
+
{
|
| 753 |
+
distributions.push_back(this->roots[i].get_children_distribution());
|
| 754 |
+
}
|
| 755 |
+
return distributions;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
// sampled related core code
|
| 759 |
+
std::vector<std::vector<std::vector<float> > > CRoots::get_sampled_actions()
|
| 760 |
+
{
|
| 761 |
+
/*
|
| 762 |
+
Overview:
|
| 763 |
+
Get the sampled_actions of each root.
|
| 764 |
+
Outputs:
|
| 765 |
+
- python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3,
|
| 766 |
+
python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]].
|
| 767 |
+
*/
|
| 768 |
+
std::vector<std::vector<CAction> > sampled_actions;
|
| 769 |
+
std::vector<std::vector<std::vector<float> > > python_sampled_actions;
|
| 770 |
+
|
| 771 |
+
// sampled_actions.reserve(this->root_num);
|
| 772 |
+
|
| 773 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 774 |
+
{
|
| 775 |
+
std::vector<CAction> sampled_action;
|
| 776 |
+
sampled_action = this->roots[i].legal_actions;
|
| 777 |
+
std::vector<std::vector<float> > python_sampled_action;
|
| 778 |
+
|
| 779 |
+
for (int j = 0; j < this->roots[i].legal_actions.size(); ++j)
|
| 780 |
+
{
|
| 781 |
+
python_sampled_action.push_back(sampled_action[j].value);
|
| 782 |
+
}
|
| 783 |
+
python_sampled_actions.push_back(python_sampled_action);
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
return python_sampled_actions;
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
std::vector<float> CRoots::get_values()
|
| 790 |
+
{
|
| 791 |
+
/*
|
| 792 |
+
Overview:
|
| 793 |
+
Return the estimated value of each root.
|
| 794 |
+
*/
|
| 795 |
+
std::vector<float> values;
|
| 796 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 797 |
+
{
|
| 798 |
+
values.push_back(this->roots[i].value());
|
| 799 |
+
}
|
| 800 |
+
return values;
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
//*********************************************************
|
| 804 |
+
//
|
| 805 |
+
void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
|
| 806 |
+
{
|
| 807 |
+
/*
|
| 808 |
+
Overview:
|
| 809 |
+
Update the q value of the root and its child nodes.
|
| 810 |
+
Arguments:
|
| 811 |
+
- root: the root that update q value from.
|
| 812 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 813 |
+
- discount_factor: the discount factor of reward.
|
| 814 |
+
- players: the number of players.
|
| 815 |
+
*/
|
| 816 |
+
std::stack<CNode *> node_stack;
|
| 817 |
+
node_stack.push(root);
|
| 818 |
+
float parent_value_prefix = 0.0;
|
| 819 |
+
int is_reset = 0;
|
| 820 |
+
while (node_stack.size() > 0)
|
| 821 |
+
{
|
| 822 |
+
CNode *node = node_stack.top();
|
| 823 |
+
node_stack.pop();
|
| 824 |
+
|
| 825 |
+
if (node != root)
|
| 826 |
+
{
|
| 827 |
+
// NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 828 |
+
// but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 829 |
+
// true_reward = node.value_prefix - (- parent_value_prefix)
|
| 830 |
+
float true_reward = node->value_prefix - node->parent_value_prefix;
|
| 831 |
+
|
| 832 |
+
if (is_reset == 1)
|
| 833 |
+
{
|
| 834 |
+
true_reward = node->value_prefix;
|
| 835 |
+
}
|
| 836 |
+
float qsa;
|
| 837 |
+
if (players == 1)
|
| 838 |
+
qsa = true_reward + discount_factor * node->value();
|
| 839 |
+
else if (players == 2)
|
| 840 |
+
// TODO(pu): why only the last reward multiply the discount_factor?
|
| 841 |
+
qsa = true_reward + discount_factor * (-1) * node->value();
|
| 842 |
+
|
| 843 |
+
min_max_stats.update(qsa);
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
for (auto a : node->legal_actions)
|
| 847 |
+
{
|
| 848 |
+
CNode *child = node->get_child(a);
|
| 849 |
+
if (child->expanded())
|
| 850 |
+
{
|
| 851 |
+
child->parent_value_prefix = node->value_prefix;
|
| 852 |
+
node_stack.push(child);
|
| 853 |
+
}
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
is_reset = node->is_reset;
|
| 857 |
+
}
|
| 858 |
+
}
|
| 859 |
+
|
| 860 |
+
void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
|
| 861 |
+
{
|
| 862 |
+
/*
|
| 863 |
+
Overview:
|
| 864 |
+
Update the value sum and visit count of nodes along the search path.
|
| 865 |
+
Arguments:
|
| 866 |
+
- search_path: a vector of nodes on the search path.
|
| 867 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 868 |
+
- to_play: which player to play the game in the current node.
|
| 869 |
+
- value: the value to propagate along the search path.
|
| 870 |
+
- discount_factor: the discount factor of reward.
|
| 871 |
+
*/
|
| 872 |
+
assert(to_play == -1 || to_play == 1 || to_play == 2);
|
| 873 |
+
if (to_play == -1)
|
| 874 |
+
{
|
| 875 |
+
// for play-with-bot-mode
|
| 876 |
+
float bootstrap_value = value;
|
| 877 |
+
int path_len = search_path.size();
|
| 878 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 879 |
+
{
|
| 880 |
+
CNode *node = search_path[i];
|
| 881 |
+
node->value_sum += bootstrap_value;
|
| 882 |
+
node->visit_count += 1;
|
| 883 |
+
|
| 884 |
+
float parent_value_prefix = 0.0;
|
| 885 |
+
int is_reset = 0;
|
| 886 |
+
if (i >= 1)
|
| 887 |
+
{
|
| 888 |
+
CNode *parent = search_path[i - 1];
|
| 889 |
+
parent_value_prefix = parent->value_prefix;
|
| 890 |
+
is_reset = parent->is_reset;
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
float true_reward = node->value_prefix - parent_value_prefix;
|
| 894 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 895 |
+
|
| 896 |
+
if (is_reset == 1)
|
| 897 |
+
{
|
| 898 |
+
// parent is reset.
|
| 899 |
+
true_reward = node->value_prefix;
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 903 |
+
}
|
| 904 |
+
}
|
| 905 |
+
else
|
| 906 |
+
{
|
| 907 |
+
// for self-play-mode
|
| 908 |
+
float bootstrap_value = value;
|
| 909 |
+
int path_len = search_path.size();
|
| 910 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 911 |
+
{
|
| 912 |
+
CNode *node = search_path[i];
|
| 913 |
+
if (node->to_play == to_play)
|
| 914 |
+
node->value_sum += bootstrap_value;
|
| 915 |
+
else
|
| 916 |
+
node->value_sum += -bootstrap_value;
|
| 917 |
+
node->visit_count += 1;
|
| 918 |
+
|
| 919 |
+
float parent_value_prefix = 0.0;
|
| 920 |
+
int is_reset = 0;
|
| 921 |
+
if (i >= 1)
|
| 922 |
+
{
|
| 923 |
+
CNode *parent = search_path[i - 1];
|
| 924 |
+
parent_value_prefix = parent->value_prefix;
|
| 925 |
+
is_reset = parent->is_reset;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
// NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 929 |
+
// but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 930 |
+
float true_reward = node->value_prefix - parent_value_prefix;
|
| 931 |
+
|
| 932 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 933 |
+
|
| 934 |
+
if (is_reset == 1)
|
| 935 |
+
{
|
| 936 |
+
// parent is reset.
|
| 937 |
+
true_reward = node->value_prefix;
|
| 938 |
+
}
|
| 939 |
+
if (node->to_play == to_play)
|
| 940 |
+
bootstrap_value = -true_reward + discount_factor * bootstrap_value;
|
| 941 |
+
else
|
| 942 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 943 |
+
}
|
| 944 |
+
}
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch)
|
| 948 |
+
{
|
| 949 |
+
/*
|
| 950 |
+
Overview:
|
| 951 |
+
Expand the nodes along the search path and update the infos.
|
| 952 |
+
Arguments:
|
| 953 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path.
|
| 954 |
+
- discount_factor: the discount factor of reward.
|
| 955 |
+
- value_prefixs: the value prefixs of nodes along the search path.
|
| 956 |
+
- values: the values to propagate along the search path.
|
| 957 |
+
- policies: the policy logits of nodes along the search path.
|
| 958 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 959 |
+
- results: the search results.
|
| 960 |
+
- is_reset_list: the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset.
|
| 961 |
+
- to_play_batch: the batch of which player is playing on this node.
|
| 962 |
+
*/
|
| 963 |
+
for (int i = 0; i < results.num; ++i)
|
| 964 |
+
{
|
| 965 |
+
results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]);
|
| 966 |
+
// reset
|
| 967 |
+
results.nodes[i]->is_reset = is_reset_list[i];
|
| 968 |
+
|
| 969 |
+
cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
CAction cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, bool continuous_action_space)
|
| 974 |
+
{
|
| 975 |
+
/*
|
| 976 |
+
Overview:
|
| 977 |
+
Select the child node of the roots according to ucb scores.
|
| 978 |
+
Arguments:
|
| 979 |
+
- root: the roots to select the child node.
|
| 980 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 981 |
+
- pb_c_base: constants c2 in muzero.
|
| 982 |
+
- pb_c_init: constants c1 in muzero.
|
| 983 |
+
- disount_factor: the discount factor of reward.
|
| 984 |
+
- mean_q: the mean q value of the parent node.
|
| 985 |
+
- players: the number of players.
|
| 986 |
+
- continuous_action_space: whether the action space is continous in current env.
|
| 987 |
+
Outputs:
|
| 988 |
+
- action: the action to select.
|
| 989 |
+
*/
|
| 990 |
+
// sampled related core code
|
| 991 |
+
// TODO(pu): Progressive widening (See https://hal.archives-ouvertes.fr/hal-00542673v2/document)
|
| 992 |
+
float max_score = FLOAT_MIN;
|
| 993 |
+
const float epsilon = 0.000001;
|
| 994 |
+
std::vector<CAction> max_index_lst;
|
| 995 |
+
for (auto a : root->legal_actions)
|
| 996 |
+
{
|
| 997 |
+
|
| 998 |
+
CNode *child = root->get_child(a);
|
| 999 |
+
// sampled related core code
|
| 1000 |
+
float temp_score = cucb_score(root, child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players, continuous_action_space);
|
| 1001 |
+
|
| 1002 |
+
if (max_score < temp_score)
|
| 1003 |
+
{
|
| 1004 |
+
max_score = temp_score;
|
| 1005 |
+
|
| 1006 |
+
max_index_lst.clear();
|
| 1007 |
+
max_index_lst.push_back(a);
|
| 1008 |
+
}
|
| 1009 |
+
else if (temp_score >= max_score - epsilon)
|
| 1010 |
+
{
|
| 1011 |
+
max_index_lst.push_back(a);
|
| 1012 |
+
}
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
// python code: int action = 0;
|
| 1016 |
+
CAction action;
|
| 1017 |
+
if (max_index_lst.size() > 0)
|
| 1018 |
+
{
|
| 1019 |
+
int rand_index = rand() % max_index_lst.size();
|
| 1020 |
+
action = max_index_lst[rand_index];
|
| 1021 |
+
}
|
| 1022 |
+
return action;
|
| 1023 |
+
}
|
| 1024 |
+
|
| 1025 |
+
// sampled related core code
|
| 1026 |
+
float cucb_score(CNode *parent, CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players, bool continuous_action_space)
|
| 1027 |
+
{
|
| 1028 |
+
/*
|
| 1029 |
+
Overview:
|
| 1030 |
+
Compute the ucb score of the child.
|
| 1031 |
+
Arguments:
|
| 1032 |
+
- child: the child node to compute ucb score.
|
| 1033 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 1034 |
+
- parent_mean_q: the mean q value of the parent node.
|
| 1035 |
+
- is_reset: whether the value prefix needs to be reset.
|
| 1036 |
+
- total_children_visit_counts: the total visit counts of the child nodes of the parent node.
|
| 1037 |
+
- parent_value_prefix: the value prefix of parent node.
|
| 1038 |
+
- pb_c_base: constants c2 in muzero.
|
| 1039 |
+
- pb_c_init: constants c1 in muzero.
|
| 1040 |
+
- disount_factor: the discount factor of reward.
|
| 1041 |
+
- players: the number of players.
|
| 1042 |
+
- continuous_action_space: whether the action space is continous in current env.
|
| 1043 |
+
Outputs:
|
| 1044 |
+
- ucb_value: the ucb score of the child.
|
| 1045 |
+
*/
|
| 1046 |
+
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
|
| 1047 |
+
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
|
| 1048 |
+
pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
|
| 1049 |
+
|
| 1050 |
+
// prior_score = pb_c * child->prior;
|
| 1051 |
+
|
| 1052 |
+
// sampled related core code
|
| 1053 |
+
// TODO(pu): empirical distribution
|
| 1054 |
+
std::string empirical_distribution_type = "density";
|
| 1055 |
+
if (empirical_distribution_type.compare("density"))
|
| 1056 |
+
{
|
| 1057 |
+
if (continuous_action_space == true)
|
| 1058 |
+
{
|
| 1059 |
+
float empirical_prob_sum = 0;
|
| 1060 |
+
for (int i = 0; i < parent->children.size(); ++i)
|
| 1061 |
+
{
|
| 1062 |
+
empirical_prob_sum += exp(parent->get_child(parent->legal_actions[i])->prior);
|
| 1063 |
+
}
|
| 1064 |
+
prior_score = pb_c * exp(child->prior) / (empirical_prob_sum + 1e-6);
|
| 1065 |
+
}
|
| 1066 |
+
else
|
| 1067 |
+
{
|
| 1068 |
+
float empirical_prob_sum = 0;
|
| 1069 |
+
for (int i = 0; i < parent->children.size(); ++i)
|
| 1070 |
+
{
|
| 1071 |
+
empirical_prob_sum += parent->get_child(parent->legal_actions[i])->prior;
|
| 1072 |
+
}
|
| 1073 |
+
prior_score = pb_c * child->prior / (empirical_prob_sum + 1e-6);
|
| 1074 |
+
}
|
| 1075 |
+
}
|
| 1076 |
+
else if (empirical_distribution_type.compare("uniform"))
|
| 1077 |
+
{
|
| 1078 |
+
prior_score = pb_c * 1 / parent->children.size();
|
| 1079 |
+
}
|
| 1080 |
+
// sampled related core code
|
| 1081 |
+
if (child->visit_count == 0)
|
| 1082 |
+
{
|
| 1083 |
+
value_score = parent_mean_q;
|
| 1084 |
+
}
|
| 1085 |
+
else
|
| 1086 |
+
{
|
| 1087 |
+
float true_reward = child->value_prefix - parent_value_prefix;
|
| 1088 |
+
if (is_reset == 1)
|
| 1089 |
+
{
|
| 1090 |
+
true_reward = child->value_prefix;
|
| 1091 |
+
}
|
| 1092 |
+
|
| 1093 |
+
if (players == 1)
|
| 1094 |
+
value_score = true_reward + discount_factor * child->value();
|
| 1095 |
+
else if (players == 2)
|
| 1096 |
+
value_score = true_reward + discount_factor * (-child->value());
|
| 1097 |
+
}
|
| 1098 |
+
|
| 1099 |
+
value_score = min_max_stats.normalize(value_score);
|
| 1100 |
+
|
| 1101 |
+
if (value_score < 0)
|
| 1102 |
+
value_score = 0;
|
| 1103 |
+
if (value_score > 1)
|
| 1104 |
+
value_score = 1;
|
| 1105 |
+
|
| 1106 |
+
float ucb_value = prior_score + value_score;
|
| 1107 |
+
return ucb_value;
|
| 1108 |
+
}
|
| 1109 |
+
|
| 1110 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch, bool continuous_action_space)
|
| 1111 |
+
{
|
| 1112 |
+
/*
|
| 1113 |
+
Overview:
|
| 1114 |
+
Search node path from the roots.
|
| 1115 |
+
Arguments:
|
| 1116 |
+
- roots: the roots that search from.
|
| 1117 |
+
- pb_c_base: constants c2 in muzero.
|
| 1118 |
+
- pb_c_init: constants c1 in muzero.
|
| 1119 |
+
- disount_factor: the discount factor of reward.
|
| 1120 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 1121 |
+
- results: the search results.
|
| 1122 |
+
- virtual_to_play_batch: the batch of which player is playing on this node.
|
| 1123 |
+
- continuous_action_space: whether the action space is continous in current env.
|
| 1124 |
+
*/
|
| 1125 |
+
// set seed
|
| 1126 |
+
get_time_and_set_rand_seed();
|
| 1127 |
+
|
| 1128 |
+
std::vector<float> null_value;
|
| 1129 |
+
for (int i = 0; i < 1; ++i)
|
| 1130 |
+
{
|
| 1131 |
+
null_value.push_back(i + 0.1);
|
| 1132 |
+
}
|
| 1133 |
+
// CAction last_action = CAction(null_value, 1);
|
| 1134 |
+
std::vector<float> last_action;
|
| 1135 |
+
float parent_q = 0.0;
|
| 1136 |
+
results.search_lens = std::vector<int>();
|
| 1137 |
+
|
| 1138 |
+
int players = 0;
|
| 1139 |
+
int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
|
| 1140 |
+
if (largest_element == -1)
|
| 1141 |
+
players = 1;
|
| 1142 |
+
else
|
| 1143 |
+
players = 2;
|
| 1144 |
+
|
| 1145 |
+
for (int i = 0; i < results.num; ++i)
|
| 1146 |
+
{
|
| 1147 |
+
CNode *node = &(roots->roots[i]);
|
| 1148 |
+
int is_root = 1;
|
| 1149 |
+
int search_len = 0;
|
| 1150 |
+
results.search_paths[i].push_back(node);
|
| 1151 |
+
|
| 1152 |
+
while (node->expanded())
|
| 1153 |
+
{
|
| 1154 |
+
float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
|
| 1155 |
+
is_root = 0;
|
| 1156 |
+
parent_q = mean_q;
|
| 1157 |
+
|
| 1158 |
+
CAction action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, continuous_action_space);
|
| 1159 |
+
if (players > 1)
|
| 1160 |
+
{
|
| 1161 |
+
assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
|
| 1162 |
+
if (virtual_to_play_batch[i] == 1)
|
| 1163 |
+
virtual_to_play_batch[i] = 2;
|
| 1164 |
+
else
|
| 1165 |
+
virtual_to_play_batch[i] = 1;
|
| 1166 |
+
}
|
| 1167 |
+
|
| 1168 |
+
node->best_action = action; // CAction
|
| 1169 |
+
// next
|
| 1170 |
+
node = node->get_child(action);
|
| 1171 |
+
last_action = action.value;
|
| 1172 |
+
|
| 1173 |
+
results.search_paths[i].push_back(node);
|
| 1174 |
+
search_len += 1;
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
|
| 1178 |
+
|
| 1179 |
+
results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
|
| 1180 |
+
results.latent_state_index_in_batch.push_back(parent->batch_index);
|
| 1181 |
+
|
| 1182 |
+
results.last_actions.push_back(last_action);
|
| 1183 |
+
results.search_lens.push_back(search_len);
|
| 1184 |
+
results.nodes.push_back(node);
|
| 1185 |
+
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
|
| 1186 |
+
}
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
}
|
LightZero/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.h
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#ifndef CNODE_H
|
| 4 |
+
#define CNODE_H
|
| 5 |
+
|
| 6 |
+
#include "../../common_lib/cminimax.h"
|
| 7 |
+
#include <math.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <stack>
|
| 10 |
+
#include <stdlib.h>
|
| 11 |
+
#include <time.h>
|
| 12 |
+
#include <cmath>
|
| 13 |
+
#include <sys/timeb.h>
|
| 14 |
+
#include <time.h>
|
| 15 |
+
#include <map>
|
| 16 |
+
|
| 17 |
+
const int DEBUG_MODE = 0;
|
| 18 |
+
|
| 19 |
+
namespace tree
|
| 20 |
+
{
|
| 21 |
+
// sampled related core code
|
| 22 |
+
class CAction
|
| 23 |
+
{
|
| 24 |
+
public:
|
| 25 |
+
std::vector<float> value;
|
| 26 |
+
std::vector<size_t> hash;
|
| 27 |
+
int is_root_action;
|
| 28 |
+
|
| 29 |
+
CAction();
|
| 30 |
+
CAction(std::vector<float> value, int is_root_action);
|
| 31 |
+
~CAction();
|
| 32 |
+
|
| 33 |
+
std::vector<size_t> get_hash(void);
|
| 34 |
+
std::size_t get_combined_hash(void);
|
| 35 |
+
};
|
| 36 |
+
|
| 37 |
+
class CNode
|
| 38 |
+
{
|
| 39 |
+
public:
|
| 40 |
+
int visit_count, to_play, current_latent_state_index, batch_index, is_reset, action_space_size;
|
| 41 |
+
// sampled related core code
|
| 42 |
+
CAction best_action;
|
| 43 |
+
int num_of_sampled_actions;
|
| 44 |
+
float value_prefix, prior, value_sum;
|
| 45 |
+
float parent_value_prefix;
|
| 46 |
+
bool continuous_action_space;
|
| 47 |
+
std::vector<int> children_index;
|
| 48 |
+
std::map<size_t, CNode> children;
|
| 49 |
+
|
| 50 |
+
std::vector<CAction> legal_actions;
|
| 51 |
+
|
| 52 |
+
CNode();
|
| 53 |
+
// sampled related core code
|
| 54 |
+
CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space);
|
| 55 |
+
~CNode();
|
| 56 |
+
|
| 57 |
+
void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits);
|
| 58 |
+
void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
|
| 59 |
+
float compute_mean_q(int isRoot, float parent_q, float discount_factor);
|
| 60 |
+
void print_out();
|
| 61 |
+
|
| 62 |
+
int expanded();
|
| 63 |
+
|
| 64 |
+
float value();
|
| 65 |
+
|
| 66 |
+
// sampled related core code
|
| 67 |
+
std::vector<std::vector<float> > get_trajectory();
|
| 68 |
+
std::vector<int> get_children_distribution();
|
| 69 |
+
CNode *get_child(CAction action);
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
class CRoots
|
| 73 |
+
{
|
| 74 |
+
public:
|
| 75 |
+
int root_num;
|
| 76 |
+
int num_of_sampled_actions;
|
| 77 |
+
int action_space_size;
|
| 78 |
+
std::vector<CNode> roots;
|
| 79 |
+
std::vector<std::vector<float> > legal_actions_list;
|
| 80 |
+
bool continuous_action_space;
|
| 81 |
+
|
| 82 |
+
CRoots();
|
| 83 |
+
CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space);
|
| 84 |
+
~CRoots();
|
| 85 |
+
|
| 86 |
+
void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 87 |
+
void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 88 |
+
void clear();
|
| 89 |
+
// sampled related core code
|
| 90 |
+
std::vector<std::vector<std::vector<float> > > get_trajectories();
|
| 91 |
+
std::vector<std::vector<std::vector<float> > > get_sampled_actions();
|
| 92 |
+
|
| 93 |
+
std::vector<std::vector<int> > get_distributions();
|
| 94 |
+
|
| 95 |
+
std::vector<float> get_values();
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
class CSearchResults
|
| 99 |
+
{
|
| 100 |
+
public:
|
| 101 |
+
int num;
|
| 102 |
+
std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, search_lens;
|
| 103 |
+
std::vector<int> virtual_to_play_batchs;
|
| 104 |
+
std::vector<std::vector<float> > last_actions;
|
| 105 |
+
|
| 106 |
+
std::vector<CNode *> nodes;
|
| 107 |
+
std::vector<std::vector<CNode *> > search_paths;
|
| 108 |
+
|
| 109 |
+
CSearchResults();
|
| 110 |
+
CSearchResults(int num);
|
| 111 |
+
~CSearchResults();
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
//*********************************************************
|
| 115 |
+
void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
|
| 116 |
+
void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
|
| 117 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch);
|
| 118 |
+
CAction cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, bool continuous_action_space);
|
| 119 |
+
float cucb_score(CNode *parent, CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players, bool continuous_action_space);
|
| 120 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch, bool continuous_action_space);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
#endif
|
LightZero/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include "cnode.h"
|
| 5 |
+
#include <algorithm>
|
| 6 |
+
#include <map>
|
| 7 |
+
#include <cassert>
|
| 8 |
+
#include <numeric>
|
| 9 |
+
#include <iostream>
|
| 10 |
+
#include <vector>
|
| 11 |
+
#include <map>
|
| 12 |
+
#include <random>
|
| 13 |
+
#include <algorithm>
|
| 14 |
+
#include <iterator>
|
| 15 |
+
|
| 16 |
+
#ifdef _WIN32
|
| 17 |
+
#include "..\..\common_lib\utils.cpp"
|
| 18 |
+
#else
|
| 19 |
+
#include "../../common_lib/utils.cpp"
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
namespace tree
|
| 24 |
+
{
|
| 25 |
+
|
| 26 |
+
CSearchResults::CSearchResults()
|
| 27 |
+
{
|
| 28 |
+
/*
|
| 29 |
+
Overview:
|
| 30 |
+
Initialization of CSearchResults, the default result number is set to 0.
|
| 31 |
+
*/
|
| 32 |
+
this->num = 0;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
CSearchResults::CSearchResults(int num)
|
| 36 |
+
{
|
| 37 |
+
/*
|
| 38 |
+
Overview:
|
| 39 |
+
Initialization of CSearchResults with result number.
|
| 40 |
+
*/
|
| 41 |
+
this->num = num;
|
| 42 |
+
for (int i = 0; i < num; ++i)
|
| 43 |
+
{
|
| 44 |
+
this->search_paths.push_back(std::vector<CNode *>());
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
CSearchResults::~CSearchResults() {}
|
| 49 |
+
|
| 50 |
+
//*********************************************************
|
| 51 |
+
|
| 52 |
+
CNode::CNode()
|
| 53 |
+
{
|
| 54 |
+
/*
|
| 55 |
+
Overview:
|
| 56 |
+
Initialization of CNode.
|
| 57 |
+
*/
|
| 58 |
+
this->prior = 0;
|
| 59 |
+
this->legal_actions = legal_actions;
|
| 60 |
+
|
| 61 |
+
this->visit_count = 0;
|
| 62 |
+
this->value_sum = 0;
|
| 63 |
+
this->best_action = -1;
|
| 64 |
+
this->to_play = 0;
|
| 65 |
+
this->reward = 0.0;
|
| 66 |
+
this->is_chance = false;
|
| 67 |
+
this->chance_space_size= 2;
|
| 68 |
+
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
CNode::CNode(float prior, std::vector<int> &legal_actions, bool is_chance, int chance_space_size)
|
| 72 |
+
{
|
| 73 |
+
/*
|
| 74 |
+
Overview:
|
| 75 |
+
Initialization of CNode with prior value and legal actions.
|
| 76 |
+
Arguments:
|
| 77 |
+
- prior: the prior value of this node.
|
| 78 |
+
- legal_actions: a vector of legal actions of this node.
|
| 79 |
+
*/
|
| 80 |
+
this->prior = prior;
|
| 81 |
+
this->legal_actions = legal_actions;
|
| 82 |
+
|
| 83 |
+
this->visit_count = 0;
|
| 84 |
+
this->value_sum = 0;
|
| 85 |
+
this->best_action = -1;
|
| 86 |
+
this->to_play = 0;
|
| 87 |
+
this->current_latent_state_index = -1;
|
| 88 |
+
this->batch_index = -1;
|
| 89 |
+
this->is_chance = is_chance;
|
| 90 |
+
this->chance_space_size = chance_space_size;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
CNode::~CNode() {}
|
| 94 |
+
|
| 95 |
+
void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits, bool child_is_chance)
|
| 96 |
+
{
|
| 97 |
+
/*
|
| 98 |
+
Overview:
|
| 99 |
+
Expand the child nodes of the current node.
|
| 100 |
+
Arguments:
|
| 101 |
+
- to_play: which player to play the game in the current node.
|
| 102 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path of the current node.
|
| 103 |
+
- batch_index: The index of latent state of the leaf node in the search path of the current node.
|
| 104 |
+
- reward: the reward of the current node.
|
| 105 |
+
- policy_logits: the logit of the child nodes.
|
| 106 |
+
*/
|
| 107 |
+
this->to_play = to_play;
|
| 108 |
+
this->current_latent_state_index = current_latent_state_index;
|
| 109 |
+
this->batch_index = batch_index;
|
| 110 |
+
this->reward = reward;
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
// assert((this->is_chance != child_is_chance) && "is_chance and child_is_chance should be different");
|
| 114 |
+
|
| 115 |
+
if(this->is_chance == true){
|
| 116 |
+
child_is_chance = false;
|
| 117 |
+
this->reward = 0.0;
|
| 118 |
+
}
|
| 119 |
+
else{
|
| 120 |
+
child_is_chance = true;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
int action_num = policy_logits.size();
|
| 124 |
+
if (this->legal_actions.size() == 0)
|
| 125 |
+
{
|
| 126 |
+
for (int i = 0; i < action_num; ++i)
|
| 127 |
+
{
|
| 128 |
+
this->legal_actions.push_back(i);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
float temp_policy;
|
| 133 |
+
float policy_sum = 0.0;
|
| 134 |
+
|
| 135 |
+
#ifdef _WIN32
|
| 136 |
+
// 创建动态数组
|
| 137 |
+
float* policy = new float[action_num];
|
| 138 |
+
#else
|
| 139 |
+
float policy[action_num];
|
| 140 |
+
#endif
|
| 141 |
+
|
| 142 |
+
float policy_max = FLOAT_MIN;
|
| 143 |
+
for (auto a : this->legal_actions)
|
| 144 |
+
{
|
| 145 |
+
if (policy_max < policy_logits[a])
|
| 146 |
+
{
|
| 147 |
+
policy_max = policy_logits[a];
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
for (auto a : this->legal_actions)
|
| 152 |
+
{
|
| 153 |
+
temp_policy = exp(policy_logits[a] - policy_max);
|
| 154 |
+
policy_sum += temp_policy;
|
| 155 |
+
policy[a] = temp_policy;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
float prior;
|
| 159 |
+
for (auto a : this->legal_actions)
|
| 160 |
+
{
|
| 161 |
+
prior = policy[a] / policy_sum;
|
| 162 |
+
std::vector<int> tmp_empty;
|
| 163 |
+
this->children[a] = CNode(prior, tmp_empty, child_is_chance, this->chance_space_size); // only for muzero/efficient zero, not support alphazero
|
| 164 |
+
// this->children[a] = CNode(prior, tmp_empty, is_chance = child_is_chance); // only for muzero/efficient zero, not support alphazero
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
#ifdef _WIN32
|
| 168 |
+
// 释放数组内存
|
| 169 |
+
delete[] policy;
|
| 170 |
+
#else
|
| 171 |
+
#endif
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
|
| 175 |
+
{
|
| 176 |
+
/*
|
| 177 |
+
Overview:
|
| 178 |
+
Add a noise to the prior of the child nodes.
|
| 179 |
+
Arguments:
|
| 180 |
+
- exploration_fraction: the fraction to add noise.
|
| 181 |
+
- noises: the vector of noises added to each child node.
|
| 182 |
+
*/
|
| 183 |
+
float noise, prior;
|
| 184 |
+
for (int i = 0; i < this->legal_actions.size(); ++i)
|
| 185 |
+
{
|
| 186 |
+
noise = noises[i];
|
| 187 |
+
CNode *child = this->get_child(this->legal_actions[i]);
|
| 188 |
+
|
| 189 |
+
prior = child->prior;
|
| 190 |
+
child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
|
| 195 |
+
{
|
| 196 |
+
/*
|
| 197 |
+
Overview:
|
| 198 |
+
Compute the mean q value of the current node.
|
| 199 |
+
Arguments:
|
| 200 |
+
- isRoot: whether the current node is a root node.
|
| 201 |
+
- parent_q: the q value of the parent node.
|
| 202 |
+
- discount_factor: the discount_factor of reward.
|
| 203 |
+
*/
|
| 204 |
+
float total_unsigned_q = 0.0;
|
| 205 |
+
int total_visits = 0;
|
| 206 |
+
for (auto a : this->legal_actions)
|
| 207 |
+
{
|
| 208 |
+
CNode *child = this->get_child(a);
|
| 209 |
+
if (child->visit_count > 0)
|
| 210 |
+
{
|
| 211 |
+
float true_reward = child->reward;
|
| 212 |
+
float qsa = true_reward + discount_factor * child->value();
|
| 213 |
+
total_unsigned_q += qsa;
|
| 214 |
+
total_visits += 1;
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
float mean_q = 0.0;
|
| 219 |
+
if (isRoot && total_visits > 0)
|
| 220 |
+
{
|
| 221 |
+
mean_q = (total_unsigned_q) / (total_visits);
|
| 222 |
+
}
|
| 223 |
+
else
|
| 224 |
+
{
|
| 225 |
+
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
|
| 226 |
+
}
|
| 227 |
+
return mean_q;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
void CNode::print_out()
|
| 231 |
+
{
|
| 232 |
+
return;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
int CNode::expanded()
|
| 236 |
+
{
|
| 237 |
+
/*
|
| 238 |
+
Overview:
|
| 239 |
+
Return whether the current node is expanded.
|
| 240 |
+
*/
|
| 241 |
+
return this->children.size() > 0;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
float CNode::value()
|
| 245 |
+
{
|
| 246 |
+
/*
|
| 247 |
+
Overview:
|
| 248 |
+
Return the real value of the current tree.
|
| 249 |
+
*/
|
| 250 |
+
float true_value = 0.0;
|
| 251 |
+
if (this->visit_count == 0)
|
| 252 |
+
{
|
| 253 |
+
return true_value;
|
| 254 |
+
}
|
| 255 |
+
else
|
| 256 |
+
{
|
| 257 |
+
true_value = this->value_sum / this->visit_count;
|
| 258 |
+
return true_value;
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
std::vector<int> CNode::get_trajectory()
|
| 263 |
+
{
|
| 264 |
+
/*
|
| 265 |
+
Overview:
|
| 266 |
+
Find the current best trajectory starts from the current node.
|
| 267 |
+
Outputs:
|
| 268 |
+
- traj: a vector of node index, which is the current best trajectory from this node.
|
| 269 |
+
*/
|
| 270 |
+
std::vector<int> traj;
|
| 271 |
+
|
| 272 |
+
CNode *node = this;
|
| 273 |
+
int best_action = node->best_action;
|
| 274 |
+
while (best_action >= 0)
|
| 275 |
+
{
|
| 276 |
+
traj.push_back(best_action);
|
| 277 |
+
|
| 278 |
+
node = node->get_child(best_action);
|
| 279 |
+
best_action = node->best_action;
|
| 280 |
+
}
|
| 281 |
+
return traj;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
std::vector<int> CNode::get_children_distribution()
|
| 285 |
+
{
|
| 286 |
+
/*
|
| 287 |
+
Overview:
|
| 288 |
+
Get the distribution of child nodes in the format of visit_count.
|
| 289 |
+
Outputs:
|
| 290 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 291 |
+
*/
|
| 292 |
+
std::vector<int> distribution;
|
| 293 |
+
if (this->expanded())
|
| 294 |
+
{
|
| 295 |
+
for (auto a : this->legal_actions)
|
| 296 |
+
{
|
| 297 |
+
CNode *child = this->get_child(a);
|
| 298 |
+
distribution.push_back(child->visit_count);
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
return distribution;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
CNode *CNode::get_child(int action)
|
| 305 |
+
{
|
| 306 |
+
/*
|
| 307 |
+
Overview:
|
| 308 |
+
Get the child node corresponding to the input action.
|
| 309 |
+
Arguments:
|
| 310 |
+
- action: the action to get child.
|
| 311 |
+
*/
|
| 312 |
+
return &(this->children[action]);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
//*********************************************************
|
| 316 |
+
|
| 317 |
+
CRoots::CRoots()
|
| 318 |
+
{
|
| 319 |
+
/*
|
| 320 |
+
Overview:
|
| 321 |
+
The initialization of CRoots.
|
| 322 |
+
*/
|
| 323 |
+
this->root_num = 0;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list, int chance_space_size=2)
|
| 327 |
+
{
|
| 328 |
+
/*
|
| 329 |
+
Overview:
|
| 330 |
+
The initialization of CRoots with root num and legal action lists.
|
| 331 |
+
Arguments:
|
| 332 |
+
- root_num: the number of the current root.
|
| 333 |
+
- legal_action_list: the vector of the legal action of this root.
|
| 334 |
+
*/
|
| 335 |
+
this->root_num = root_num;
|
| 336 |
+
this->legal_actions_list = legal_actions_list;
|
| 337 |
+
|
| 338 |
+
for (int i = 0; i < root_num; ++i)
|
| 339 |
+
{
|
| 340 |
+
this->roots.push_back(CNode(0, this->legal_actions_list[i], false, chance_space_size));
|
| 341 |
+
// this->roots.push_back(CNode(0, this->legal_actions_list[i], false));
|
| 342 |
+
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
CRoots::~CRoots() {}
|
| 347 |
+
|
| 348 |
+
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 349 |
+
{
|
| 350 |
+
/*
|
| 351 |
+
Overview:
|
| 352 |
+
Expand the roots and add noises.
|
| 353 |
+
Arguments:
|
| 354 |
+
- root_noise_weight: the exploration fraction of roots
|
| 355 |
+
- noises: the vector of noise add to the roots.
|
| 356 |
+
- rewards: the vector of rewards of each root.
|
| 357 |
+
- policies: the vector of policy logits of each root.
|
| 358 |
+
- to_play_batch: the vector of the player side of each root.
|
| 359 |
+
*/
|
| 360 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 361 |
+
{
|
| 362 |
+
this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i], true);
|
| 363 |
+
this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
|
| 364 |
+
|
| 365 |
+
this->roots[i].visit_count += 1;
|
| 366 |
+
}
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
void CRoots::prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
|
| 370 |
+
{
|
| 371 |
+
/*
|
| 372 |
+
Overview:
|
| 373 |
+
Expand the roots without noise.
|
| 374 |
+
Arguments:
|
| 375 |
+
- rewards: the vector of rewards of each root.
|
| 376 |
+
- policies: the vector of policy logits of each root.
|
| 377 |
+
- to_play_batch: the vector of the player side of each root.
|
| 378 |
+
*/
|
| 379 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 380 |
+
{
|
| 381 |
+
this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i], true);
|
| 382 |
+
|
| 383 |
+
this->roots[i].visit_count += 1;
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
void CRoots::clear()
|
| 388 |
+
{
|
| 389 |
+
/*
|
| 390 |
+
Overview:
|
| 391 |
+
Clear the roots vector.
|
| 392 |
+
*/
|
| 393 |
+
this->roots.clear();
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
std::vector<std::vector<int> > CRoots::get_trajectories()
|
| 397 |
+
{
|
| 398 |
+
/*
|
| 399 |
+
Overview:
|
| 400 |
+
Find the current best trajectory starts from each root.
|
| 401 |
+
Outputs:
|
| 402 |
+
- traj: a vector of node index, which is the current best trajectory from each root.
|
| 403 |
+
*/
|
| 404 |
+
std::vector<std::vector<int> > trajs;
|
| 405 |
+
trajs.reserve(this->root_num);
|
| 406 |
+
|
| 407 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 408 |
+
{
|
| 409 |
+
trajs.push_back(this->roots[i].get_trajectory());
|
| 410 |
+
}
|
| 411 |
+
return trajs;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
std::vector<std::vector<int> > CRoots::get_distributions()
|
| 415 |
+
{
|
| 416 |
+
/*
|
| 417 |
+
Overview:
|
| 418 |
+
Get the children distribution of each root.
|
| 419 |
+
Outputs:
|
| 420 |
+
- distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
|
| 421 |
+
*/
|
| 422 |
+
std::vector<std::vector<int> > distributions;
|
| 423 |
+
distributions.reserve(this->root_num);
|
| 424 |
+
|
| 425 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 426 |
+
{
|
| 427 |
+
distributions.push_back(this->roots[i].get_children_distribution());
|
| 428 |
+
}
|
| 429 |
+
return distributions;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
std::vector<float> CRoots::get_values()
|
| 433 |
+
{
|
| 434 |
+
/*
|
| 435 |
+
Overview:
|
| 436 |
+
Return the real value of each root.
|
| 437 |
+
*/
|
| 438 |
+
std::vector<float> values;
|
| 439 |
+
for (int i = 0; i < this->root_num; ++i)
|
| 440 |
+
{
|
| 441 |
+
values.push_back(this->roots[i].value());
|
| 442 |
+
}
|
| 443 |
+
return values;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
//*********************************************************
|
| 447 |
+
//
|
| 448 |
+
void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
|
| 449 |
+
{
|
| 450 |
+
/*
|
| 451 |
+
Overview:
|
| 452 |
+
Update the q value of the root and its child nodes.
|
| 453 |
+
Arguments:
|
| 454 |
+
- root: the root that update q value from.
|
| 455 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 456 |
+
- discount_factor: the discount factor of reward.
|
| 457 |
+
- players: the number of players.
|
| 458 |
+
*/
|
| 459 |
+
std::stack<CNode *> node_stack;
|
| 460 |
+
node_stack.push(root);
|
| 461 |
+
while (node_stack.size() > 0)
|
| 462 |
+
{
|
| 463 |
+
CNode *node = node_stack.top();
|
| 464 |
+
node_stack.pop();
|
| 465 |
+
|
| 466 |
+
if (node != root)
|
| 467 |
+
{
|
| 468 |
+
// # NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 469 |
+
// # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 470 |
+
// # true_reward = node.value_prefix - (- parent_value_prefix)
|
| 471 |
+
// float true_reward = node->value_prefix - node->parent_value_prefix;
|
| 472 |
+
float true_reward = node->reward;
|
| 473 |
+
|
| 474 |
+
float qsa;
|
| 475 |
+
if (players == 1)
|
| 476 |
+
qsa = true_reward + discount_factor * node->value();
|
| 477 |
+
else if (players == 2)
|
| 478 |
+
// TODO(pu):
|
| 479 |
+
qsa = true_reward + discount_factor * (-1) * node->value();
|
| 480 |
+
|
| 481 |
+
min_max_stats.update(qsa);
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
for (auto a : node->legal_actions)
|
| 485 |
+
{
|
| 486 |
+
CNode *child = node->get_child(a);
|
| 487 |
+
if (child->expanded())
|
| 488 |
+
{
|
| 489 |
+
node_stack.push(child);
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
|
| 496 |
+
{
|
| 497 |
+
/*
|
| 498 |
+
Overview:
|
| 499 |
+
Update the value sum and visit count of nodes along the search path.
|
| 500 |
+
Arguments:
|
| 501 |
+
- search_path: a vector of nodes on the search path.
|
| 502 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 503 |
+
- to_play: which player to play the game in the current node.
|
| 504 |
+
- value: the value to propagate along the search path.
|
| 505 |
+
- discount_factor: the discount factor of reward.
|
| 506 |
+
*/
|
| 507 |
+
assert(to_play == -1 || to_play == 1 || to_play == 2);
|
| 508 |
+
if (to_play == -1)
|
| 509 |
+
{
|
| 510 |
+
// for play-with-bot-mode
|
| 511 |
+
float bootstrap_value = value;
|
| 512 |
+
int path_len = search_path.size();
|
| 513 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 514 |
+
{
|
| 515 |
+
CNode *node = search_path[i];
|
| 516 |
+
node->value_sum += bootstrap_value;
|
| 517 |
+
node->visit_count += 1;
|
| 518 |
+
|
| 519 |
+
float true_reward = node->reward;
|
| 520 |
+
|
| 521 |
+
min_max_stats.update(true_reward + discount_factor * node->value());
|
| 522 |
+
|
| 523 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 524 |
+
// std::cout << "to_play: " << to_play << std::endl;
|
| 525 |
+
|
| 526 |
+
}
|
| 527 |
+
}
|
| 528 |
+
else
|
| 529 |
+
{
|
| 530 |
+
// for self-play-mode
|
| 531 |
+
float bootstrap_value = value;
|
| 532 |
+
int path_len = search_path.size();
|
| 533 |
+
for (int i = path_len - 1; i >= 0; --i)
|
| 534 |
+
{
|
| 535 |
+
CNode *node = search_path[i];
|
| 536 |
+
if (node->to_play == to_play)
|
| 537 |
+
node->value_sum += bootstrap_value;
|
| 538 |
+
else
|
| 539 |
+
node->value_sum += -bootstrap_value;
|
| 540 |
+
node->visit_count += 1;
|
| 541 |
+
|
| 542 |
+
// NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
|
| 543 |
+
// but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
|
| 544 |
+
// float true_reward = node->value_prefix - parent_value_prefix;
|
| 545 |
+
float true_reward = node->reward;
|
| 546 |
+
|
| 547 |
+
// TODO(pu): why in muzero-general is - node.value
|
| 548 |
+
min_max_stats.update(true_reward + discount_factor * -node->value());
|
| 549 |
+
|
| 550 |
+
if (node->to_play == to_play)
|
| 551 |
+
bootstrap_value = -true_reward + discount_factor * bootstrap_value;
|
| 552 |
+
else
|
| 553 |
+
bootstrap_value = true_reward + discount_factor * bootstrap_value;
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch, std::vector<bool> &is_chance_list, std::vector<int> &leaf_idx_list)
|
| 559 |
+
{
|
| 560 |
+
/*
|
| 561 |
+
Overview:
|
| 562 |
+
Expand the nodes along the search path and update the infos.
|
| 563 |
+
Arguments:
|
| 564 |
+
- current_latent_state_index: The index of latent state of the leaf node in the search path.
|
| 565 |
+
- discount_factor: the discount factor of reward.
|
| 566 |
+
- value_prefixs: the value prefixs of nodes along the search path.
|
| 567 |
+
- values: the values to propagate along the search path.
|
| 568 |
+
- policies: the policy logits of nodes along the search path.
|
| 569 |
+
- min_max_stats: a tool used to min-max normalize the q value.
|
| 570 |
+
- results: the search results.
|
| 571 |
+
- to_play_batch: the batch of which player is playing on this node.
|
| 572 |
+
*/
|
| 573 |
+
|
| 574 |
+
if (leaf_idx_list.empty()) {
|
| 575 |
+
leaf_idx_list.resize(results.num);
|
| 576 |
+
for (int i = 0; i < results.num; ++i) {
|
| 577 |
+
leaf_idx_list[i] = i;
|
| 578 |
+
}
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
for (auto leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order) {
|
| 582 |
+
int i = leaf_idx_list[leaf_order];
|
| 583 |
+
}
|
| 584 |
+
for (int leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order)
|
| 585 |
+
{
|
| 586 |
+
int i = leaf_idx_list[leaf_order];
|
| 587 |
+
results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[leaf_order], policies[leaf_order], is_chance_list[i]);
|
| 588 |
+
cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[leaf_order], discount_factor);
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
|
| 594 |
+
{
|
| 595 |
+
/*
|
| 596 |
+
Overview:
|
| 597 |
+
Select the child node of the roots according to ucb scores.
|
| 598 |
+
Arguments:
|
| 599 |
+
- root: the roots to select the child node.
|
| 600 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 601 |
+
- pb_c_base: constants c2 in muzero.
|
| 602 |
+
- pb_c_init: constants c1 in muzero.
|
| 603 |
+
- disount_factor: the discount factor of reward.
|
| 604 |
+
- mean_q: the mean q value of the parent node.
|
| 605 |
+
- players: the number of players.
|
| 606 |
+
Outputs:
|
| 607 |
+
- action: the action to select.
|
| 608 |
+
*/
|
| 609 |
+
if (root->is_chance) {
|
| 610 |
+
// std::cout << "root->is_chance: True " << std::endl;
|
| 611 |
+
|
| 612 |
+
// If the node is a chance node, we sample from the prior outcome distribution.
|
| 613 |
+
std::vector<int> outcomes;
|
| 614 |
+
std::vector<double> probs;
|
| 615 |
+
|
| 616 |
+
for (const auto& kv : root->children) {
|
| 617 |
+
outcomes.push_back(kv.first);
|
| 618 |
+
probs.push_back(kv.second.prior); // Assuming 'prior' is a member variable of Node
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
std::random_device rd;
|
| 622 |
+
std::mt19937 gen(rd());
|
| 623 |
+
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 624 |
+
|
| 625 |
+
int outcome = outcomes[dist(gen)];
|
| 626 |
+
// std::cout << "Outcome: " << outcome << std::endl;
|
| 627 |
+
|
| 628 |
+
return outcome;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
// std::cout << "root->is_chance: False " << std::endl;
|
| 632 |
+
|
| 633 |
+
float max_score = FLOAT_MIN;
|
| 634 |
+
const float epsilon = 0.000001;
|
| 635 |
+
std::vector<int> max_index_lst;
|
| 636 |
+
for (auto a : root->legal_actions)
|
| 637 |
+
{
|
| 638 |
+
|
| 639 |
+
CNode *child = root->get_child(a);
|
| 640 |
+
float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
|
| 641 |
+
|
| 642 |
+
if (max_score < temp_score)
|
| 643 |
+
{
|
| 644 |
+
max_score = temp_score;
|
| 645 |
+
|
| 646 |
+
max_index_lst.clear();
|
| 647 |
+
max_index_lst.push_back(a);
|
| 648 |
+
}
|
| 649 |
+
else if (temp_score >= max_score - epsilon)
|
| 650 |
+
{
|
| 651 |
+
max_index_lst.push_back(a);
|
| 652 |
+
}
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
int action = 0;
|
| 656 |
+
if (max_index_lst.size() > 0)
|
| 657 |
+
{
|
| 658 |
+
int rand_index = rand() % max_index_lst.size();
|
| 659 |
+
action = max_index_lst[rand_index];
|
| 660 |
+
}
|
| 661 |
+
return action;
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
|
| 665 |
+
{
|
| 666 |
+
/*
|
| 667 |
+
Overview:
|
| 668 |
+
Compute the ucb score of the child.
|
| 669 |
+
Arguments:
|
| 670 |
+
- child: the child node to compute ucb score.
|
| 671 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 672 |
+
- mean_q: the mean q value of the parent node.
|
| 673 |
+
- total_children_visit_counts: the total visit counts of the child nodes of the parent node.
|
| 674 |
+
- pb_c_base: constants c2 in muzero.
|
| 675 |
+
- pb_c_init: constants c1 in muzero.
|
| 676 |
+
- disount_factor: the discount factor of reward.
|
| 677 |
+
- players: the number of players.
|
| 678 |
+
Outputs:
|
| 679 |
+
- ucb_value: the ucb score of the child.
|
| 680 |
+
*/
|
| 681 |
+
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
|
| 682 |
+
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
|
| 683 |
+
pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
|
| 684 |
+
|
| 685 |
+
prior_score = pb_c * child->prior;
|
| 686 |
+
if (child->visit_count == 0)
|
| 687 |
+
{
|
| 688 |
+
value_score = parent_mean_q;
|
| 689 |
+
}
|
| 690 |
+
else
|
| 691 |
+
{
|
| 692 |
+
float true_reward = child->reward;
|
| 693 |
+
if (players == 1)
|
| 694 |
+
value_score = true_reward + discount_factor * child->value();
|
| 695 |
+
else if (players == 2)
|
| 696 |
+
value_score = true_reward + discount_factor * (-child->value());
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
value_score = min_max_stats.normalize(value_score);
|
| 700 |
+
|
| 701 |
+
if (value_score < 0)
|
| 702 |
+
value_score = 0;
|
| 703 |
+
if (value_score > 1)
|
| 704 |
+
value_score = 1;
|
| 705 |
+
|
| 706 |
+
float ucb_value = prior_score + value_score;
|
| 707 |
+
return ucb_value;
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
|
| 711 |
+
{
|
| 712 |
+
/*
|
| 713 |
+
Overview:
|
| 714 |
+
Search node path from the roots.
|
| 715 |
+
Arguments:
|
| 716 |
+
- roots: the roots that search from.
|
| 717 |
+
- pb_c_base: constants c2 in muzero.
|
| 718 |
+
- pb_c_init: constants c1 in muzero.
|
| 719 |
+
- disount_factor: the discount factor of reward.
|
| 720 |
+
- min_max_stats: a tool used to min-max normalize the score.
|
| 721 |
+
- results: the search results.
|
| 722 |
+
- virtual_to_play_batch: the batch of which player is playing on this node.
|
| 723 |
+
*/
|
| 724 |
+
// set seed
|
| 725 |
+
get_time_and_set_rand_seed();
|
| 726 |
+
|
| 727 |
+
int last_action = -1;
|
| 728 |
+
float parent_q = 0.0;
|
| 729 |
+
results.search_lens = std::vector<int>();
|
| 730 |
+
|
| 731 |
+
int players = 0;
|
| 732 |
+
int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
|
| 733 |
+
if (largest_element == -1)
|
| 734 |
+
players = 1;
|
| 735 |
+
else
|
| 736 |
+
players = 2;
|
| 737 |
+
|
| 738 |
+
for (int i = 0; i < results.num; ++i)
|
| 739 |
+
{
|
| 740 |
+
CNode *node = &(roots->roots[i]);
|
| 741 |
+
int is_root = 1;
|
| 742 |
+
int search_len = 0;
|
| 743 |
+
results.search_paths[i].push_back(node);
|
| 744 |
+
|
| 745 |
+
// std::cout << "root->is_chance: " <<node->is_chance<< std::endl;
|
| 746 |
+
// node->is_chance=false;
|
| 747 |
+
|
| 748 |
+
while (node->expanded())
|
| 749 |
+
{
|
| 750 |
+
float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
|
| 751 |
+
is_root = 0;
|
| 752 |
+
parent_q = mean_q;
|
| 753 |
+
// std::cout << "node->is_chance: " <<node->is_chance<< std::endl;
|
| 754 |
+
|
| 755 |
+
int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
|
| 756 |
+
if (players > 1)
|
| 757 |
+
{
|
| 758 |
+
assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
|
| 759 |
+
if (virtual_to_play_batch[i] == 1)
|
| 760 |
+
virtual_to_play_batch[i] = 2;
|
| 761 |
+
else
|
| 762 |
+
virtual_to_play_batch[i] = 1;
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
node->best_action = action;
|
| 766 |
+
// next
|
| 767 |
+
node = node->get_child(action);
|
| 768 |
+
last_action = action;
|
| 769 |
+
results.search_paths[i].push_back(node);
|
| 770 |
+
search_len += 1;
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
|
| 774 |
+
|
| 775 |
+
results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
|
| 776 |
+
results.latent_state_index_in_batch.push_back(parent->batch_index);
|
| 777 |
+
|
| 778 |
+
results.last_actions.push_back(last_action);
|
| 779 |
+
results.search_lens.push_back(search_len);
|
| 780 |
+
results.nodes.push_back(node);
|
| 781 |
+
results.leaf_node_is_chance.push_back(node->is_chance);
|
| 782 |
+
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
|
| 783 |
+
|
| 784 |
+
}
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
}
|
LightZero/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// C++11
|
| 2 |
+
|
| 3 |
+
#ifndef CNODE_H
|
| 4 |
+
#define CNODE_H
|
| 5 |
+
|
| 6 |
+
#include "./../common_lib/cminimax.h"
|
| 7 |
+
#include <math.h>
|
| 8 |
+
#include <vector>
|
| 9 |
+
#include <stack>
|
| 10 |
+
#include <stdlib.h>
|
| 11 |
+
#include <time.h>
|
| 12 |
+
#include <cmath>
|
| 13 |
+
#include <sys/timeb.h>
|
| 14 |
+
#include <time.h>
|
| 15 |
+
#include <map>
|
| 16 |
+
|
| 17 |
+
const int DEBUG_MODE = 0;
|
| 18 |
+
|
| 19 |
+
namespace tree {
|
| 20 |
+
|
| 21 |
+
class CNode {
|
| 22 |
+
public:
|
| 23 |
+
int visit_count, to_play, current_latent_state_index, batch_index, best_action;
|
| 24 |
+
float reward, prior, value_sum;
|
| 25 |
+
bool is_chance;
|
| 26 |
+
int chance_space_size;
|
| 27 |
+
std::vector<int> children_index;
|
| 28 |
+
std::map<int, CNode> children;
|
| 29 |
+
|
| 30 |
+
std::vector<int> legal_actions;
|
| 31 |
+
|
| 32 |
+
CNode();
|
| 33 |
+
CNode(float prior, std::vector<int> &legal_actions, bool is_chance = false, int chance_space_size = 2);
|
| 34 |
+
~CNode();
|
| 35 |
+
|
| 36 |
+
void expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits, bool is_chance);
|
| 37 |
+
void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
|
| 38 |
+
float compute_mean_q(int isRoot, float parent_q, float discount_factor);
|
| 39 |
+
void print_out();
|
| 40 |
+
|
| 41 |
+
int expanded();
|
| 42 |
+
|
| 43 |
+
float value();
|
| 44 |
+
|
| 45 |
+
std::vector<int> get_trajectory();
|
| 46 |
+
std::vector<int> get_children_distribution();
|
| 47 |
+
CNode* get_child(int action);
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
class CRoots{
|
| 51 |
+
public:
|
| 52 |
+
int root_num;
|
| 53 |
+
std::vector<CNode> roots;
|
| 54 |
+
std::vector<std::vector<int> > legal_actions_list;
|
| 55 |
+
int chance_space_size;
|
| 56 |
+
|
| 57 |
+
CRoots();
|
| 58 |
+
CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list, int chance_space_size);
|
| 59 |
+
~CRoots();
|
| 60 |
+
|
| 61 |
+
void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 62 |
+
void prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
|
| 63 |
+
void clear();
|
| 64 |
+
std::vector<std::vector<int> > get_trajectories();
|
| 65 |
+
std::vector<std::vector<int> > get_distributions();
|
| 66 |
+
std::vector<float> get_values();
|
| 67 |
+
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
class CSearchResults{
|
| 71 |
+
public:
|
| 72 |
+
int num;
|
| 73 |
+
std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
|
| 74 |
+
std::vector<int> virtual_to_play_batchs;
|
| 75 |
+
std::vector<CNode*> nodes;
|
| 76 |
+
std::vector<bool> leaf_node_is_chance;
|
| 77 |
+
std::vector<std::vector<CNode*> > search_paths;
|
| 78 |
+
|
| 79 |
+
CSearchResults();
|
| 80 |
+
CSearchResults(int num);
|
| 81 |
+
~CSearchResults();
|
| 82 |
+
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
//*********************************************************
|
| 87 |
+
void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
|
| 88 |
+
void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
|
| 89 |
+
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch, std::vector<int> & is_chance_list, std::vector<int> &leaf_idx_list);
|
| 90 |
+
int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
|
| 91 |
+
float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players);
|
| 92 |
+
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
#endif
|