Added leaderboard with DB connectivity
Browse files- README.md +50 -0
- app.py +6 -0
- arena/board.py +63 -9
- arena/board_view.py +5 -1
- arena/c4.py +132 -22
- arena/game.py +56 -4
- arena/llm.py +54 -64
- arena/player.py +31 -7
- arena/record.py +178 -0
- connect.png +0 -0
- prototype.ipynb +0 -0
- requirements.txt +2 -1
README.md
CHANGED
@@ -15,3 +15,53 @@ short_description: Arena for playing Four-in-a-row between LLMs
|
|
15 |
# Four-in-a-row Arena
|
16 |
|
17 |
### A battleground for pitting LLMs against each other in the classic board game
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Four-in-a-row Arena
|
16 |
|
17 |
### A battleground for pitting LLMs against each other in the classic board game
|
18 |
+
|
19 |
+

|
20 |
+
|
21 |
+
It has been great fun making this Arena and watching LLMs duke it out!
|
22 |
+
|
23 |
+
Quick links:
|
24 |
+
- The [GitHub repo](https://github.com/ed-donner/connect) for the code
|
25 |
+
- The [HuggingFace Spaces](https://huggingface.co/spaces/ed-donner/connect) where it's running
|
26 |
+
- My [LinkedIn](https://www.linkedin.com/in/eddonner/) - I love connecting!
|
27 |
+
|
28 |
+
If you'd like to learn more about this:
|
29 |
+
- I have a best-selling intensive 8-week [Mastering LLM engineering](https://www.udemy.com/course/llm-engineering-master-ai-and-large-language-models/?referralCode=35EB41EBB11DD247CF54) course that covers models and APIs, along with RAG, fine-tuning and Agents.
|
30 |
+
- I'm running a number of [Live Events](https://www.oreilly.com/search/?q=author%3A%20%22Ed%20Donner%22) with O'Reilly and Pearson
|
31 |
+
|
32 |
+
## Installing the code
|
33 |
+
|
34 |
+
1. Clone the repo with `git clone https://github.com/ed-donner/connect.git`
|
35 |
+
2. Change to the project directory with `cd connect`
|
36 |
+
3. Create a python virtualenv with `python -m venv venv`
|
37 |
+
4. Activate your environment with either `venv\Scripts\activate` on Windows, or `source venv/bin/activate` on Mac/Linux
|
38 |
+
5. Then run `pip install -r requirements.txt` to install the packages
|
39 |
+
|
40 |
+
## Setting up your API keys
|
41 |
+
|
42 |
+
Please create a file with the exact name `.env` in the project root directory (connect).
|
43 |
+
|
44 |
+
You would typically use Notepad (Windows) or nano (Mac) for this.
|
45 |
+
|
46 |
+
If you're not familiar with setting up a .env file this way, ask ChatGPT! It will give much more eloquent instructions than me. 😂
|
47 |
+
|
48 |
+
Your .env file should contain the following; add whichever keys you would like to use.
|
49 |
+
|
50 |
+
```
|
51 |
+
OPENAI_API_KEY=sk-proj-...
|
52 |
+
ANTHROPIC_API_KEY=sk-ant-...
|
53 |
+
DEEPSEEK_API_KEY=sk...
|
54 |
+
GROQ_API_KEY=...
|
55 |
+
```
|
56 |
+
|
57 |
+
## Optional - using Ollama
|
58 |
+
|
59 |
+
You can run Ollama locally, and the Arena will connect to run local models.
|
60 |
+
1. Download and install Ollama from https://ollama.com noting that on a PC you might need to have administrator permissions for the install to work properly
|
61 |
+
2. On a PC, start a Command prompt / Powershell (Press Win + R, type `cmd`, and press Enter). On a Mac, start a Terminal (Applications > Utilities > Terminal).
|
62 |
+
3. Run `ollama run llama3.2` or for smaller machines try `ollama run llama3.2:1b`
|
63 |
+
4. If this doesn't work, you may need to run `ollama serve` in another Powershell (Windows) or Terminal (Mac), and try step 3 again
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
app.py
CHANGED
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from arena.c4 import make_display
|
2 |
from dotenv import load_dotenv
|
3 |
|
|
|
1 |
+
"""
|
2 |
+
The main entry-point for the Spaces application
|
3 |
+
Create a Gradio app and launch it
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
from arena.c4 import make_display
|
8 |
from dotenv import load_dotenv
|
9 |
|
arena/board.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from arena.board_view import to_svg
|
|
|
2 |
|
3 |
RED = 1
|
4 |
YELLOW = -1
|
@@ -10,8 +11,15 @@ cols = "ABCDEFG"
|
|
10 |
|
11 |
|
12 |
class Board:
|
|
|
|
|
|
|
13 |
|
14 |
def __init__(self):
|
|
|
|
|
|
|
|
|
15 |
self.cells = [[0 for _ in range(7)] for _ in range(6)]
|
16 |
self.player = RED
|
17 |
self.winner = EMPTY
|
@@ -20,6 +28,9 @@ class Board:
|
|
20 |
self.latest_x, self.latest_y = -1, -1
|
21 |
|
22 |
def __repr__(self):
|
|
|
|
|
|
|
23 |
result = ""
|
24 |
for y in range(6):
|
25 |
for x in range(7):
|
@@ -29,6 +40,9 @@ class Board:
|
|
29 |
return result
|
30 |
|
31 |
def message(self):
|
|
|
|
|
|
|
32 |
if self.winner and self.forfeit:
|
33 |
return f"{show[self.winner]} wins after an illegal move by {show[-1*self.winner]}\n"
|
34 |
elif self.winner:
|
@@ -39,16 +53,24 @@ class Board:
|
|
39 |
return f"{show[self.player]} to play\n"
|
40 |
|
41 |
def html(self):
|
|
|
|
|
|
|
42 |
result = '<div style="text-align: center;font-size:24px">'
|
43 |
result += self.__repr__().replace("\n", "<br/>")
|
44 |
result += "</div>"
|
45 |
return result
|
46 |
|
47 |
def svg(self):
|
48 |
-
"""
|
|
|
|
|
49 |
return to_svg(self)
|
50 |
|
51 |
def json(self):
|
|
|
|
|
|
|
52 |
result = "{\n"
|
53 |
result += ' "Column names": ["A", "B", "C", "D", "E", "F", "G"],\n'
|
54 |
for y in range(6):
|
@@ -60,6 +82,9 @@ class Board:
|
|
60 |
return result
|
61 |
|
62 |
def alternative(self):
|
|
|
|
|
|
|
63 |
result = "ABCDEFG\n"
|
64 |
for y in range(6):
|
65 |
for x in range(7):
|
@@ -67,19 +92,32 @@ class Board:
|
|
67 |
result += "\n"
|
68 |
return result
|
69 |
|
70 |
-
def height(self, x):
|
|
|
|
|
|
|
71 |
height = 0
|
72 |
while height < 6 and self.cells[height][x] != EMPTY:
|
73 |
height += 1
|
74 |
return height
|
75 |
|
76 |
-
def legal_moves(self):
|
|
|
|
|
|
|
77 |
return [cols[x] for x in range(7) if self.height(x) < 6]
|
78 |
|
79 |
-
def illegal_moves(self):
|
|
|
|
|
|
|
80 |
return [cols[x] for x in range(7) if self.height(x) == 6]
|
81 |
|
82 |
-
def winning_line(self, x, y, dx, dy):
|
|
|
|
|
|
|
|
|
83 |
color = self.cells[y][x]
|
84 |
for pointer in range(1, 4):
|
85 |
xp = x + dx * pointer
|
@@ -88,20 +126,33 @@ class Board:
|
|
88 |
return EMPTY
|
89 |
return color
|
90 |
|
91 |
-
def winning_cell(self, x, y):
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
for dx, dy in ((0, 1), (1, 1), (1, 0), (1, -1)):
|
93 |
if winner := self.winning_line(x, y, dx, dy):
|
94 |
return winner
|
95 |
return EMPTY
|
96 |
|
97 |
-
def wins(self):
|
|
|
|
|
|
|
|
|
98 |
for y in range(6):
|
99 |
for x in range(7):
|
100 |
if winner := self.winning_cell(x, y):
|
101 |
return winner
|
102 |
return EMPTY
|
103 |
|
104 |
-
def move(self, x):
|
|
|
|
|
|
|
105 |
y = self.height(x)
|
106 |
self.cells[y][x] = self.player
|
107 |
self.latest_x, self.latest_y = x, y
|
@@ -113,5 +164,8 @@ class Board:
|
|
113 |
self.player = -1 * self.player
|
114 |
return self
|
115 |
|
116 |
-
def is_active(self):
|
|
|
|
|
|
|
117 |
return not self.winner and not self.draw
|
|
|
1 |
from arena.board_view import to_svg
|
2 |
+
from typing import List
|
3 |
|
4 |
RED = 1
|
5 |
YELLOW = -1
|
|
|
11 |
|
12 |
|
13 |
class Board:
|
14 |
+
"""
|
15 |
+
A class to represent a Four-in-the-row Board
|
16 |
+
"""
|
17 |
|
18 |
def __init__(self):
|
19 |
+
"""
|
20 |
+
Initialize this instance, starting with empty cells, RED to play
|
21 |
+
The latest x,y is used to track the most recent move, so it animates on the display
|
22 |
+
"""
|
23 |
self.cells = [[0 for _ in range(7)] for _ in range(6)]
|
24 |
self.player = RED
|
25 |
self.winner = EMPTY
|
|
|
28 |
self.latest_x, self.latest_y = -1, -1
|
29 |
|
30 |
def __repr__(self):
|
31 |
+
"""
|
32 |
+
A visual representation
|
33 |
+
"""
|
34 |
result = ""
|
35 |
for y in range(6):
|
36 |
for x in range(7):
|
|
|
40 |
return result
|
41 |
|
42 |
def message(self):
|
43 |
+
"""
|
44 |
+
A summary of the status
|
45 |
+
"""
|
46 |
if self.winner and self.forfeit:
|
47 |
return f"{show[self.winner]} wins after an illegal move by {show[-1*self.winner]}\n"
|
48 |
elif self.winner:
|
|
|
53 |
return f"{show[self.player]} to play\n"
|
54 |
|
55 |
def html(self):
|
56 |
+
"""
|
57 |
+
Return an HTML representation
|
58 |
+
"""
|
59 |
result = '<div style="text-align: center;font-size:24px">'
|
60 |
result += self.__repr__().replace("\n", "<br/>")
|
61 |
result += "</div>"
|
62 |
return result
|
63 |
|
64 |
def svg(self):
|
65 |
+
"""
|
66 |
+
Return an SVG representation
|
67 |
+
"""
|
68 |
return to_svg(self)
|
69 |
|
70 |
def json(self):
|
71 |
+
"""
|
72 |
+
Return a json representation
|
73 |
+
"""
|
74 |
result = "{\n"
|
75 |
result += ' "Column names": ["A", "B", "C", "D", "E", "F", "G"],\n'
|
76 |
for y in range(6):
|
|
|
82 |
return result
|
83 |
|
84 |
def alternative(self):
|
85 |
+
"""
|
86 |
+
An alternative representation, used in prompting so that the LLM sees this 2 ways
|
87 |
+
"""
|
88 |
result = "ABCDEFG\n"
|
89 |
for y in range(6):
|
90 |
for x in range(7):
|
|
|
92 |
result += "\n"
|
93 |
return result
|
94 |
|
95 |
+
def height(self, x: int) -> int:
|
96 |
+
"""
|
97 |
+
Return the height of the given column
|
98 |
+
"""
|
99 |
height = 0
|
100 |
while height < 6 and self.cells[height][x] != EMPTY:
|
101 |
height += 1
|
102 |
return height
|
103 |
|
104 |
+
def legal_moves(self) -> List[str]:
|
105 |
+
"""
|
106 |
+
Return the names of columns that are not full
|
107 |
+
"""
|
108 |
return [cols[x] for x in range(7) if self.height(x) < 6]
|
109 |
|
110 |
+
def illegal_moves(self) -> List[str]:
|
111 |
+
"""
|
112 |
+
Return the names of columns that are full
|
113 |
+
"""
|
114 |
return [cols[x] for x in range(7) if self.height(x) == 6]
|
115 |
|
116 |
+
def winning_line(self, x: int, y: int, dx: int, dy: int) -> int:
|
117 |
+
"""
|
118 |
+
Return RED or YELLOW if this cell is the start of a 4 in the row going in the direction dx, dy
|
119 |
+
Or EMPTY if not
|
120 |
+
"""
|
121 |
color = self.cells[y][x]
|
122 |
for pointer in range(1, 4):
|
123 |
xp = x + dx * pointer
|
|
|
126 |
return EMPTY
|
127 |
return color
|
128 |
|
129 |
+
def winning_cell(self, x: int, y: int) -> int:
|
130 |
+
"""
|
131 |
+
Return RED or YELLOW if this cell is the start of a 4 in the row
|
132 |
+
Or EMPTY if not
|
133 |
+
For performance reasons, only look in 4 of the possible 8 directions,
|
134 |
+
(because this test will run on both sides of the 4-in-a-row)
|
135 |
+
"""
|
136 |
for dx, dy in ((0, 1), (1, 1), (1, 0), (1, -1)):
|
137 |
if winner := self.winning_line(x, y, dx, dy):
|
138 |
return winner
|
139 |
return EMPTY
|
140 |
|
141 |
+
def wins(self) -> int:
|
142 |
+
"""
|
143 |
+
Return RED or YELLOW if there is a 4-in-a-row of that color on the board
|
144 |
+
Or EMPTY if not
|
145 |
+
"""
|
146 |
for y in range(6):
|
147 |
for x in range(7):
|
148 |
if winner := self.winning_cell(x, y):
|
149 |
return winner
|
150 |
return EMPTY
|
151 |
|
152 |
+
def move(self, x: int):
|
153 |
+
"""
|
154 |
+
Make a move in the given column
|
155 |
+
"""
|
156 |
y = self.height(x)
|
157 |
self.cells[y][x] = self.player
|
158 |
self.latest_x, self.latest_y = x, y
|
|
|
164 |
self.player = -1 * self.player
|
165 |
return self
|
166 |
|
167 |
+
def is_active(self) -> bool:
|
168 |
+
"""
|
169 |
+
Return true if the game has not yet ended
|
170 |
+
"""
|
171 |
return not self.winner and not self.draw
|
arena/board_view.py
CHANGED
@@ -3,7 +3,11 @@ YELLOW = -1
|
|
3 |
EMPTY = 0
|
4 |
|
5 |
def to_svg(board):
|
6 |
-
"""
|
|
|
|
|
|
|
|
|
7 |
svg = '''
|
8 |
<div style="display: flex; justify-content: center;">
|
9 |
<svg width="450" height="420" viewBox="0 0 450 420">
|
|
|
3 |
EMPTY = 0
|
4 |
|
5 |
def to_svg(board):
|
6 |
+
"""
|
7 |
+
Create an SVG representation of the board, with the latest piece dropping down via SVG
|
8 |
+
I must confess that this function was written almost entirely by Claude; done in 15 mins,
|
9 |
+
when it would have taken me a couple of hours. Amazing!
|
10 |
+
"""
|
11 |
svg = '''
|
12 |
<div style="display: flex; justify-content: center;">
|
13 |
<svg width="450" height="420" viewBox="0 0 450 420">
|
arena/c4.py
CHANGED
@@ -4,7 +4,13 @@ from arena.llm import LLM
|
|
4 |
import gradio as gr
|
5 |
|
6 |
|
7 |
-
css = "
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
js = """
|
10 |
function refresh() {
|
@@ -18,20 +24,70 @@ function refresh() {
|
|
18 |
"""
|
19 |
|
20 |
|
21 |
-
def message_html(game):
|
|
|
|
|
|
|
22 |
return (
|
23 |
f'<div style="text-align: center;font-size:18px">{game.board.message()}</div>'
|
24 |
)
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def load_callback(red_llm, yellow_llm):
|
|
|
|
|
|
|
28 |
game = Game(red_llm, yellow_llm)
|
29 |
enabled = gr.Button(interactive=True)
|
30 |
message = message_html(game)
|
31 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
def move_callback(game):
|
|
|
|
|
|
|
35 |
game.move()
|
36 |
message = message_html(game)
|
37 |
if_active = gr.Button(interactive=game.board.is_active())
|
@@ -47,8 +103,13 @@ def move_callback(game):
|
|
47 |
|
48 |
|
49 |
def run_callback(game):
|
|
|
|
|
|
|
|
|
50 |
enabled = gr.Button(interactive=True)
|
51 |
disabled = gr.Button(interactive=False)
|
|
|
52 |
message = message_html(game)
|
53 |
yield game, game.board.svg(), message, game.thoughts(RED), game.thoughts(
|
54 |
YELLOW
|
@@ -59,26 +120,39 @@ def run_callback(game):
|
|
59 |
yield game, game.board.svg(), message, game.thoughts(RED), game.thoughts(
|
60 |
YELLOW
|
61 |
), disabled, disabled, disabled
|
|
|
62 |
yield game, game.board.svg(), message, game.thoughts(RED), game.thoughts(
|
63 |
YELLOW
|
64 |
), disabled, disabled, enabled
|
65 |
|
66 |
|
67 |
def model_callback(player_name, game, new_model_name):
|
|
|
|
|
|
|
68 |
player = game.players[player_name]
|
69 |
player.switch_model(new_model_name)
|
70 |
return game
|
71 |
|
72 |
|
73 |
def red_model_callback(game, new_model_name):
|
|
|
|
|
|
|
74 |
return model_callback(RED, game, new_model_name)
|
75 |
|
76 |
|
77 |
def yellow_model_callback(game, new_model_name):
|
|
|
|
|
|
|
78 |
return model_callback(YELLOW, game, new_model_name)
|
79 |
|
80 |
|
81 |
def player_section(name, default):
|
|
|
|
|
|
|
82 |
all_model_names = LLM.all_model_names()
|
83 |
with gr.Row():
|
84 |
gr.HTML(f'<div style="text-align: center;font-size:18px">{name} Player</div>')
|
@@ -94,6 +168,9 @@ def player_section(name, default):
|
|
94 |
|
95 |
|
96 |
def make_display():
|
|
|
|
|
|
|
97 |
with gr.Blocks(
|
98 |
title="C4 Battle",
|
99 |
css=css,
|
@@ -103,31 +180,60 @@ def make_display():
|
|
103 |
|
104 |
game = gr.State()
|
105 |
|
106 |
-
with gr.
|
107 |
-
gr.
|
108 |
-
|
109 |
-
)
|
110 |
-
with gr.Row():
|
111 |
-
with gr.Column(scale=1):
|
112 |
-
red_thoughts, red_dropdown = player_section("Red", "gpt-4o-mini")
|
113 |
-
with gr.Column(scale=2):
|
114 |
with gr.Row():
|
115 |
-
|
116 |
-
'<div style="text-align: center;font-size:
|
117 |
)
|
118 |
-
with gr.Row():
|
119 |
-
board_display = gr.HTML()
|
120 |
with gr.Row():
|
121 |
with gr.Column(scale=1):
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
with gr.Column(scale=1):
|
124 |
-
|
|
|
|
|
|
|
|
|
125 |
with gr.Column(scale=1):
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
blocks.load(
|
133 |
load_callback,
|
@@ -191,4 +297,8 @@ def make_display():
|
|
191 |
],
|
192 |
)
|
193 |
|
|
|
|
|
|
|
|
|
194 |
return blocks
|
|
|
4 |
import gradio as gr
|
5 |
|
6 |
|
7 |
+
css = """
|
8 |
+
.dataframe-fix .table-wrap {
|
9 |
+
min-height: 800px;
|
10 |
+
max-height: 800px;
|
11 |
+
}
|
12 |
+
footer{display:none !important}
|
13 |
+
"""
|
14 |
|
15 |
js = """
|
16 |
function refresh() {
|
|
|
24 |
"""
|
25 |
|
26 |
|
27 |
+
def message_html(game) -> str:
|
28 |
+
"""
|
29 |
+
Return the message for the top of the UI
|
30 |
+
"""
|
31 |
return (
|
32 |
f'<div style="text-align: center;font-size:18px">{game.board.message()}</div>'
|
33 |
)
|
34 |
|
35 |
|
36 |
+
def format_records_for_table(games):
|
37 |
+
"""
|
38 |
+
Turn the results objects into a list of lists for the Gradio Dataframe
|
39 |
+
"""
|
40 |
+
return [
|
41 |
+
[
|
42 |
+
game.when,
|
43 |
+
game.red_player,
|
44 |
+
game.yellow_player,
|
45 |
+
"Red" if game.red_won else "Yellow" if game.yellow_won else "Draw",
|
46 |
+
]
|
47 |
+
for game in reversed(games)
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
def format_ratings_for_table(ratings):
|
52 |
+
"""
|
53 |
+
Turn the ratings into a List of Lists for the Gradio Dataframe
|
54 |
+
"""
|
55 |
+
items = sorted(ratings.items(), key=lambda x: x[1], reverse=True)
|
56 |
+
return [[item[0], int(round(item[1]))] for item in items]
|
57 |
+
|
58 |
+
|
59 |
def load_callback(red_llm, yellow_llm):
|
60 |
+
"""
|
61 |
+
Callback called when the game is started. Create a new Game object for the state.
|
62 |
+
"""
|
63 |
game = Game(red_llm, yellow_llm)
|
64 |
enabled = gr.Button(interactive=True)
|
65 |
message = message_html(game)
|
66 |
+
return (
|
67 |
+
game,
|
68 |
+
game.board.svg(),
|
69 |
+
message,
|
70 |
+
"",
|
71 |
+
"",
|
72 |
+
enabled,
|
73 |
+
enabled,
|
74 |
+
enabled,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def leaderboard_callback(game):
|
79 |
+
"""
|
80 |
+
Callback called when the user switches to the Leaderboard tab. Load in the results.
|
81 |
+
"""
|
82 |
+
records_df = format_records_for_table(Game.get_games())
|
83 |
+
ratings_df = format_ratings_for_table(Game.get_ratings())
|
84 |
+
return records_df, ratings_df
|
85 |
|
86 |
|
87 |
def move_callback(game):
|
88 |
+
"""
|
89 |
+
Callback called when the user clicks to do a single move.
|
90 |
+
"""
|
91 |
game.move()
|
92 |
message = message_html(game)
|
93 |
if_active = gr.Button(interactive=game.board.is_active())
|
|
|
103 |
|
104 |
|
105 |
def run_callback(game):
|
106 |
+
"""
|
107 |
+
Callback called when the user runs an entire game. Reset the board, run the game, store results.
|
108 |
+
Yield interim results so the UI updates.
|
109 |
+
"""
|
110 |
enabled = gr.Button(interactive=True)
|
111 |
disabled = gr.Button(interactive=False)
|
112 |
+
game.reset()
|
113 |
message = message_html(game)
|
114 |
yield game, game.board.svg(), message, game.thoughts(RED), game.thoughts(
|
115 |
YELLOW
|
|
|
120 |
yield game, game.board.svg(), message, game.thoughts(RED), game.thoughts(
|
121 |
YELLOW
|
122 |
), disabled, disabled, disabled
|
123 |
+
game.record()
|
124 |
yield game, game.board.svg(), message, game.thoughts(RED), game.thoughts(
|
125 |
YELLOW
|
126 |
), disabled, disabled, enabled
|
127 |
|
128 |
|
129 |
def model_callback(player_name, game, new_model_name):
|
130 |
+
"""
|
131 |
+
Callback when the user changes the model
|
132 |
+
"""
|
133 |
player = game.players[player_name]
|
134 |
player.switch_model(new_model_name)
|
135 |
return game
|
136 |
|
137 |
|
138 |
def red_model_callback(game, new_model_name):
|
139 |
+
"""
|
140 |
+
Callback when red model is changed
|
141 |
+
"""
|
142 |
return model_callback(RED, game, new_model_name)
|
143 |
|
144 |
|
145 |
def yellow_model_callback(game, new_model_name):
|
146 |
+
"""
|
147 |
+
Callback when yellow model is changed
|
148 |
+
"""
|
149 |
return model_callback(YELLOW, game, new_model_name)
|
150 |
|
151 |
|
152 |
def player_section(name, default):
|
153 |
+
"""
|
154 |
+
Create the left and right sections of the UI
|
155 |
+
"""
|
156 |
all_model_names = LLM.all_model_names()
|
157 |
with gr.Row():
|
158 |
gr.HTML(f'<div style="text-align: center;font-size:18px">{name} Player</div>')
|
|
|
168 |
|
169 |
|
170 |
def make_display():
|
171 |
+
"""
|
172 |
+
The Gradio UI to show the Game, with event handlers
|
173 |
+
"""
|
174 |
with gr.Blocks(
|
175 |
title="C4 Battle",
|
176 |
css=css,
|
|
|
180 |
|
181 |
game = gr.State()
|
182 |
|
183 |
+
with gr.Tabs():
|
184 |
+
with gr.TabItem("Game"):
|
185 |
+
|
|
|
|
|
|
|
|
|
|
|
186 |
with gr.Row():
|
187 |
+
gr.HTML(
|
188 |
+
'<div style="text-align: center;font-size:24px">Four-in-a-row LLM Showdown</div>'
|
189 |
)
|
|
|
|
|
190 |
with gr.Row():
|
191 |
with gr.Column(scale=1):
|
192 |
+
red_thoughts, red_dropdown = player_section(
|
193 |
+
"Red", "gpt-4o-mini"
|
194 |
+
)
|
195 |
+
with gr.Column(scale=2):
|
196 |
+
with gr.Row():
|
197 |
+
message = gr.HTML(
|
198 |
+
'<div style="text-align: center;font-size:18px">The Board</div>'
|
199 |
+
)
|
200 |
+
with gr.Row():
|
201 |
+
board_display = gr.HTML()
|
202 |
+
with gr.Row():
|
203 |
+
with gr.Column(scale=1):
|
204 |
+
move_button = gr.Button("Next move")
|
205 |
+
with gr.Column(scale=1):
|
206 |
+
run_button = gr.Button("Run game", variant="primary")
|
207 |
+
with gr.Column(scale=1):
|
208 |
+
reset_button = gr.Button("Start Over", variant="stop")
|
209 |
with gr.Column(scale=1):
|
210 |
+
yellow_thoughts, yellow_dropdown = player_section(
|
211 |
+
"Yellow", "claude-3-5-sonnet-latest"
|
212 |
+
)
|
213 |
+
with gr.TabItem("Leaderboard") as leaderboard_tab:
|
214 |
+
with gr.Row():
|
215 |
with gr.Column(scale=1):
|
216 |
+
ratings_df = gr.Dataframe(
|
217 |
+
headers=["Player", "ELO"],
|
218 |
+
label="Ratings",
|
219 |
+
column_widths=[2, 1],
|
220 |
+
wrap=True,
|
221 |
+
col_count=2,
|
222 |
+
row_count=10,
|
223 |
+
max_height=800,
|
224 |
+
elem_classes=["dataframe-fix"],
|
225 |
+
)
|
226 |
+
with gr.Column(scale=2):
|
227 |
+
results_df = gr.Dataframe(
|
228 |
+
headers=["When", "Red Player", "Yellow Player", "Winner"],
|
229 |
+
label="Game History",
|
230 |
+
column_widths=[2, 2, 2, 1],
|
231 |
+
wrap=True,
|
232 |
+
col_count=4,
|
233 |
+
row_count=10,
|
234 |
+
max_height=800,
|
235 |
+
elem_classes=["dataframe-fix"],
|
236 |
+
)
|
237 |
|
238 |
blocks.load(
|
239 |
load_callback,
|
|
|
297 |
],
|
298 |
)
|
299 |
|
300 |
+
leaderboard_tab.select(
|
301 |
+
leaderboard_callback, inputs=[game], outputs=[results_df, ratings_df]
|
302 |
+
)
|
303 |
+
|
304 |
return blocks
|
arena/game.py
CHANGED
@@ -1,26 +1,78 @@
|
|
1 |
-
from arena.board import Board, RED, YELLOW
|
2 |
from arena.player import Player
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
class Game:
|
|
|
|
|
|
|
6 |
|
7 |
-
def __init__(self, model_red, model_yellow):
|
|
|
|
|
|
|
8 |
self.board = Board()
|
9 |
self.players = {
|
10 |
RED: Player(model_red, RED),
|
11 |
YELLOW: Player(model_yellow, YELLOW),
|
12 |
}
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def move(self):
|
|
|
|
|
|
|
15 |
self.players[self.board.player].move(self.board)
|
16 |
|
17 |
-
def is_active(self):
|
|
|
|
|
|
|
18 |
return self.board.is_active()
|
19 |
|
20 |
-
def thoughts(self, player):
|
|
|
|
|
|
|
21 |
return self.players[player].thoughts()
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def run(self):
|
|
|
|
|
|
|
24 |
while self.is_active():
|
25 |
self.move()
|
26 |
print(self.board)
|
|
|
1 |
+
from arena.board import Board, RED, YELLOW
|
2 |
from arena.player import Player
|
3 |
+
from arena.record import get_games, Result, record_game, ratings
|
4 |
+
from datetime import datetime
|
5 |
+
from typing import List
|
6 |
|
7 |
|
8 |
class Game:
|
9 |
+
"""
|
10 |
+
A Game consists of a Board and 2 players
|
11 |
+
"""
|
12 |
|
13 |
+
def __init__(self, model_red: str, model_yellow: str):
|
14 |
+
"""
|
15 |
+
Initialize this Game; a new board, and new Player objects
|
16 |
+
"""
|
17 |
self.board = Board()
|
18 |
self.players = {
|
19 |
RED: Player(model_red, RED),
|
20 |
YELLOW: Player(model_yellow, YELLOW),
|
21 |
}
|
22 |
|
23 |
+
def reset(self):
|
24 |
+
"""
|
25 |
+
Restart the game by resetting the board; keep players the same
|
26 |
+
"""
|
27 |
+
self.board = Board()
|
28 |
+
|
29 |
def move(self):
|
30 |
+
"""
|
31 |
+
Make the next move. Delegate to the current player to make a move on this board.
|
32 |
+
"""
|
33 |
self.players[self.board.player].move(self.board)
|
34 |
|
35 |
+
def is_active(self) -> bool:
|
36 |
+
"""
|
37 |
+
Return true if the game hasn't yet ended
|
38 |
+
"""
|
39 |
return self.board.is_active()
|
40 |
|
41 |
+
def thoughts(self, player) -> str:
|
42 |
+
"""
|
43 |
+
Return the inner thoughts of the given player
|
44 |
+
"""
|
45 |
return self.players[player].thoughts()
|
46 |
|
47 |
+
@staticmethod
|
48 |
+
def get_games() -> List:
|
49 |
+
"""
|
50 |
+
Return all the games stored in the db
|
51 |
+
"""
|
52 |
+
return get_games()
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def get_ratings():
|
56 |
+
"""
|
57 |
+
Return the ELO ratings of all players
|
58 |
+
"""
|
59 |
+
return ratings()
|
60 |
+
|
61 |
+
def record(self):
|
62 |
+
"""
|
63 |
+
Store the results of this game in the DB
|
64 |
+
"""
|
65 |
+
red_player = self.players[RED].llm.model_name
|
66 |
+
yellow_player = self.players[YELLOW].llm.model_name
|
67 |
+
red_won = self.board.winner == RED
|
68 |
+
yellow_won = self.board.winner == YELLOW
|
69 |
+
result = Result(red_player, yellow_player, red_won, yellow_won, datetime.now())
|
70 |
+
record_game(result)
|
71 |
+
|
72 |
def run(self):
|
73 |
+
"""
|
74 |
+
If being used outside gradio; move and print in a loop
|
75 |
+
"""
|
76 |
while self.is_active():
|
77 |
self.move()
|
78 |
print(self.board)
|
arena/llm.py
CHANGED
@@ -48,7 +48,11 @@ class LLM(ABC):
|
|
48 |
return result
|
49 |
|
50 |
def protected_send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
while retries:
|
53 |
retries -= 1
|
54 |
try:
|
@@ -61,9 +65,27 @@ class LLM(ABC):
|
|
61 |
return "{}"
|
62 |
|
63 |
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
def api_model_name(self):
|
|
|
|
|
|
|
67 |
if " " in self.model_name:
|
68 |
return self.model_name.split(" ")[0]
|
69 |
else:
|
@@ -83,6 +105,10 @@ class LLM(ABC):
|
|
83 |
|
84 |
@classmethod
|
85 |
def all_model_names(cls) -> List[str]:
|
|
|
|
|
|
|
|
|
86 |
models = list(cls.model_map().keys())
|
87 |
allowed = os.getenv("MODELS")
|
88 |
if allowed:
|
@@ -153,28 +179,10 @@ class GPT(LLM):
|
|
153 |
super().__init__(model_name, temperature)
|
154 |
self.client = OpenAI()
|
155 |
|
156 |
-
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
157 |
-
"""
|
158 |
-
Send a message to GPT
|
159 |
-
:param system: the context in which this message is to be taken
|
160 |
-
:param user: the prompt
|
161 |
-
:param max_tokens: max number of tokens to generate
|
162 |
-
:return: the response from the AI
|
163 |
-
"""
|
164 |
-
response = self.client.chat.completions.create(
|
165 |
-
model=self.api_model_name(),
|
166 |
-
messages=[
|
167 |
-
{"role": "system", "content": system},
|
168 |
-
{"role": "user", "content": user},
|
169 |
-
],
|
170 |
-
response_format={"type": "json_object"},
|
171 |
-
)
|
172 |
-
return response.choices[0].message.content
|
173 |
-
|
174 |
|
175 |
class O1(LLM):
|
176 |
"""
|
177 |
-
A class to act as an interface to the remote AI, in this case
|
178 |
"""
|
179 |
|
180 |
model_names = ["o1-mini"]
|
@@ -188,7 +196,7 @@ class O1(LLM):
|
|
188 |
|
189 |
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
190 |
"""
|
191 |
-
Send a message to
|
192 |
:param system: the context in which this message is to be taken
|
193 |
:param user: the prompt
|
194 |
:param max_tokens: max number of tokens to generate
|
@@ -206,7 +214,7 @@ class O1(LLM):
|
|
206 |
|
207 |
class O3(LLM):
|
208 |
"""
|
209 |
-
A class to act as an interface to the remote AI, in this case
|
210 |
"""
|
211 |
|
212 |
model_names = ["o3-mini"]
|
@@ -225,7 +233,7 @@ class O3(LLM):
|
|
225 |
|
226 |
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
227 |
"""
|
228 |
-
Send a message to
|
229 |
:param system: the context in which this message is to be taken
|
230 |
:param user: the prompt
|
231 |
:param max_tokens: max number of tokens to generate
|
@@ -241,6 +249,25 @@ class O3(LLM):
|
|
241 |
return response.choices[0].message.content
|
242 |
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
class Ollama(LLM):
|
245 |
"""
|
246 |
A class to act as an interface to the remote AI, in this case Ollama via the OpenAI client
|
@@ -250,7 +277,7 @@ class Ollama(LLM):
|
|
250 |
|
251 |
def __init__(self, model_name: str, temperature: float):
|
252 |
"""
|
253 |
-
Create a new instance of the OpenAI client
|
254 |
"""
|
255 |
super().__init__(model_name, temperature)
|
256 |
self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
|
@@ -296,25 +323,6 @@ class DeepSeekAPI(LLM):
|
|
296 |
api_key=deepseek_api_key, base_url="https://api.deepseek.com"
|
297 |
)
|
298 |
|
299 |
-
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
300 |
-
"""
|
301 |
-
Send a message to DeepSeek
|
302 |
-
:param system: the context in which this message is to be taken
|
303 |
-
:param user: the prompt
|
304 |
-
:param max_tokens: max number of tokens to generate
|
305 |
-
:return: the response from the AI
|
306 |
-
"""
|
307 |
-
|
308 |
-
response = self.client.chat.completions.create(
|
309 |
-
model=self.api_model_name(),
|
310 |
-
messages=[
|
311 |
-
{"role": "system", "content": system},
|
312 |
-
{"role": "user", "content": user},
|
313 |
-
],
|
314 |
-
)
|
315 |
-
reply = response.choices[0].message.content
|
316 |
-
return reply
|
317 |
-
|
318 |
|
319 |
class DeepSeekLocal(LLM):
|
320 |
"""
|
@@ -367,25 +375,7 @@ class GroqAPI(LLM):
|
|
367 |
|
368 |
def __init__(self, model_name: str, temperature: float):
|
369 |
"""
|
370 |
-
Create a new instance of the
|
371 |
"""
|
372 |
super().__init__(model_name, temperature)
|
373 |
self.client = Groq()
|
374 |
-
|
375 |
-
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
376 |
-
"""
|
377 |
-
Send a message to GPT
|
378 |
-
:param system: the context in which this message is to be taken
|
379 |
-
:param user: the prompt
|
380 |
-
:param max_tokens: max number of tokens to generate
|
381 |
-
:return: the response from the AI
|
382 |
-
"""
|
383 |
-
response = self.client.chat.completions.create(
|
384 |
-
model=self.api_model_name(),
|
385 |
-
messages=[
|
386 |
-
{"role": "system", "content": system},
|
387 |
-
{"role": "user", "content": user},
|
388 |
-
],
|
389 |
-
response_format={"type": "json_object"},
|
390 |
-
)
|
391 |
-
return response.choices[0].message.content
|
|
|
48 |
return result
|
49 |
|
50 |
def protected_send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
51 |
+
"""
|
52 |
+
Wrap the send call in an exception handler, giving the LLM 3 chances in total, in case
|
53 |
+
of overload errors. If it fails 3 times, then it forfeits!
|
54 |
+
"""
|
55 |
+
retries = 3
|
56 |
while retries:
|
57 |
retries -= 1
|
58 |
try:
|
|
|
65 |
return "{}"
|
66 |
|
67 |
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
68 |
+
"""
|
69 |
+
Send a message to the model - this default implementation follows the OpenAI API structure
|
70 |
+
:param system: the context in which this message is to be taken
|
71 |
+
:param user: the prompt
|
72 |
+
:param max_tokens: max number of tokens to generate
|
73 |
+
:return: the response from the AI
|
74 |
+
"""
|
75 |
+
response = self.client.chat.completions.create(
|
76 |
+
model=self.api_model_name(),
|
77 |
+
messages=[
|
78 |
+
{"role": "system", "content": system},
|
79 |
+
{"role": "user", "content": user},
|
80 |
+
],
|
81 |
+
response_format={"type": "json_object"},
|
82 |
+
)
|
83 |
+
return response.choices[0].message.content
|
84 |
|
85 |
+
def api_model_name(self) -> str:
|
86 |
+
"""
|
87 |
+
Return the actual model_name to be used in the call to the API; strip out anything after a space
|
88 |
+
"""
|
89 |
if " " in self.model_name:
|
90 |
return self.model_name.split(" ")[0]
|
91 |
else:
|
|
|
105 |
|
106 |
@classmethod
|
107 |
def all_model_names(cls) -> List[str]:
|
108 |
+
"""
|
109 |
+
Return a list of all the model names supported.
|
110 |
+
Use the ones specified in the model_map, but also check if there's an env variable set that restricts the models
|
111 |
+
"""
|
112 |
models = list(cls.model_map().keys())
|
113 |
allowed = os.getenv("MODELS")
|
114 |
if allowed:
|
|
|
179 |
super().__init__(model_name, temperature)
|
180 |
self.client = OpenAI()
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
class O1(LLM):
|
184 |
"""
|
185 |
+
A class to act as an interface to the remote AI, in this case O1
|
186 |
"""
|
187 |
|
188 |
model_names = ["o1-mini"]
|
|
|
196 |
|
197 |
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
198 |
"""
|
199 |
+
Send a message to O1
|
200 |
:param system: the context in which this message is to be taken
|
201 |
:param user: the prompt
|
202 |
:param max_tokens: max number of tokens to generate
|
|
|
214 |
|
215 |
class O3(LLM):
|
216 |
"""
|
217 |
+
A class to act as an interface to the remote AI, in this case O3
|
218 |
"""
|
219 |
|
220 |
model_names = ["o3-mini"]
|
|
|
233 |
|
234 |
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
|
235 |
"""
|
236 |
+
Send a message to O3
|
237 |
:param system: the context in which this message is to be taken
|
238 |
:param user: the prompt
|
239 |
:param max_tokens: max number of tokens to generate
|
|
|
249 |
return response.choices[0].message.content
|
250 |
|
251 |
|
252 |
+
class Gemini(LLM):
|
253 |
+
"""
|
254 |
+
A class to act as an interface to the remote AI, in this case Gemini
|
255 |
+
"""
|
256 |
+
|
257 |
+
model_names = ["gemini-2.0-flash", "gemini-1.5-flash"]
|
258 |
+
|
259 |
+
def __init__(self, model_name: str, temperature: float):
|
260 |
+
"""
|
261 |
+
Create a new instance of the OpenAI client
|
262 |
+
"""
|
263 |
+
super().__init__(model_name, temperature)
|
264 |
+
google_api_key = os.getenv("GOOGLE_API_KEY")
|
265 |
+
self.client = OpenAI(
|
266 |
+
api_key=google_api_key,
|
267 |
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
class Ollama(LLM):
|
272 |
"""
|
273 |
A class to act as an interface to the remote AI, in this case Ollama via the OpenAI client
|
|
|
277 |
|
278 |
def __init__(self, model_name: str, temperature: float):
|
279 |
"""
|
280 |
+
Create a new instance of the OpenAI client for Ollama
|
281 |
"""
|
282 |
super().__init__(model_name, temperature)
|
283 |
self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
|
|
|
323 |
api_key=deepseek_api_key, base_url="https://api.deepseek.com"
|
324 |
)
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
class DeepSeekLocal(LLM):
|
328 |
"""
|
|
|
375 |
|
376 |
def __init__(self, model_name: str, temperature: float):
|
377 |
"""
|
378 |
+
Create a new instance of the Groq client
|
379 |
"""
|
380 |
super().__init__(model_name, temperature)
|
381 |
self.client = Groq()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arena/player.py
CHANGED
@@ -5,8 +5,15 @@ import random
|
|
5 |
|
6 |
|
7 |
class Player:
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
self.color = color
|
11 |
self.model = model
|
12 |
self.llm = LLM.create(self.model)
|
@@ -15,7 +22,10 @@ class Player:
|
|
15 |
self.opportunities = ""
|
16 |
self.strategy = ""
|
17 |
|
18 |
-
def system(self, board, legal_moves, illegal_moves):
|
|
|
|
|
|
|
19 |
return f"""You are playing the board game Connect 4.
|
20 |
Players take turns to drop counters into one of 7 columns A, B, C, D, E, F, G.
|
21 |
The winner is the first player to get 4 counters in a row in any direction.
|
@@ -33,7 +43,10 @@ You should respond in JSON according to this spec:
|
|
33 |
|
34 |
You must pick one of these letters for your move_column: {legal_moves}{illegal_moves}"""
|
35 |
|
36 |
-
def user(self, board, legal_moves, illegal_moves):
|
|
|
|
|
|
|
37 |
return f"""It is your turn to make a move as {pieces[self.color]}.
|
38 |
Here is the current board, with row 1 at the bottom of the board:
|
39 |
|
@@ -78,8 +91,10 @@ Now make your decision.
|
|
78 |
You must pick one of these letters for your move_column: {legal_moves}{illegal_moves}
|
79 |
"""
|
80 |
|
81 |
-
def process_move(self, reply, board):
|
82 |
-
|
|
|
|
|
83 |
try:
|
84 |
if len(reply) == 3 and reply[0] == "{" and reply[2] == "}":
|
85 |
reply = f'{{"move_column": "{reply[1]}"}}'
|
@@ -100,6 +115,9 @@ You must pick one of these letters for your move_column: {legal_moves}{illegal_m
|
|
100 |
board.winner = -1 * board.player
|
101 |
|
102 |
def move(self, board):
|
|
|
|
|
|
|
103 |
legal_moves = ", ".join(board.legal_moves())
|
104 |
if illegal := board.illegal_moves():
|
105 |
illegal_moves = (
|
@@ -114,6 +132,9 @@ You must pick one of these letters for your move_column: {legal_moves}{illegal_m
|
|
114 |
self.process_move(reply, board)
|
115 |
|
116 |
def thoughts(self):
|
|
|
|
|
|
|
117 |
result = '<div style="text-align: left;font-size:14px"><br/>'
|
118 |
result += f"<b>Evaluation:</b><br/>{self.evaluation}<br/><br/>"
|
119 |
result += f"<b>Threats:</b><br/>{self.threats}<br/><br/>"
|
@@ -122,5 +143,8 @@ You must pick one of these letters for your move_column: {legal_moves}{illegal_m
|
|
122 |
result += "</div>"
|
123 |
return result
|
124 |
|
125 |
-
def switch_model(self, new_model_name):
|
|
|
|
|
|
|
126 |
self.llm = LLM.create(new_model_name)
|
|
|
5 |
|
6 |
|
7 |
class Player:
|
8 |
+
"""
|
9 |
+
This class represents one AI player in the game, and is responsible for managing the prompts
|
10 |
+
Delegating to an LLM instance to connect to the LLM
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, model: str, color: int):
|
14 |
+
"""
|
15 |
+
Set up this instance for the given model and player color
|
16 |
+
"""
|
17 |
self.color = color
|
18 |
self.model = model
|
19 |
self.llm = LLM.create(self.model)
|
|
|
22 |
self.opportunities = ""
|
23 |
self.strategy = ""
|
24 |
|
25 |
+
def system(self, board, legal_moves: str, illegal_moves: str) -> str:
|
26 |
+
"""
|
27 |
+
Return the system prompt for this move
|
28 |
+
"""
|
29 |
return f"""You are playing the board game Connect 4.
|
30 |
Players take turns to drop counters into one of 7 columns A, B, C, D, E, F, G.
|
31 |
The winner is the first player to get 4 counters in a row in any direction.
|
|
|
43 |
|
44 |
You must pick one of these letters for your move_column: {legal_moves}{illegal_moves}"""
|
45 |
|
46 |
+
def user(self, board, legal_moves: str, illegal_moves: str) -> str:
|
47 |
+
"""
|
48 |
+
Return the user prompt for this move
|
49 |
+
"""
|
50 |
return f"""It is your turn to make a move as {pieces[self.color]}.
|
51 |
Here is the current board, with row 1 at the bottom of the board:
|
52 |
|
|
|
91 |
You must pick one of these letters for your move_column: {legal_moves}{illegal_moves}
|
92 |
"""
|
93 |
|
94 |
+
def process_move(self, reply: str, board):
|
95 |
+
"""
|
96 |
+
Interpret the reply and make the move; if the move is illegal, then the current player loses
|
97 |
+
"""
|
98 |
try:
|
99 |
if len(reply) == 3 and reply[0] == "{" and reply[2] == "}":
|
100 |
reply = f'{{"move_column": "{reply[1]}"}}'
|
|
|
115 |
board.winner = -1 * board.player
|
116 |
|
117 |
def move(self, board):
|
118 |
+
"""
|
119 |
+
Have the underlying LLM make a move, and process the result
|
120 |
+
"""
|
121 |
legal_moves = ", ".join(board.legal_moves())
|
122 |
if illegal := board.illegal_moves():
|
123 |
illegal_moves = (
|
|
|
132 |
self.process_move(reply, board)
|
133 |
|
134 |
def thoughts(self):
|
135 |
+
"""
|
136 |
+
Return HTML to describe the inner thoughts
|
137 |
+
"""
|
138 |
result = '<div style="text-align: left;font-size:14px"><br/>'
|
139 |
result += f"<b>Evaluation:</b><br/>{self.evaluation}<br/><br/>"
|
140 |
result += f"<b>Threats:</b><br/>{self.threats}<br/><br/>"
|
|
|
143 |
result += "</div>"
|
144 |
return result
|
145 |
|
146 |
+
def switch_model(self, new_model_name: str):
|
147 |
+
"""
|
148 |
+
Change the underlying LLM to the new model
|
149 |
+
"""
|
150 |
self.llm = LLM.create(new_model_name)
|
arena/record.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
from datetime import datetime
|
5 |
+
from typing import List, Dict
|
6 |
+
from dataclasses import dataclass, asdict
|
7 |
+
from pymongo import MongoClient
|
8 |
+
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Result:
|
13 |
+
red_player: str
|
14 |
+
yellow_player: str
|
15 |
+
red_won: bool
|
16 |
+
yellow_won: bool
|
17 |
+
when: datetime
|
18 |
+
|
19 |
+
|
20 |
+
COLLECTION = "connect"
|
21 |
+
|
22 |
+
|
23 |
+
def _get_collection():
|
24 |
+
"""Helper function to get MongoDB collection with error handling"""
|
25 |
+
try:
|
26 |
+
mongo_uri = os.getenv("MONGO_URI")
|
27 |
+
if mongo_uri:
|
28 |
+
client = MongoClient(mongo_uri, serverSelectionTimeoutMS=5000)
|
29 |
+
# Quick check if we can actually connect
|
30 |
+
client.admin.command("ismaster")
|
31 |
+
db = client.outsmart
|
32 |
+
return db[COLLECTION]
|
33 |
+
except (ConnectionFailure, ServerSelectionTimeoutError):
|
34 |
+
return None
|
35 |
+
|
36 |
+
|
37 |
+
def record_game(result: Result) -> bool:
|
38 |
+
"""
|
39 |
+
Store the results in the database, if database is available.
|
40 |
+
Returns True if successful, False if database is unavailable.
|
41 |
+
"""
|
42 |
+
collection = _get_collection()
|
43 |
+
if collection is None:
|
44 |
+
return False
|
45 |
+
|
46 |
+
# Convert Result object to dictionary for MongoDB storage
|
47 |
+
game_dict = asdict(result)
|
48 |
+
|
49 |
+
try:
|
50 |
+
collection.insert_one(game_dict)
|
51 |
+
return True
|
52 |
+
except Exception as e:
|
53 |
+
logging.error("Failed to record a game in the database")
|
54 |
+
logging.exception(e)
|
55 |
+
return False
|
56 |
+
|
57 |
+
|
58 |
+
def get_games() -> List[Result]:
|
59 |
+
"""
|
60 |
+
Return all games in the order that they were played.
|
61 |
+
Returns empty list if database is unavailable.
|
62 |
+
"""
|
63 |
+
collection = _get_collection()
|
64 |
+
if collection is None:
|
65 |
+
return []
|
66 |
+
|
67 |
+
try:
|
68 |
+
# Sort by _id to maintain insertion order
|
69 |
+
games = collection.find().sort("_id", 1)
|
70 |
+
|
71 |
+
# Convert MongoDB documents back to Result objects
|
72 |
+
results = []
|
73 |
+
for game in games:
|
74 |
+
# Remove MongoDB's _id field
|
75 |
+
game.pop("_id", None)
|
76 |
+
results.append(Result(**game))
|
77 |
+
|
78 |
+
return results
|
79 |
+
except Exception as e:
|
80 |
+
logging.error("Error getting games")
|
81 |
+
logging.exception(e)
|
82 |
+
return []
|
83 |
+
|
84 |
+
|
85 |
+
class EloCalculator:
|
86 |
+
def __init__(self, k_factor: float = 32, default_rating: int = 1000):
|
87 |
+
"""
|
88 |
+
Initialize the ELO calculator.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
k_factor: Determines how much ratings change after each game
|
92 |
+
default_rating: Starting rating for new players
|
93 |
+
"""
|
94 |
+
self.k_factor = k_factor
|
95 |
+
self.default_rating = default_rating
|
96 |
+
self.ratings: Dict[str, float] = {}
|
97 |
+
|
98 |
+
def get_player_rating(self, player: str) -> float:
|
99 |
+
"""Get a player's current rating, or default if they're new."""
|
100 |
+
return self.ratings.get(player, self.default_rating)
|
101 |
+
|
102 |
+
def calculate_expected_score(self, rating_a: float, rating_b: float) -> float:
|
103 |
+
"""
|
104 |
+
Calculate the expected score (win probability) for player A against player B.
|
105 |
+
Uses the ELO formula: 1 / (1 + 10^((ratingB - ratingA)/400))
|
106 |
+
"""
|
107 |
+
return 1 / (1 + math.pow(10, (rating_b - rating_a) / 400))
|
108 |
+
|
109 |
+
def update_ratings(
|
110 |
+
self, player_a: str, player_b: str, score_a: float, score_b: float
|
111 |
+
) -> None:
|
112 |
+
"""
|
113 |
+
Update ratings for two players based on their game outcome.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
player_a: Name of first player
|
117 |
+
player_b: Name of second player
|
118 |
+
score_a: Actual score for player A (1 for win, 0.5 for draw, 0 for loss)
|
119 |
+
score_b: Actual score for player B (1 for win, 0.5 for draw, 0 for loss)
|
120 |
+
"""
|
121 |
+
rating_a = self.get_player_rating(player_a)
|
122 |
+
rating_b = self.get_player_rating(player_b)
|
123 |
+
|
124 |
+
expected_a = self.calculate_expected_score(rating_a, rating_b)
|
125 |
+
expected_b = 1 - expected_a
|
126 |
+
|
127 |
+
# Update ratings using the ELO formula: R' = R + K * (S - E)
|
128 |
+
# where R is the current rating, K is the k-factor,
|
129 |
+
# S is the actual score, and E is the expected score
|
130 |
+
new_rating_a = rating_a + self.k_factor * (score_a - expected_a)
|
131 |
+
new_rating_b = rating_b + self.k_factor * (score_b - expected_b)
|
132 |
+
|
133 |
+
self.ratings[player_a] = new_rating_a
|
134 |
+
self.ratings[player_b] = new_rating_b
|
135 |
+
|
136 |
+
|
137 |
+
def calculate_elo_ratings(
|
138 |
+
results: List[Result], exclude_self_play: bool = True
|
139 |
+
) -> Dict[str, float]:
|
140 |
+
"""
|
141 |
+
Calculate final ELO ratings for all players based on a list of game results.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
results: List of game results, sorted by date
|
145 |
+
exclude_self_play: If True, skip games where a player plays against themselves
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Dictionary mapping player names to their final ELO ratings
|
149 |
+
"""
|
150 |
+
calculator = EloCalculator()
|
151 |
+
|
152 |
+
for result in results:
|
153 |
+
# Skip self-play games if requested
|
154 |
+
if exclude_self_play and result.red_player == result.yellow_player:
|
155 |
+
continue
|
156 |
+
|
157 |
+
# Convert game result to ELO scores (1 for win, 0.5 for draw, 0 for loss)
|
158 |
+
if result.red_won and not result.yellow_won:
|
159 |
+
red_score, yellow_score = 1.0, 0.0
|
160 |
+
elif result.yellow_won and not result.red_won:
|
161 |
+
red_score, yellow_score = 0.0, 1.0
|
162 |
+
else:
|
163 |
+
# Draw (including double-win or double-loss cases)
|
164 |
+
red_score, yellow_score = 0.5, 0.5
|
165 |
+
|
166 |
+
calculator.update_ratings(
|
167 |
+
result.red_player, result.yellow_player, red_score, yellow_score
|
168 |
+
)
|
169 |
+
|
170 |
+
return calculator.ratings
|
171 |
+
|
172 |
+
|
173 |
+
def ratings() -> Dict[str, float]:
|
174 |
+
"""
|
175 |
+
Return the ELO ratings from all prior games in the DB
|
176 |
+
"""
|
177 |
+
games = get_games()
|
178 |
+
return calculate_elo_ratings(games)
|
connect.png
ADDED
![]() |
prototype.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
@@ -7,4 +7,5 @@ gradio
|
|
7 |
google.generativeai
|
8 |
anthropic
|
9 |
groq
|
10 |
-
black
|
|
|
|
7 |
google.generativeai
|
8 |
anthropic
|
9 |
groq
|
10 |
+
black
|
11 |
+
pymongo
|