Spaces:
Running
Running
Updated with Hex
Browse files
src/game_reasoning_arena/arena/envs/README.md
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Simulation logic for each game.
|
src/game_reasoning_arena/arena/envs/hex_env.py
CHANGED
@@ -7,6 +7,7 @@ Hex using the OpenSpiel framework.
|
|
7 |
from typing import Any, Dict, Optional
|
8 |
from .open_spiel_env import OpenSpielEnv
|
9 |
|
|
|
10 |
class HexEnv(OpenSpielEnv):
|
11 |
"""Environment Simulator for Hex."""
|
12 |
|
@@ -19,13 +20,13 @@ class HexEnv(OpenSpielEnv):
|
|
19 |
Args:
|
20 |
game: The OpenSpiel game object.
|
21 |
game_name: A string representing the name of the game.
|
22 |
-
player_types: A dictionary mapping player IDs to their types
|
|
|
23 |
max_game_rounds: Maximum number of rounds
|
24 |
for iterated games (optional, default is None).
|
25 |
"""
|
26 |
super().__init__(game, game_name, player_types, max_game_rounds, seed)
|
27 |
|
28 |
-
|
29 |
def get_player_symbol(self, agent_id: int) -> str:
|
30 |
"""Returns the symbol used by a Tic Tac Toe player.
|
31 |
|
@@ -47,7 +48,9 @@ class HexEnv(OpenSpielEnv):
|
|
47 |
str: Legal action numbers and a flattened board index layout.
|
48 |
"""
|
49 |
legal = self.state.legal_actions(agent_id)
|
50 |
-
size
|
|
|
|
|
51 |
|
52 |
# Create a flat index grid (diagonal shape)
|
53 |
grid = []
|
@@ -74,7 +77,9 @@ class HexEnv(OpenSpielEnv):
|
|
74 |
raw = self.state.observation_string(agent_id)
|
75 |
symbols = [char for char in raw if char in ("y", "o", ".")]
|
76 |
|
77 |
-
size
|
|
|
|
|
78 |
rows = []
|
79 |
idx = 0
|
80 |
for row in range(size):
|
|
|
7 |
from typing import Any, Dict, Optional
|
8 |
from .open_spiel_env import OpenSpielEnv
|
9 |
|
10 |
+
|
11 |
class HexEnv(OpenSpielEnv):
|
12 |
"""Environment Simulator for Hex."""
|
13 |
|
|
|
20 |
Args:
|
21 |
game: The OpenSpiel game object.
|
22 |
game_name: A string representing the name of the game.
|
23 |
+
player_types: A dictionary mapping player IDs to their types
|
24 |
+
(e.g., human, random).
|
25 |
max_game_rounds: Maximum number of rounds
|
26 |
for iterated games (optional, default is None).
|
27 |
"""
|
28 |
super().__init__(game, game_name, player_types, max_game_rounds, seed)
|
29 |
|
|
|
30 |
def get_player_symbol(self, agent_id: int) -> str:
|
31 |
"""Returns the symbol used by a Tic Tac Toe player.
|
32 |
|
|
|
48 |
str: Legal action numbers and a flattened board index layout.
|
49 |
"""
|
50 |
legal = self.state.legal_actions(agent_id)
|
51 |
+
# Get board size from observation tensor shape
|
52 |
+
obs_shape = self.game.observation_tensor_shape()
|
53 |
+
size = obs_shape[-1] # Usually 11
|
54 |
|
55 |
# Create a flat index grid (diagonal shape)
|
56 |
grid = []
|
|
|
77 |
raw = self.state.observation_string(agent_id)
|
78 |
symbols = [char for char in raw if char in ("y", "o", ".")]
|
79 |
|
80 |
+
# Get board size from observation tensor shape
|
81 |
+
obs_shape = self.game.observation_tensor_shape()
|
82 |
+
size = obs_shape[-1] # typically 11
|
83 |
rows = []
|
84 |
idx = 0
|
85 |
for row in range(size):
|
src/game_reasoning_arena/arena/envs/kuhn_poker_env.py
CHANGED
@@ -9,8 +9,8 @@ game state and potential strategies.
|
|
9 |
"""
|
10 |
|
11 |
from typing import Any, Dict, Optional
|
12 |
-
from .open_spiel_env import OpenSpielEnv
|
13 |
from game_reasoning_arena.arena.agents.llm_utils import format_prompt
|
|
|
14 |
|
15 |
|
16 |
class KuhnPokerEnv(OpenSpielEnv):
|
|
|
9 |
"""
|
10 |
|
11 |
from typing import Any, Dict, Optional
|
|
|
12 |
from game_reasoning_arena.arena.agents.llm_utils import format_prompt
|
13 |
+
from .open_spiel_env import OpenSpielEnv
|
14 |
|
15 |
|
16 |
class KuhnPokerEnv(OpenSpielEnv):
|
src/game_reasoning_arena/arena/envs/matrix_game_env.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
"""Simulator for Matrix Games.
|
2 |
|
3 |
This module implements the MatrixGameEnvclass, which handles various
|
4 |
-
matrix games like Rock-Paper-Scissors and Prisoner's Dilemma using
|
5 |
-
framework.
|
6 |
"""
|
7 |
|
8 |
from typing import Any, Dict, List, Optional
|
@@ -20,7 +20,8 @@ class MatrixGameEnv(OpenSpielEnv):
|
|
20 |
Args:
|
21 |
game: The OpenSpiel game object.
|
22 |
game_name: A string representing the name of the game.
|
23 |
-
player_types: A dictionary mapping player IDs to their types
|
|
|
24 |
max_game_rounds: Maximum number of rounds
|
25 |
for iterated games (optional, default is None).
|
26 |
"""
|
@@ -78,23 +79,23 @@ class MatrixGameEnv(OpenSpielEnv):
|
|
78 |
|
79 |
prompt = f"""You are Player {agent_id} in the game: {self.game_name}
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
94 |
|
95 |
return prompt
|
96 |
|
97 |
def render_board(self, agent_id: int) -> str:
|
98 |
# Matrix games have no spatial board; return a basic description.
|
99 |
-
|
100 |
-
return "Matrix game – no board representation available"
|
|
|
1 |
"""Simulator for Matrix Games.
|
2 |
|
3 |
This module implements the MatrixGameEnvclass, which handles various
|
4 |
+
matrix games like Rock-Paper-Scissors and Prisoner's Dilemma using
|
5 |
+
the OpenSpiel framework.
|
6 |
"""
|
7 |
|
8 |
from typing import Any, Dict, List, Optional
|
|
|
20 |
Args:
|
21 |
game: The OpenSpiel game object.
|
22 |
game_name: A string representing the name of the game.
|
23 |
+
player_types: A dictionary mapping player IDs to their types
|
24 |
+
(e.g., human, random).
|
25 |
max_game_rounds: Maximum number of rounds
|
26 |
for iterated games (optional, default is None).
|
27 |
"""
|
|
|
79 |
|
80 |
prompt = f"""You are Player {agent_id} in the game: {self.game_name}
|
81 |
|
82 |
+
Available actions:
|
83 |
+
{action_list}
|
84 |
+
|
85 |
+
What action do you choose? Reply only with the action number.
|
86 |
+
|
87 |
+
First, think through the game strategy
|
88 |
+
and explain your reasoning.
|
89 |
+
Only after that, decide on the best action to take.
|
90 |
+
|
91 |
+
Reply only in the following JSON format:
|
92 |
+
{{
|
93 |
+
'reasoning': <str>,
|
94 |
+
'action': <int>
|
95 |
+
}}"""
|
96 |
|
97 |
return prompt
|
98 |
|
99 |
def render_board(self, agent_id: int) -> str:
|
100 |
# Matrix games have no spatial board; return a basic description.
|
101 |
+
return "Matrix game – no board representation available"
|
|
src/game_reasoning_arena/arena/envs/open_spiel_env.py
CHANGED
@@ -28,13 +28,15 @@ class OpenSpielEnv(ABC):
|
|
28 |
Args:
|
29 |
game (Any): The OpenSpiel game object being simulated.
|
30 |
game_name (str): A human-readable name for the game.
|
31 |
-
player_type (Dict[str, str]): Maps "Player 1", "Player 2", ...
|
32 |
-
|
|
|
|
|
33 |
seed (Optional[int]): Random seed for reproducibility.
|
34 |
"""
|
35 |
self.game = game
|
36 |
self.game_name = game_name
|
37 |
-
self.player_types = player_types
|
38 |
self.max_game_rounds = max_game_rounds # For iterated games only
|
39 |
self.state = None
|
40 |
self.info = {}
|
@@ -46,9 +48,10 @@ class OpenSpielEnv(ABC):
|
|
46 |
|
47 |
self.state = None
|
48 |
|
49 |
-
def reset(self, seed: Optional[int]=None) -> Tuple[str, Dict[str, Any]]:
|
50 |
"""
|
51 |
-
Resets the environment to an initial state and returns an
|
|
|
52 |
|
53 |
Args:
|
54 |
seed (Optional[int]): Seed for environment randomization.
|
@@ -63,7 +66,8 @@ class OpenSpielEnv(ABC):
|
|
63 |
if hasattr(self.game, "set_seed"):
|
64 |
self.game.set_seed(seed)
|
65 |
|
66 |
-
|
|
|
67 |
self.terminated = False
|
68 |
self.truncated = False
|
69 |
self.info = {}
|
@@ -74,21 +78,27 @@ class OpenSpielEnv(ABC):
|
|
74 |
|
75 |
return self._state_to_observation(), self.info
|
76 |
|
77 |
-
def step(self, action_dict: Dict[int, int]
|
78 |
-
|
|
|
|
|
79 |
|
80 |
Args:
|
81 |
-
action_dict (Dict[int, int]): A dictionary mapping
|
|
|
82 |
- For turn-based games: {current_player: action}
|
83 |
-
- For simultaneous games:
|
|
|
84 |
|
85 |
Returns:
|
86 |
Tuple[Any, float, bool, bool, Dict[str, Any]]: A tuple containing:
|
87 |
- observation (Any): The resulting state after the action.
|
88 |
- reward (float): The reward obtained from this step.
|
89 |
- terminated (bool): Whether the episode has ended normally.
|
90 |
-
- truncated (bool): Whether the episode ended
|
91 |
-
|
|
|
|
|
92 |
"""
|
93 |
|
94 |
# Handle chance nodes
|
@@ -104,11 +114,14 @@ class OpenSpielEnv(ABC):
|
|
104 |
|
105 |
# Move environment to the next state
|
106 |
if self.state.is_simultaneous_node():
|
107 |
-
actions = [action_dict[player] for player in sorted(
|
|
|
|
|
108 |
self.state.apply_actions(actions) # Multi-agent moves
|
109 |
else:
|
110 |
current_player = list(action_dict.keys())[0]
|
111 |
-
|
|
|
112 |
|
113 |
# Stepwise reward for each OpenSpiel-indexed agent
|
114 |
reward_dict = self._compute_reward()
|
@@ -122,15 +135,21 @@ class OpenSpielEnv(ABC):
|
|
122 |
and self.state.move_number() >= self.max_game_rounds
|
123 |
)
|
124 |
|
125 |
-
# If the game is finished, store final scores;
|
|
|
126 |
if self.terminated or self.truncated:
|
127 |
print("game terminated" if self.terminated else "game truncated")
|
128 |
-
# Note: final rewards are
|
129 |
-
|
|
|
|
|
|
|
130 |
else:
|
131 |
-
|
|
|
132 |
|
133 |
-
return observation_dict, reward_dict, self.terminated,
|
|
|
134 |
|
135 |
def render(self, mode: str = 'human'):
|
136 |
"""Print out the current state of the game."""
|
@@ -144,7 +163,7 @@ class OpenSpielEnv(ABC):
|
|
144 |
Args:
|
145 |
seed (int): The random seed.
|
146 |
"""
|
147 |
-
self.random_generator = random.Random(seed)
|
148 |
|
149 |
# Set game seed if OpenSpiel supports it
|
150 |
if hasattr(self.game, "set_seed"):
|
@@ -154,7 +173,8 @@ class OpenSpielEnv(ABC):
|
|
154 |
|
155 |
def detect_illegal_moves(self, actions_dict: Dict[int, int]) -> int:
|
156 |
"""
|
157 |
-
Detects illegal moves by comparing chosen actions
|
|
|
158 |
|
159 |
Args:
|
160 |
actions_dict: Dictionary mapping player IDs to chosen actions.
|
@@ -181,7 +201,8 @@ class OpenSpielEnv(ABC):
|
|
181 |
"""Returns the observation for each agent in the game.
|
182 |
|
183 |
Returns:
|
184 |
-
Dict[int, Dict[str, Any]]: Mapping from agent ID
|
|
|
185 |
"""
|
186 |
|
187 |
agent_id = self.state.current_player()
|
@@ -189,7 +210,8 @@ class OpenSpielEnv(ABC):
|
|
189 |
agent_id: {
|
190 |
"state_string": self.state.observation_string(agent_id),
|
191 |
"legal_actions": self.state.legal_actions(agent_id),
|
192 |
-
|
|
|
193 |
}
|
194 |
}
|
195 |
|
|
|
28 |
Args:
|
29 |
game (Any): The OpenSpiel game object being simulated.
|
30 |
game_name (str): A human-readable name for the game.
|
31 |
+
player_type (Dict[str, str]): Maps "Player 1", "Player 2", ...
|
32 |
+
to their types (human, random, llm, etc.).
|
33 |
+
max_game_rounds (int): Maximum number of rounds for iterated games.
|
34 |
+
Ignored by single-shot games.
|
35 |
seed (Optional[int]): Random seed for reproducibility.
|
36 |
"""
|
37 |
self.game = game
|
38 |
self.game_name = game_name
|
39 |
+
self.player_types = player_types # List of strings
|
40 |
self.max_game_rounds = max_game_rounds # For iterated games only
|
41 |
self.state = None
|
42 |
self.info = {}
|
|
|
48 |
|
49 |
self.state = None
|
50 |
|
51 |
+
def reset(self, seed: Optional[int] = None) -> Tuple[str, Dict[str, Any]]:
|
52 |
"""
|
53 |
+
Resets the environment to an initial state and returns an
|
54 |
+
initial observation.
|
55 |
|
56 |
Args:
|
57 |
seed (Optional[int]): Seed for environment randomization.
|
|
|
66 |
if hasattr(self.game, "set_seed"):
|
67 |
self.game.set_seed(seed)
|
68 |
|
69 |
+
# Instantiates the pyspiel game state
|
70 |
+
self.state = self.game.new_initial_state()
|
71 |
self.terminated = False
|
72 |
self.truncated = False
|
73 |
self.info = {}
|
|
|
78 |
|
79 |
return self._state_to_observation(), self.info
|
80 |
|
81 |
+
def step(self, action_dict: Dict[int, int]
|
82 |
+
) -> Tuple[Any, float, bool, bool, Dict[str, Any]]:
|
83 |
+
"""Applies the given action(s) to the environment
|
84 |
+
and returns the new state.
|
85 |
|
86 |
Args:
|
87 |
+
action_dict (Dict[int, int]): A dictionary mapping
|
88 |
+
agent IDs to actions.
|
89 |
- For turn-based games: {current_player: action}
|
90 |
+
- For simultaneous games:
|
91 |
+
{player_0: action_0, player_1: action_1, ...}
|
92 |
|
93 |
Returns:
|
94 |
Tuple[Any, float, bool, bool, Dict[str, Any]]: A tuple containing:
|
95 |
- observation (Any): The resulting state after the action.
|
96 |
- reward (float): The reward obtained from this step.
|
97 |
- terminated (bool): Whether the episode has ended normally.
|
98 |
+
- truncated (bool): Whether the episode ended
|
99 |
+
due to `max_game_rounds`.
|
100 |
+
- info (Dict[str, Any]): Additional diagnostic
|
101 |
+
information (e.g., final scores if done).
|
102 |
"""
|
103 |
|
104 |
# Handle chance nodes
|
|
|
114 |
|
115 |
# Move environment to the next state
|
116 |
if self.state.is_simultaneous_node():
|
117 |
+
actions = [action_dict[player] for player in sorted(
|
118 |
+
action_dict.keys()
|
119 |
+
)]
|
120 |
self.state.apply_actions(actions) # Multi-agent moves
|
121 |
else:
|
122 |
current_player = list(action_dict.keys())[0]
|
123 |
+
# Single action
|
124 |
+
self.state.apply_action(action_dict[current_player])
|
125 |
|
126 |
# Stepwise reward for each OpenSpiel-indexed agent
|
127 |
reward_dict = self._compute_reward()
|
|
|
135 |
and self.state.move_number() >= self.max_game_rounds
|
136 |
)
|
137 |
|
138 |
+
# If the game is finished, store final scores;
|
139 |
+
# otherwise, update current player
|
140 |
if self.terminated or self.truncated:
|
141 |
print("game terminated" if self.terminated else "game truncated")
|
142 |
+
# Note: final rewards are correctly
|
143 |
+
# updated by the OpenSpiel rewards tracker.
|
144 |
+
observation_dict = {
|
145 |
+
agentID: None for agentID in list(action_dict.keys())
|
146 |
+
} # No observation when the game ends
|
147 |
else:
|
148 |
+
# Get next observation for all agents
|
149 |
+
observation_dict = self._state_to_observation()
|
150 |
|
151 |
+
return (observation_dict, reward_dict, self.terminated,
|
152 |
+
self.truncated, self.info)
|
153 |
|
154 |
def render(self, mode: str = 'human'):
|
155 |
"""Print out the current state of the game."""
|
|
|
163 |
Args:
|
164 |
seed (int): The random seed.
|
165 |
"""
|
166 |
+
self.random_generator = random.Random(seed)
|
167 |
|
168 |
# Set game seed if OpenSpiel supports it
|
169 |
if hasattr(self.game, "set_seed"):
|
|
|
173 |
|
174 |
def detect_illegal_moves(self, actions_dict: Dict[int, int]) -> int:
|
175 |
"""
|
176 |
+
Detects illegal moves by comparing chosen actions
|
177 |
+
with OpenSpiel's legal actions.
|
178 |
|
179 |
Args:
|
180 |
actions_dict: Dictionary mapping player IDs to chosen actions.
|
|
|
201 |
"""Returns the observation for each agent in the game.
|
202 |
|
203 |
Returns:
|
204 |
+
Dict[int, Dict[str, Any]]: Mapping from agent ID
|
205 |
+
to their respective observations.
|
206 |
"""
|
207 |
|
208 |
agent_id = self.state.current_player()
|
|
|
210 |
agent_id: {
|
211 |
"state_string": self.state.observation_string(agent_id),
|
212 |
"legal_actions": self.state.legal_actions(agent_id),
|
213 |
+
# Overriden in some child classes
|
214 |
+
"prompt": self._generate_prompt(agent_id)
|
215 |
}
|
216 |
}
|
217 |
|