lcipolina commited on
Commit
dddb842
·
verified ·
1 Parent(s): 768f7fd

Updated with Hex

Browse files
src/game_reasoning_arena/arena/envs/README.md CHANGED
@@ -0,0 +1 @@
 
 
1
+ Simulation logic for each game.
src/game_reasoning_arena/arena/envs/hex_env.py CHANGED
@@ -7,6 +7,7 @@ Hex using the OpenSpiel framework.
7
  from typing import Any, Dict, Optional
8
  from .open_spiel_env import OpenSpielEnv
9
 
 
10
  class HexEnv(OpenSpielEnv):
11
  """Environment Simulator for Hex."""
12
 
@@ -19,13 +20,13 @@ class HexEnv(OpenSpielEnv):
19
  Args:
20
  game: The OpenSpiel game object.
21
  game_name: A string representing the name of the game.
22
- player_types: A dictionary mapping player IDs to their types (e.g., human, random).
 
23
  max_game_rounds: Maximum number of rounds
24
  for iterated games (optional, default is None).
25
  """
26
  super().__init__(game, game_name, player_types, max_game_rounds, seed)
27
 
28
-
29
  def get_player_symbol(self, agent_id: int) -> str:
30
  """Returns the symbol used by a Tic Tac Toe player.
31
 
@@ -47,7 +48,9 @@ class HexEnv(OpenSpielEnv):
47
  str: Legal action numbers and a flattened board index layout.
48
  """
49
  legal = self.state.legal_actions(agent_id)
50
- size = self.game.board_size # Usually 11
 
 
51
 
52
  # Create a flat index grid (diagonal shape)
53
  grid = []
@@ -74,7 +77,9 @@ class HexEnv(OpenSpielEnv):
74
  raw = self.state.observation_string(agent_id)
75
  symbols = [char for char in raw if char in ("y", "o", ".")]
76
 
77
- size = self.game.board_size # typically 11
 
 
78
  rows = []
79
  idx = 0
80
  for row in range(size):
 
7
  from typing import Any, Dict, Optional
8
  from .open_spiel_env import OpenSpielEnv
9
 
10
+
11
  class HexEnv(OpenSpielEnv):
12
  """Environment Simulator for Hex."""
13
 
 
20
  Args:
21
  game: The OpenSpiel game object.
22
  game_name: A string representing the name of the game.
23
+ player_types: A dictionary mapping player IDs to their types
24
+ (e.g., human, random).
25
  max_game_rounds: Maximum number of rounds
26
  for iterated games (optional, default is None).
27
  """
28
  super().__init__(game, game_name, player_types, max_game_rounds, seed)
29
 
 
30
  def get_player_symbol(self, agent_id: int) -> str:
31
  """Returns the symbol used by a Tic Tac Toe player.
32
 
 
48
  str: Legal action numbers and a flattened board index layout.
49
  """
50
  legal = self.state.legal_actions(agent_id)
51
+ # Get board size from observation tensor shape
52
+ obs_shape = self.game.observation_tensor_shape()
53
+ size = obs_shape[-1] # Usually 11
54
 
55
  # Create a flat index grid (diagonal shape)
56
  grid = []
 
77
  raw = self.state.observation_string(agent_id)
78
  symbols = [char for char in raw if char in ("y", "o", ".")]
79
 
80
+ # Get board size from observation tensor shape
81
+ obs_shape = self.game.observation_tensor_shape()
82
+ size = obs_shape[-1] # typically 11
83
  rows = []
84
  idx = 0
85
  for row in range(size):
src/game_reasoning_arena/arena/envs/kuhn_poker_env.py CHANGED
@@ -9,8 +9,8 @@ game state and potential strategies.
9
  """
10
 
11
  from typing import Any, Dict, Optional
12
- from .open_spiel_env import OpenSpielEnv
13
  from game_reasoning_arena.arena.agents.llm_utils import format_prompt
 
14
 
15
 
16
  class KuhnPokerEnv(OpenSpielEnv):
 
9
  """
10
 
11
  from typing import Any, Dict, Optional
 
12
  from game_reasoning_arena.arena.agents.llm_utils import format_prompt
13
+ from .open_spiel_env import OpenSpielEnv
14
 
15
 
16
  class KuhnPokerEnv(OpenSpielEnv):
