Spaces:
Sleeping
Sleeping
| import gym | |
| from gym import spaces | |
| from typing import Optional, List | |
| from . import state | |
| from .const import WORDLE_N, REWARD, WORDLE_CHARS | |
| from .words import complete_vocabulary, target_vocabulary | |
| import random | |
| def _load_words( | |
| limit: Optional[int] = None, | |
| complete: Optional[bool] = False | |
| ) -> List[str]: | |
| words = complete_vocabulary if complete else target_vocabulary | |
| return words if not limit else words[:limit] | |
| def get_env(env_id='WordleEnvFull-v0'): | |
| return gym.make(env_id) | |
| class WordleEnvBase(gym.Env): | |
| """ | |
| Actions: | |
| Can play any 5 letter word in vocabulary | |
| * 13k for full vocab | |
| State space is defined as: | |
| * 6 possibilities for turns (WORDLE_TURNS) | |
| * For each in VALID_CHARS [A-Z] | |
| can be in one of 3^WORDLE_N states: (No, Maybe, Yes) | |
| for full game, this is (3^5)^26 | |
| Each state has 1 + 5*26 possibilities | |
| Reward: | |
| Reward is 10 for guessing the right word, | |
| -10 for not guessing the right word after 6 guesses. | |
| 1 from every letter correctly guessed on each try | |
| Starting State: | |
| Random goal word | |
| Initial state with turn 0, all chars Unvisited | |
| """ | |
| def __init__(self, words: List[str], | |
| max_turns: int = 6, | |
| allowable_words: Optional[int] = None, | |
| mask_based_state_updates: bool = False): | |
| assert all( | |
| len(w) == WORDLE_N for w in words | |
| ), f'Not all words of length {WORDLE_N}, {words}' | |
| self.words = words | |
| self.max_turns = max_turns | |
| self.allowable_words = allowable_words | |
| self.mask_based_state_updates = mask_based_state_updates | |
| if not self.allowable_words: | |
| self.allowable_words = len(self.words) | |
| self.action_space = spaces.Discrete(self.words_as_action_space()) | |
| self.observation_space = spaces.MultiDiscrete( | |
| state.get_nvec(self.max_turns)) | |
| self.done = True | |
| self.goal_word: int = -1 | |
| self.state: state.WordleState = None | |
| self.state_updater = state.update | |
| if self.mask_based_state_updates: | |
| self.state_updater = state.update_mask | |
| def step(self, action: int): | |
| if self.done: | |
| raise ValueError( | |
| "You are calling 'step()' even though this " | |
| "environment has already returned done = True. You " | |
| "should always call 'reset()' once you receive 'done = " | |
| "True' -- any further steps are undefined behavior." | |
| ) | |
| word = self.words[action] | |
| goal_word = self.words[self.goal_word] | |
| # assert word in self.words, f'{word} not in words list' | |
| self.state, r = self.state_updater(state=self.state, | |
| word=word, | |
| goal_word=goal_word) | |
| reward = r | |
| if action == self.goal_word: | |
| self.done = True | |
| # reward = REWARD | |
| if state.remaining_steps(self.state) == self.max_turns-1: | |
| reward = 0 # -10*REWARD # No reward for guessing off the bat | |
| else: | |
| reward = REWARD | |
| elif state.remaining_steps(self.state) == 0: | |
| self.done = True | |
| reward = -REWARD | |
| goal_dict = {"goal_id": self.goal_word} | |
| return self.state.copy(), reward, self.done, goal_dict | |
| def reset(self): | |
| self.state = state.new(self.max_turns) | |
| self.done = False | |
| random_word = random.choice(self.words[:self.allowable_words]) | |
| self.goal_word = self.words.index(random_word) | |
| return self.state.copy() | |
| def set_goal_word(self, goal_word: str): | |
| self.goal_word = self.words.index(goal_word) | |
| def set_goal_encoded(self, goal_encoded: int): | |
| self.goal_word = goal_encoded | |
| def words_as_action_space(self): | |
| return len(self.words) | |
| class WordleEnv100OneAction(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(100), allowable_words=1) | |
| class WordleEnv100WithMask(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(100), | |
| mask_based_state_updates=True) | |
| class WordleEnv100TwoAction(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(100), allowable_words=2) | |
| class WordleEnv100fiftyAction(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(100), allowable_words=50) | |
| class WordleEnv100FullAction(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(100), allowable_words=100) | |
| class WordleEnv1000WithMask(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(1000), | |
| mask_based_state_updates=True) | |
| class WordleEnv1000FullAction(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(1000), allowable_words=1000) | |
| class WordleEnvFull(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words()) | |
| class WordleEnvRealWithMask(WordleEnvBase): | |
| def __init__(self): | |
| super().__init__(words=_load_words(), allowable_words=2315, | |
| mask_based_state_updates=True) | |