Spaces:
Sleeping
Sleeping
Fix bug on env initialization
Browse files- a3c/discrete_A3C.py +29 -26
- main.py +36 -1
- wordle_env/state.py +4 -2
- wordle_env/wordle.py +4 -4
a3c/discrete_A3C.py
CHANGED
|
@@ -13,47 +13,50 @@ import torch.multiprocessing as mp
|
|
| 13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
| 14 |
import numpy as np
|
| 15 |
|
| 16 |
-
GAMMA = 0.
|
| 17 |
|
| 18 |
class Net(nn.Module):
|
| 19 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
| 20 |
super(Net, self).__init__()
|
| 21 |
self.s_dim = s_dim
|
| 22 |
self.a_dim = a_dim
|
| 23 |
-
n_emb = 32
|
| 24 |
-
|
| 25 |
-
# self.pi2 = nn.Linear(128, a_dim)
|
| 26 |
-
self.v1 = nn.Linear(s_dim, 256)
|
| 27 |
-
self.v2 = nn.Linear(256, n_emb)
|
| 28 |
-
self.v3 = nn.Linear(n_emb, 1)
|
| 29 |
-
set_init([ self.v1, self.v2]) # n_emb
|
| 30 |
-
self.distribution = torch.distributions.Categorical
|
| 31 |
word_width = 26 * words_width
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
for i, word in enumerate(word_list):
|
| 35 |
for j, c in enumerate(word):
|
| 36 |
-
word_array[
|
| 37 |
self.words = torch.Tensor(word_array)
|
| 38 |
-
self.f_word = nn.Sequential(
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
)
|
| 43 |
|
| 44 |
def forward(self, x):
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# logits = self.pi2(pi1)
|
| 50 |
-
v1 = torch.tanh(self.v1(x))
|
| 51 |
-
values = self.v2(v1)
|
| 52 |
logits = torch.log_softmax(
|
| 53 |
-
torch.tensordot(self.actor_head(values),
|
| 54 |
dims=((1,), (0,))),
|
| 55 |
dim=-1)
|
| 56 |
-
values = self.
|
| 57 |
return logits, values
|
| 58 |
|
| 59 |
def choose_action(self, s):
|
|
|
|
| 13 |
from .utils import v_wrap, set_init, push_and_pull, record
|
| 14 |
import numpy as np
|
| 15 |
|
| 16 |
+
GAMMA = 0.7
|
| 17 |
|
| 18 |
class Net(nn.Module):
|
| 19 |
def __init__(self, s_dim, a_dim, word_list, words_width):
|
| 20 |
super(Net, self).__init__()
|
| 21 |
self.s_dim = s_dim
|
| 22 |
self.a_dim = a_dim
|
| 23 |
+
# n_emb = 32
|
| 24 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
word_width = 26 * words_width
|
| 26 |
+
layers = [
|
| 27 |
+
nn.Linear(s_dim, word_width),
|
| 28 |
+
nn.Tanh(),
|
| 29 |
+
# nn.Linear(128, word_width),
|
| 30 |
+
# nn.Tanh(),
|
| 31 |
+
# nn.Linear(256, n_emb),
|
| 32 |
+
# nn.Tanh(),
|
| 33 |
+
]
|
| 34 |
+
self.v1 = nn.Sequential(*layers)
|
| 35 |
+
self.v4 = nn.Linear(word_width, 1)
|
| 36 |
+
self.actor_head = nn.Linear(word_width, word_width)
|
| 37 |
+
|
| 38 |
+
self.distribution = torch.distributions.Categorical
|
| 39 |
+
word_array = np.zeros((word_width, len(word_list)))
|
| 40 |
for i, word in enumerate(word_list):
|
| 41 |
for j, c in enumerate(word):
|
| 42 |
+
word_array[ j*26 + (ord(c) - ord('A')), i ] = 1
|
| 43 |
self.words = torch.Tensor(word_array)
|
| 44 |
+
# self.f_word = nn.Sequential(
|
| 45 |
+
# nn.Linear(word_width, 64),
|
| 46 |
+
# nn.ReLU(),
|
| 47 |
+
# nn.Linear(64, n_emb),
|
| 48 |
+
# )
|
| 49 |
|
| 50 |
def forward(self, x):
|
| 51 |
+
# fw = self.f_word(
|
| 52 |
+
# self.words.to(x.device.index),
|
| 53 |
+
# ).transpose(0, 1)
|
| 54 |
+
values = self.v1(x.float())
|
|
|
|
|
|
|
|
|
|
| 55 |
logits = torch.log_softmax(
|
| 56 |
+
torch.tensordot(self.actor_head(values), self.words,
|
| 57 |
dims=((1,), (0,))),
|
| 58 |
dim=-1)
|
| 59 |
+
values = self.v4(values)
|
| 60 |
return logits, values
|
| 61 |
|
| 62 |
def choose_action(self, s):
|
main.py
CHANGED
|
@@ -6,10 +6,44 @@ import torch.multiprocessing as mp
|
|
| 6 |
|
| 7 |
from a3c.discrete_A3C import Net, Worker
|
| 8 |
from a3c.shared_adam import SharedAdam
|
|
|
|
| 9 |
from wordle_env.wordle import WordleEnvBase
|
| 10 |
|
| 11 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
if __name__ == "__main__":
|
| 14 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
| 15 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
|
@@ -39,4 +73,5 @@ if __name__ == "__main__":
|
|
| 39 |
plt.plot(res)
|
| 40 |
plt.ylabel('Moving average ep reward')
|
| 41 |
plt.xlabel('Step')
|
| 42 |
-
plt.show()
|
|
|
|
|
|
| 6 |
|
| 7 |
from a3c.discrete_A3C import Net, Worker
|
| 8 |
from a3c.shared_adam import SharedAdam
|
| 9 |
+
from a3c.utils import v_wrap
|
| 10 |
from wordle_env.wordle import WordleEnvBase
|
| 11 |
|
| 12 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 13 |
|
| 14 |
+
def evaluate(net, env):
|
| 15 |
+
print("Evaluation mode")
|
| 16 |
+
n_wins = 0
|
| 17 |
+
n_guesses = 0
|
| 18 |
+
n_win_guesses = 0
|
| 19 |
+
env = env.unwrapped
|
| 20 |
+
N = env.allowable_words
|
| 21 |
+
for goal_word in env.words[:N]:
|
| 22 |
+
win, outcomes = play(net, env)
|
| 23 |
+
if win:
|
| 24 |
+
n_wins += 1
|
| 25 |
+
n_win_guesses += len(outcomes)
|
| 26 |
+
else:
|
| 27 |
+
print("Lost!", goal_word, outcomes)
|
| 28 |
+
n_guesses += len(outcomes)
|
| 29 |
+
|
| 30 |
+
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
| 31 |
+
f"{n_guesses / N} including losses.")
|
| 32 |
+
|
| 33 |
+
def play(net, env):
|
| 34 |
+
state = env.reset()
|
| 35 |
+
outcomes = []
|
| 36 |
+
win = False
|
| 37 |
+
for i in range(env.max_turns):
|
| 38 |
+
action = net.choose_action(v_wrap(state[None, :]))
|
| 39 |
+
state, reward, done, _ = env.step(action)
|
| 40 |
+
outcomes.append((env.words[action], reward))
|
| 41 |
+
if done:
|
| 42 |
+
if reward >= 0:
|
| 43 |
+
win = True
|
| 44 |
+
break
|
| 45 |
+
return win, outcomes
|
| 46 |
+
|
| 47 |
if __name__ == "__main__":
|
| 48 |
max_ep = int(sys.argv[1]) if len(sys.argv) > 1 else 100000
|
| 49 |
env_id = sys.argv[2] if len(sys.argv) > 2 else 'WordleEnv100FullAction-v0'
|
|
|
|
| 73 |
plt.plot(res)
|
| 74 |
plt.ylabel('Moving average ep reward')
|
| 75 |
plt.xlabel('Step')
|
| 76 |
+
plt.show()
|
| 77 |
+
evaluate(gnet, env)
|
wordle_env/state.py
CHANGED
|
@@ -141,7 +141,7 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
| 141 |
|
| 142 |
def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
| 143 |
state = state.copy()
|
| 144 |
-
|
| 145 |
state[0] -= 1
|
| 146 |
processed_letters = []
|
| 147 |
for i, c in enumerate(word):
|
|
@@ -149,6 +149,8 @@ def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
| 149 |
offset = 1 + cint * WORDLE_N * 3
|
| 150 |
if goal_word[i] == c:
|
| 151 |
# char at position i = yes, all other chars at position i == no
|
|
|
|
|
|
|
| 152 |
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
| 153 |
for ocint in range(len(WORDLE_CHARS)):
|
| 154 |
if ocint != cint:
|
|
@@ -168,5 +170,5 @@ def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
| 168 |
# Char at all positions = no
|
| 169 |
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
| 170 |
processed_letters.append(c)
|
| 171 |
-
return state
|
| 172 |
|
|
|
|
| 141 |
|
| 142 |
def update(state: WordleState, word: str, goal_word: str) -> WordleState:
|
| 143 |
state = state.copy()
|
| 144 |
+
reward = 0
|
| 145 |
state[0] -= 1
|
| 146 |
processed_letters = []
|
| 147 |
for i, c in enumerate(word):
|
|
|
|
| 149 |
offset = 1 + cint * WORDLE_N * 3
|
| 150 |
if goal_word[i] == c:
|
| 151 |
# char at position i = yes, all other chars at position i == no
|
| 152 |
+
if state[offset + 3 * i:offset + 3 * i + 3][2] == 0:
|
| 153 |
+
reward += 0.1
|
| 154 |
state[offset + 3 * i:offset + 3 * i + 3] = [0, 0, 1]
|
| 155 |
for ocint in range(len(WORDLE_CHARS)):
|
| 156 |
if ocint != cint:
|
|
|
|
| 170 |
# Char at all positions = no
|
| 171 |
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
| 172 |
processed_letters.append(c)
|
| 173 |
+
return state, reward
|
| 174 |
|
wordle_env/wordle.py
CHANGED
|
@@ -73,11 +73,11 @@ class WordleEnvBase(gym.Env):
|
|
| 73 |
word = self.words[action]
|
| 74 |
goal_word = self.words[self.goal_word]
|
| 75 |
# assert word in self.words, f'{word} not in words list'
|
| 76 |
-
self.state = self.state_updater(state=self.state,
|
| 77 |
word=word,
|
| 78 |
goal_word=goal_word)
|
| 79 |
|
| 80 |
-
reward =
|
| 81 |
if action == self.goal_word:
|
| 82 |
self.done = True
|
| 83 |
#reward = REWARD
|
|
@@ -159,7 +159,7 @@ class WordleEnv100fiftyAction(WordleEnvBase):
|
|
| 159 |
|
| 160 |
class WordleEnv100FullAction(WordleEnvBase):
|
| 161 |
def __init__(self):
|
| 162 |
-
super().__init__(words=_load_words(), allowable_words=100)
|
| 163 |
|
| 164 |
|
| 165 |
class WordleEnv1000(WordleEnvBase):
|
|
@@ -175,7 +175,7 @@ class WordleEnv1000WithMask(WordleEnvBase):
|
|
| 175 |
|
| 176 |
class WordleEnv1000FullAction(WordleEnvBase):
|
| 177 |
def __init__(self):
|
| 178 |
-
super().__init__(words=_load_words(), allowable_words=1000)
|
| 179 |
|
| 180 |
|
| 181 |
class WordleEnvFull(WordleEnvBase):
|
|
|
|
| 73 |
word = self.words[action]
|
| 74 |
goal_word = self.words[self.goal_word]
|
| 75 |
# assert word in self.words, f'{word} not in words list'
|
| 76 |
+
self.state, r = self.state_updater(state=self.state,
|
| 77 |
word=word,
|
| 78 |
goal_word=goal_word)
|
| 79 |
|
| 80 |
+
reward = r
|
| 81 |
if action == self.goal_word:
|
| 82 |
self.done = True
|
| 83 |
#reward = REWARD
|
|
|
|
| 159 |
|
| 160 |
class WordleEnv100FullAction(WordleEnvBase):
|
| 161 |
def __init__(self):
|
| 162 |
+
super().__init__(words=_load_words(100), allowable_words=100)
|
| 163 |
|
| 164 |
|
| 165 |
class WordleEnv1000(WordleEnvBase):
|
|
|
|
| 175 |
|
| 176 |
class WordleEnv1000FullAction(WordleEnvBase):
|
| 177 |
def __init__(self):
|
| 178 |
+
super().__init__(words=_load_words(1000), allowable_words=1000)
|
| 179 |
|
| 180 |
|
| 181 |
class WordleEnvFull(WordleEnvBase):
|