src/game_reasoning_arena/arena/envs/matrix_game_env.py CHANGED
@@ -1,8 +1,8 @@
1
  """Simulator for Matrix Games.
2
 
3
  This module implements the MatrixGameEnvclass, which handles various
4
- matrix games like Rock-Paper-Scissors and Prisoner's Dilemma using the OpenSpiel
5
- framework.
6
  """
7
 
8
  from typing import Any, Dict, List, Optional
@@ -20,7 +20,8 @@ class MatrixGameEnv(OpenSpielEnv):
20
  Args:
21
  game: The OpenSpiel game object.
22
  game_name: A string representing the name of the game.
23
- player_types: A dictionary mapping player IDs to their types (e.g., human, random).
 
24
  max_game_rounds: Maximum number of rounds
25
  for iterated games (optional, default is None).
26
  """
@@ -78,23 +79,23 @@ class MatrixGameEnv(OpenSpielEnv):
78
 
79
  prompt = f"""You are Player {agent_id} in the game: {self.game_name}
80
 
81
- Available actions:
82
- {action_list}
83
-
84
- What action do you choose? Reply only with the action number.
85
-
86
- First, think through the game strategy and explain your reasoning.
87
- Only after that, decide on the best action to take.
88
-
89
- Reply only in the following JSON format:
90
- {{
91
- 'reasoning': <str>,
92
- 'action': <int>
93
- }}"""
 
94
 
95
  return prompt
96
 
97
  def render_board(self, agent_id: int) -> str:
98
  # Matrix games have no spatial board; return a basic description.
99
-
100
- return "Matrix game – no board representation available"
 
1
  """Simulator for Matrix Games.
2
 
3
  This module implements the MatrixGameEnvclass, which handles various
4
+ matrix games like Rock-Paper-Scissors and Prisoner's Dilemma using
5
+ the OpenSpiel framework.
6
  """
7
 
8
  from typing import Any, Dict, List, Optional
 
20
  Args:
21
  game: The OpenSpiel game object.
22
  game_name: A string representing the name of the game.
23
+ player_types: A dictionary mapping player IDs to their types
24
+ (e.g., human, random).
25
  max_game_rounds: Maximum number of rounds
26
  for iterated games (optional, default is None).
27
  """
 
79
 
80
  prompt = f"""You are Player {agent_id} in the game: {self.game_name}
81
 
82
+ Available actions:
83
+ {action_list}
84
+
85
+ What action do you choose? Reply only with the action number.
86
+
87
+ First, think through the game strategy
88
+ and explain your reasoning.
89
+ Only after that, decide on the best action to take.
90
+
91
+ Reply only in the following JSON format:
92
+ {{
93
+ 'reasoning': <str>,
94
+ 'action': <int>
95
+ }}"""
96
 
97
  return prompt
98
 
99
  def render_board(self, agent_id: int) -> str:
100
  # Matrix games have no spatial board; return a basic description.
101
+ return "Matrix game – no board representation available"
 
src/game_reasoning_arena/arena/envs/open_spiel_env.py CHANGED
@@ -28,13 +28,15 @@ class OpenSpielEnv(ABC):
28
  Args:
29
  game (Any): The OpenSpiel game object being simulated.
30
  game_name (str): A human-readable name for the game.
31
- player_type (Dict[str, str]): Maps "Player 1", "Player 2", ... to their types (human, random, llm, etc.).
32
- max_game_rounds (int): Maximum number of rounds for iterated games. Ignored by single-shot games.
 
 
33
  seed (Optional[int]): Random seed for reproducibility.
34
  """
35
  self.game = game
36
  self.game_name = game_name
37
- self.player_types = player_types # List of strings
38
  self.max_game_rounds = max_game_rounds # For iterated games only
39
  self.state = None
40
  self.info = {}
@@ -46,9 +48,10 @@ class OpenSpielEnv(ABC):
46
 
47
  self.state = None
48
 
49
- def reset(self, seed: Optional[int]=None) -> Tuple[str, Dict[str, Any]]:
50
  """
51
- Resets the environment to an initial state and returns an initial observation.
 
52
 
53
  Args:
54
  seed (Optional[int]): Seed for environment randomization.
@@ -63,7 +66,8 @@ class OpenSpielEnv(ABC):
63
  if hasattr(self.game, "set_seed"):
64
  self.game.set_seed(seed)
65
 
66
- self.state = self.game.new_initial_state() # Instantiates the pyspiel game state
 
67
  self.terminated = False
68
  self.truncated = False
