File size: 10,981 Bytes
fda5d3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
''' Base class for simulating games.'''

import os
import json
from typing import Dict, Any, List
from abc import ABC
import random
from utils.llm_utils import generate_prompt, llm_decide_move
from enum import Enum, unique


@unique
class PlayerId(Enum):
    CHANCE = -1
    SIMULTANEOUS = -2
    INVALID = -3
    TERMINAL = -4
    MEAN_FIELD = -5

    @classmethod
    def from_value(cls, value: int):
        """Returns the PlayerId corresponding to a given integer value.

        Args:
            value (int): The numerical value to map to a PlayerId.

        Returns:
            PlayerId: The matching enum member, or raises a ValueError if invalid.
        """
        for member in cls:
            if member.value == value:
                return member
        if value >= 0:  # Positive integers represent default players
            return None  # No enum corresponds to these values directly
        raise ValueError(f"Unknown player ID value: {value}")


class PlayerType(Enum):
    HUMAN = "human"
    RANDOM_BOT = "random_bot"
    LLM = "llm"
    SELF_PLAY = "self_play"


class GameSimulator(ABC):
    """Base class for simulating games with LLMs.

    Handles common functionality like state transitions, scoring, and logging.
    """

    def __init__(self, game: Any, game_name: str, llms: Dict[str, Any],
                 player_type: Dict[str, str], max_game_rounds: int = None):
        """
        Args:
            game (Any): The OpenSpiel game object being simulated.
            game_name (str): A human-readable name for the game (for logging and reporting).
            llms (Dict[str, Any]): A dictionary mapping player names (e.g., "Player 1")
                to their corresponding LLM instances. Can be empty if no LLMs are used.
            player_type (Dict[str, str]): A dictionary mapping player names to their types.
            max_game_rounds (int): Maximum number of rounds for iterated games. Ignored by single-shot games.
        """
        self.game = game
        self.game_name = game_name
        self.llms = llms
        self.player_type = player_type
        self.max_game_rounds = max_game_rounds  # For iterated games
        self.scores = {name: 0 for name in self.llms.keys()}  # Initialize scores

    def simulate(self, rounds: int = 1, log_fn=None) -> Dict[str, Any]:
        """Simulates a game for multiple rounds and computes metrics .

        Args:
            rounds: Number of times the game should be played.
            log_fn: Optional function to log intermediate states.

        Returns:
            Dict[str, Any]: Summary of results for all rounds.
        """
        outcomes = self._initialize_outcomes() # Reset the outcomes dictionary

        for _ in range(rounds):
            self.scores = {name: 0 for name in self.llms.keys()}  # Reset scores
            state = self.game.new_initial_state()

            while not state.is_terminal():
                if self.max_game_rounds is not None and state.move_number() >= self.max_game_rounds:
                    # If max_game_rounds is specified, terminate the game after the maximum number of rounds.
                    # The state.move_number() method tracks the number of moves (or rounds) within the game.
                    # This ensures that iterated games, such as the Iterated Prisoner's Dilemma,
                    # stop after the specified number of rounds, even if the game would naturally continue.
                    break
                if log_fn:
                    log_fn(state)

                # Collect actions
                current_player = state.current_player()
                player_id = self.normalize_player_id(current_player)

                if player_id == PlayerId.CHANCE.value:
                    # Handle chance nodes where the environment acts randomly.
                    self._handle_chance_node(state)
                elif player_id == PlayerId.SIMULTANEOUS.value:
                     # Handle simultaneous moves for all players.
                    actions = self._collect_actions(state)
                    state.apply_actions(actions)
                elif player_id == PlayerId.TERMINAL.value:
                    break
                elif current_player >= 0:  # Default players (turn-based)
                    legal_actions = state.legal_actions(current_player)
                    action = self._get_action(current_player, state, legal_actions)
                    state.apply_action(action)
                else:
                    raise ValueError(f"Unexpected player ID: {current_player}")

            # Record outcomes
            final_scores = state.returns()
            self._record_outcomes(final_scores, outcomes)

        return outcomes

    def _handle_chance_node(self, state: Any):
        """Handle chance nodes. Default behavior raises an error."""
        raise NotImplementedError("Chance node handling not implemented for this game.")


    def _collect_actions(self, state: Any) -> List[int]:
        """Collects actions for all players in a simultaneous-move game.

        Args:
            state: The current game state.

        Returns:
            List[int]: Actions chosen by all players.
        """
        return [
            self._get_action(player, state, state.legal_actions(player))
            for player in range(self.game.num_players())
        ]

    def _initialize_outcomes(self) -> Dict[str, Any]:
        """Initializes the outcomes dictionary."""
        return {"wins": {name: 0 for name in self.llms.keys()},
                "losses": {name: 0 for name in self.llms.keys()},
                "ties": 0
                }


    def _get_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
        """Gets the action for the current player.

        Args:
            player: The index of the current player.
            state: The current game state.
            legal_actions: The legal actions available for the player.

        Returns:
            int: The action selected by the player.
        """
        player_name = f"Player {player + 1}"  # Map index to player name
        player_type = self.player_type.get(player_name)

        if player_type == PlayerType.HUMAN.value:
            return self._get_human_action(state, legal_actions)
        if player_type == PlayerType.RANDOM_BOT.value:
            return random.choice(legal_actions)
        if player_type == PlayerType.LLM.value:
            return self._get_llm_action(player, state, legal_actions)

        raise ValueError(f"Unknown player type for {player_name}: {player_type}")


    def _get_human_action(self, state: Any, legal_actions: List[int]) -> int:
        """Handles input for human players."""
        print(f"Current state of {self.game_name}:\n{state}")
        print(f"Your options: {legal_actions}") # Display legal moves to the user
        while True:
            try:
                action = int(input("Enter your action (number): "))
                if action in legal_actions: # Validate the move
                    return action
            except ValueError:
                pass
            print("Invalid action. Please choose from:", legal_actions)

    def _get_llm_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
        """Handles LLM-based decisions."""
        player_name = f"Player {player + 1}"
        llm = self.llms[player_name]
        prompt = generate_prompt(self.game_name, str(state), legal_actions)
        return llm_decide_move(llm, prompt, tuple(legal_actions))

    def _apply_default_action(self, state):
        """
        Applies a default action when the current player is invalid.
        """
        state.apply_action(random.choice(state.legal_actions()))

    def _record_outcomes(self, final_scores: List[float], outcomes: Dict[str, Any]) -> str:
        """Records the outcome of a single game round.

        Args:
            final_scores (List[float]): Final cumulative scores of all players.
            outcomes (Dict[str, Any]): Dictionary to record wins, losses, and ties.

        Returns:
            str: Name of the winner or "tie" if there is no single winner.
        """
        # Check if all scores are equal (a tie)
        if all(score == final_scores[0] for score in final_scores):
            outcomes["ties"] += 1
            return "tie"

        # Find the maximum score and determine winners
        max_score = max(final_scores)
        winners = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] == max_score]

        # Track losers as players who do not have the maximum score
        losers = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] != max_score]

        # If there is one winner, record it; otherwise, record as a tie
        if len(winners) == 1:
            outcomes["wins"][winners[0]] += 1
            for loser in losers:
                outcomes["losses"][loser] += 1
            return winners[0]
        else:
            outcomes["ties"] += 1
            return "tie"


    def save_results(self, state: Any, final_scores: List[float]) -> None:
        """Save simulation results to a JSON file."""
        results = self._prepare_results(state, final_scores)
        filename = self._get_results_filename()

        with open(filename, "w") as f:
            json.dump(results, f, indent=4)
        print(f"Results saved to {filename}")

    def _prepare_results(self, state: Any, final_scores: List[float]) -> Dict[str, Any]:
        """Prepares the results dictionary for JSON serialization."""
        final_scores = final_scores.tolist() if hasattr(final_scores, "tolist") else final_scores
        return {
            "game_name": self.game_name,
            "final_state": str(state),
            "scores": self.scores,
            "returns": final_scores,
            "history": state.history_str(),
        }

    def _get_results_filename(self) -> str:
        """Generates the filename for saving results."""
        results_dir = "results"
        os.makedirs(results_dir, exist_ok=True)
        return os.path.join(results_dir, f"{self.game_name.lower().replace(' ', '_')}_results.json")

    def log_progress(self, state: Any) -> None:
        """Log the current game state."""
        print(f"Current state of {self.game_name}:\n{state}")

    def normalize_player_id(self,player_id):
        """Normalize player_id to its integer value for consistent comparisons.

           This is needed as OpenSpiel has ambiguous representation of the playerID

        Args:
            player_id (Union[int, PlayerId]): The player ID, which can be an
                integer or a PlayerId enum instance.

        Returns:
            int: The integer value of the player ID.
        """
        if isinstance(player_id, PlayerId):
            return player_id.value  # Extract the integer value from the enum
        return player_id  # If already an integer, return it as is