69
  self.info = {}
@@ -74,21 +78,27 @@ class OpenSpielEnv(ABC):
74
 
75
  return self._state_to_observation(), self.info
76
 
77
- def step(self, action_dict: Dict[int, int]) -> Tuple[Any, float, bool,bool, Dict[str, Any]]:
78
- """Applies the given action(s) to the environment and returns the new state.
 
 
79
 
80
  Args:
81
- action_dict (Dict[int, int]): A dictionary mapping agent IDs to actions.
 
82
  - For turn-based games: {current_player: action}
83
- - For simultaneous games: {player_0: action_0, player_1: action_1, ...}
 
84
 
85
  Returns:
86
  Tuple[Any, float, bool, bool, Dict[str, Any]]: A tuple containing:
87
  - observation (Any): The resulting state after the action.
88
  - reward (float): The reward obtained from this step.
89
  - terminated (bool): Whether the episode has ended normally.
90
- - truncated (bool): Whether the episode ended due to `max_game_rounds`.
91
- - info (Dict[str, Any]): Additional diagnostic information (e.g., final scores if done).
 
 
92
  """
93
 
94
  # Handle chance nodes
@@ -104,11 +114,14 @@ class OpenSpielEnv(ABC):
104
 
105
  # Move environment to the next state
106
  if self.state.is_simultaneous_node():
107
- actions = [action_dict[player] for player in sorted(action_dict.keys())]
 
 
108
  self.state.apply_actions(actions) # Multi-agent moves
109
  else:
110
  current_player = list(action_dict.keys())[0]
111
- self.state.apply_action(action_dict[current_player]) # Single action
 
112
 
113
  # Stepwise reward for each OpenSpiel-indexed agent
114
  reward_dict = self._compute_reward()
@@ -122,15 +135,21 @@ class OpenSpielEnv(ABC):
122
  and self.state.move_number() >= self.max_game_rounds
123
  )
124
 
125
- # If the game is finished, store final scores; otherwise, update current player
 
126
  if self.terminated or self.truncated:
127
  print("game terminated" if self.terminated else "game truncated")
128
- # Note: final rewards are corectly updated by the OpenSpiel rewards tracker.
129
- observation_dict = {agentID: None for agentID in list(action_dict.keys())} # No observation when the game ends
 
 
 
130
  else:
131
- observation_dict = self._state_to_observation() # Get next observation for all agents
 
132
 
133
- return observation_dict, reward_dict, self.terminated, self.truncated, self.info
 
134
 
135
  def render(self, mode: str = 'human'):
136
  """Print out the current state of the game."""
@@ -144,7 +163,7 @@ class OpenSpielEnv(ABC):
144
  Args:
145
  seed (int): The random seed.
146
  """
147
- self.random_generator = random.Random(seed) # Ensure Python's RNG is seeded
148
 
149
  # Set game seed if OpenSpiel supports it
150
  if hasattr(self.game, "set_seed"):
@@ -154,7 +173,8 @@ class OpenSpielEnv(ABC):
154
 
155
  def detect_illegal_moves(self, actions_dict: Dict[int, int]) -> int:
156
  """
157
- Detects illegal moves by comparing chosen actions with OpenSpiel's legal actions.
 
158
 
159
  Args:
160
  actions_dict: Dictionary mapping player IDs to chosen actions.
@@ -181,7 +201,8 @@ class OpenSpielEnv(ABC):
181
  """Returns the observation for each agent in the game.
182
 
183
  Returns:
184
- Dict[int, Dict[str, Any]]: Mapping from agent ID to their respective observations.
 
185
  """
186
 
187
  agent_id = self.state.current_player()
@@ -189,7 +210,8 @@ class OpenSpielEnv(ABC):
189
  agent_id: {
190
  "state_string": self.state.observation_string(agent_id),
191
  "legal_actions": self.state.legal_actions(agent_id),
192
- "prompt": self._generate_prompt(agent_id) # Overriden in some child classes
 
193
  }
194
  }
195
 
 
28
  Args:
29
  game (Any): The OpenSpiel game object being simulated.
30
  game_name (str): A human-readable name for the game.
31
+ player_type (Dict[str, str]): Maps "Player 1", "Player 2", ...
32
+ to their types (human, random, llm, etc.).
33
+ max_game_rounds (int): Maximum number of rounds for iterated games.
34
+ Ignored by single-shot games.
35
  seed (Optional[int]): Random seed for reproducibility.
36
  """
37
  self.game = game
38
  self.game_name = game_name
39
+ self.player_types = player_types # List of strings
40
  self.max_game_rounds = max_game_rounds # For iterated games only
41
  self.state = None
42
  self.info = {}
 
48
 
49
  self.state = None
50
 
51
+ def reset(self, seed: Optional[int] = None) -> Tuple[str, Dict[str, Any]]:
52
  """
53
+ Resets the environment to an initial state and returns an
54
+ initial observation.
55
 
56
  Args:
57
  seed (Optional[int]): Seed for environment randomization.
 
66
  if hasattr(self.game, "set_seed"):
67
  self.game.set_seed(seed)
68
 
69
+ # Instantiates the pyspiel game state
70
+ self.state = self.game.new_initial_state()
71
  self.terminated = False
72
  self.truncated = False
73
  self.info = {}
 
78
 
79
  return self._state_to_observation(), self.info
80
 
81
+ def step(self, action_dict: Dict[int, int]
82
+ ) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
83
+ """Applies the given action(s) to the environment
84
+ and returns the new state.
85
 
86
  Args:
87
+ action_dict (Dict[int, int]): A dictionary mapping
88
+ agent IDs to actions.
89
  - For turn-based games: {current_player: action}
90
+ - For simultaneous games:
91
+ {player_0: action_0, player_1: action_1, ...}
92
 
93
  Returns:
94
  Tuple[Any, float, bool, bool, Dict[str, Any]]: A tuple containing:
95
  - observation (Any): The resulting state after the action.
96
  - reward (float): The reward obtained from this step.
97
  - terminated (bool): Whether the episode has ended normally.
98
+ - truncated (bool): Whether the episode ended
99
+ due to `max_game_rounds`.
100
+ - info (Dict[str, Any]): Additional diagnostic
101
+ information (e.g., final scores if done).
102
  """
103
 
104
  # Handle chance nodes
 
114
 
115
  # Move environment to the next state
116
  if self.state.is_simultaneous_node():
117
+ actions = [action_dict[player] for player in sorted(
118
+ action_dict.keys()
119
+ )]
120
  self.state.apply_actions(actions) # Multi-agent moves
121
  else:
122
  current_player = list(action_dict.keys())[0]
123
+ # Single action
124
+ self.state.apply_action(action_dict[current_player])
125
 
126
  # Stepwise reward for each OpenSpiel-indexed agent
127
  reward_dict = self._compute_reward()
 
135
  and self.state.move_number() >= self.max_game_rounds
136
  )
137
 
138
+ # If the game is finished, store final scores;
139
+ # otherwise, update current player
140
  if self.terminated or self.truncated:
141
  print("game terminated" if self.terminated else "game truncated")
142
+ # Note: final rewards are correctly
143
+ # updated by the OpenSpiel rewards tracker.
144
+ observation_dict = {
145
+ agentID: None for agentID in list(action_dict.keys())
146
+ } # No observation when the game ends
147
  else:
148
+ # Get next observation for all agents
149
+ observation_dict = self._state_to_observation()
150
 
151
+ return (observation_dict, reward_dict, self.terminated,
152
+ self.truncated, self.info)
153
 
154
  def render(self, mode: str = 'human'):
155
  """Print out the current state of the game."""
 
163
  Args:
164
  seed (int): The random seed.
165
  """
166
+ self.random_generator = random.Random(seed)
167
 
168
  # Set game seed if OpenSpiel supports it
169
  if hasattr(self.game, "set_seed"):
 
173
 
174
  def detect_illegal_moves(self, actions_dict: Dict[int, int]) -> int:
175
  """
176
+ Detects illegal moves by comparing chosen actions
177
+ with OpenSpiel's legal actions.
178
 
179
  Args:
180
  actions_dict: Dictionary mapping player IDs to chosen actions.
 
201
  """Returns the observation for each agent in the game.
202
 
203
  Returns:
204
+ Dict[int, Dict[str, Any]]: Mapping from agent ID
205
+ to their respective observations.
206
  """
207
 
208
  agent_id = self.state.current_player()
 
210
  agent_id: {
211
  "state_string": self.state.observation_string(agent_id),
212
  "legal_actions": self.state.legal_actions(agent_id),
213
+ # Overriden in some child classes
214
+ "prompt": self._generate_prompt(agent_id)
215
  }
216
  }
217