ed-donner commited on
Commit
bae16e4
·
1 Parent(s): 70830d6

Restrict models

Browse files
Files changed (4) hide show
  1. app.py +3 -2
  2. arena/c4.py +2 -2
  3. arena/game.py +0 -2
  4. arena/llm.py +27 -19
app.py CHANGED
@@ -1,7 +1,8 @@
1
  from arena.c4 import make_display
 
2
 
3
 
4
- app = make_display()
5
-
6
  if __name__ == "__main__":
 
 
7
  app.launch()
 
1
  from arena.c4 import make_display
2
+ from dotenv import load_dotenv
3
 
4
 
 
 
5
  if __name__ == "__main__":
6
+ load_dotenv(override=True)
7
+ app = make_display()
8
  app.launch()
arena/c4.py CHANGED
@@ -3,7 +3,6 @@ from arena.board import RED, YELLOW
3
  from arena.llm import LLM
4
  import gradio as gr
5
 
6
- all_model_names = LLM.all_model_names()
7
 
8
  css = "footer{display:none !important}"
9
 
@@ -80,6 +79,7 @@ def yellow_model_callback(game, new_model_name):
80
 
81
 
82
  def player_section(name, default):
 
83
  with gr.Row():
84
  gr.Markdown(
85
  f'<div style="text-align: center;font-size:18px">{name} Player</div>'
@@ -113,7 +113,7 @@ def make_display():
113
  )
114
  with gr.Row():
115
  with gr.Column(scale=1):
116
- red_thoughts, red_dropdown = player_section("Red", "gpt-4o")
117
  with gr.Column(scale=2):
118
  with gr.Row():
119
  message = gr.Markdown(
 
3
  from arena.llm import LLM
4
  import gradio as gr
5
 
 
6
 
7
  css = "footer{display:none !important}"
8
 
 
79
 
80
 
81
  def player_section(name, default):
82
+ all_model_names = LLM.all_model_names()
83
  with gr.Row():
84
  gr.Markdown(
85
  f'<div style="text-align: center;font-size:18px">{name} Player</div>'
 
113
  )
114
  with gr.Row():
115
  with gr.Column(scale=1):
116
+ red_thoughts, red_dropdown = player_section("Red", "gpt-4o-mini")
117
  with gr.Column(scale=2):
118
  with gr.Row():
119
  message = gr.Markdown(
arena/game.py CHANGED
@@ -1,12 +1,10 @@
1
  from arena.board import Board, RED, YELLOW, EMPTY, pieces
2
  from arena.player import Player
3
- from dotenv import load_dotenv
4
 
5
 
6
  class Game:
7
 
8
  def __init__(self, model_red, model_yellow):
9
- load_dotenv(override=True)
10
  self.board = Board()
11
  self.players = {
12
  RED: Player(model_red, RED),
 
1
  from arena.board import Board, RED, YELLOW, EMPTY, pieces
2
  from arena.player import Player
 
3
 
4
 
5
  class Game:
6
 
7
  def __init__(self, model_red, model_yellow):
 
8
  self.board = Board()
9
  self.players = {
10
  RED: Player(model_red, RED),
arena/llm.py CHANGED
@@ -49,7 +49,6 @@ class LLM(ABC):
49
 
50
  def protected_send(self, system: str, user: str, max_tokens: int = 3000) -> str:
51
  retries = 5
52
- done = False
53
  while retries:
54
  retries -= 1
55
  try:
@@ -62,7 +61,13 @@ class LLM(ABC):
62
  return "{}"
63
 
64
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
65
- pass
 
 
 
 
 
 
66
 
67
  @classmethod
68
  def model_map(cls) -> Dict[str, Type[Self]]:
@@ -78,7 +83,13 @@ class LLM(ABC):
78
 
79
  @classmethod
80
  def all_model_names(cls) -> List[str]:
81
- return cls.model_map().keys()
 
 
 
 
 
 
82
 
83
  @classmethod
84
  def create(cls, model_name: str, temperature: float = 0.5) -> Self:
@@ -117,7 +128,7 @@ class Claude(LLM):
117
  :return: the response from the AI
118
  """
119
  response = self.client.messages.create(
120
- model=self.model_name,
121
  max_tokens=max_tokens,
122
  temperature=self.temperature,
123
  system=system,
@@ -151,7 +162,7 @@ class GPT(LLM):
151
  :return: the response from the AI
152
  """
153
  response = self.client.chat.completions.create(
154
- model=self.model_name,
155
  messages=[
156
  {"role": "system", "content": system},
157
  {"role": "user", "content": user},
@@ -185,7 +196,7 @@ class O1(LLM):
185
  """
186
  message = system + "\n\n" + user
187
  response = self.client.chat.completions.create(
188
- model=self.model_name,
189
  messages=[
190
  {"role": "user", "content": message},
191
  ],
@@ -222,7 +233,7 @@ class O3(LLM):
222
  """
223
  message = system + "\n\n" + user
224
  response = self.client.chat.completions.create(
225
- model=self.model_name,
226
  messages=[
227
  {"role": "user", "content": message},
228
  ],
@@ -241,7 +252,7 @@ class Ollama(LLM):
241
  """
242
  Create a new instance of the OpenAI client
243
  """
244
- super().__init__(model_name.replace(" local", ""), temperature)
245
  self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
246
 
247
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
@@ -254,7 +265,7 @@ class Ollama(LLM):
254
  """
255
 
256
  response = self.client.chat.completions.create(
257
- model=self.model_name,
258
  messages=[
259
  {"role": "system", "content": system},
260
  {"role": "user", "content": user},
@@ -273,15 +284,13 @@ class DeepSeekAPI(LLM):
273
  A class to act as an interface to the remote AI, in this case DeepSeek via the OpenAI client
274
  """
275
 
276
- model_names = ["deepseek-V3", "deepseek-r1"]
277
-
278
- model_map = {"deepseek-V3": "deepseek-chat", "deepseek-r1": "deepseek-reasoner"}
279
 
280
  def __init__(self, model_name: str, temperature: float):
281
  """
282
  Create a new instance of the OpenAI client
283
  """
284
- super().__init__(self.model_map[model_name], temperature)
285
  deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
286
  self.client = OpenAI(
287
  api_key=deepseek_api_key, base_url="https://api.deepseek.com"
@@ -297,12 +306,11 @@ class DeepSeekAPI(LLM):
297
  """
298
 
299
  response = self.client.chat.completions.create(
300
- model=self.model_name,
301
  messages=[
302
  {"role": "system", "content": system},
303
  {"role": "user", "content": user},
304
  ],
305
- # response_format={"type": "json_object"},
306
  )
307
  reply = response.choices[0].message.content
308
  return reply
@@ -319,7 +327,7 @@ class DeepSeekLocal(LLM):
319
  """
320
  Create a new instance of the OpenAI client
321
  """
322
- super().__init__(model_name.replace(" local", ""), temperature)
323
  self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
324
 
325
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
@@ -333,7 +341,7 @@ class DeepSeekLocal(LLM):
333
  system += "\nImportant: avoid overthinking. Think briefly and decisively. The final response must follow the given json format or you forfeit the game. Do not overthink. Respond with json."
334
  user += "\nImportant: avoid overthinking. Think briefly and decisively. The final response must follow the given json format or you forfeit the game. Do not overthink. Respond with json."
335
  response = self.client.chat.completions.create(
336
- model=self.model_name,
337
  messages=[
338
  {"role": "system", "content": system},
339
  {"role": "user", "content": user},
@@ -361,7 +369,7 @@ class GroqAPI(LLM):
361
  """
362
  Create a new instance of the OpenAI client
363
  """
364
- super().__init__(model_name[:-9], temperature)
365
  self.client = Groq()
366
 
367
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
@@ -373,7 +381,7 @@ class GroqAPI(LLM):
373
  :return: the response from the AI
374
  """
375
  response = self.client.chat.completions.create(
376
- model=self.model_name,
377
  messages=[
378
  {"role": "system", "content": system},
379
  {"role": "user", "content": user},
 
49
 
50
  def protected_send(self, system: str, user: str, max_tokens: int = 3000) -> str:
51
  retries = 5
 
52
  while retries:
53
  retries -= 1
54
  try:
 
61
  return "{}"
62
 
63
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
64
+ raise NotImplementedError
65
+
66
+ def api_model_name(self):
67
+ if " " in self.model_name:
68
+ return self.model_name.split(" ")[0]
69
+ else:
70
+ return self.model_name
71
 
72
  @classmethod
73
  def model_map(cls) -> Dict[str, Type[Self]]:
 
83
 
84
  @classmethod
85
  def all_model_names(cls) -> List[str]:
86
+ models = list(cls.model_map().keys())
87
+ allowed = os.getenv("MODELS")
88
+ if allowed:
89
+ allowed_models = allowed.split(",")
90
+ return [model for model in models if model in allowed_models]
91
+ else:
92
+ return models
93
 
94
  @classmethod
95
  def create(cls, model_name: str, temperature: float = 0.5) -> Self:
 
128
  :return: the response from the AI
129
  """
130
  response = self.client.messages.create(
131
+ model=self.api_model_name(),
132
  max_tokens=max_tokens,
133
  temperature=self.temperature,
134
  system=system,
 
162
  :return: the response from the AI
163
  """
164
  response = self.client.chat.completions.create(
165
+ model=self.api_model_name(),
166
  messages=[
167
  {"role": "system", "content": system},
168
  {"role": "user", "content": user},
 
196
  """
197
  message = system + "\n\n" + user
198
  response = self.client.chat.completions.create(
199
+ model=self.api_model_name(),
200
  messages=[
201
  {"role": "user", "content": message},
202
  ],
 
233
  """
234
  message = system + "\n\n" + user
235
  response = self.client.chat.completions.create(
236
+ model=self.api_model_name(),
237
  messages=[
238
  {"role": "user", "content": message},
239
  ],
 
252
  """
253
  Create a new instance of the OpenAI client
254
  """
255
+ super().__init__(model_name, temperature)
256
  self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
257
 
258
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
 
265
  """
266
 
267
  response = self.client.chat.completions.create(
268
+ model=self.api_model_name(),
269
  messages=[
270
  {"role": "system", "content": system},
271
  {"role": "user", "content": user},
 
284
  A class to act as an interface to the remote AI, in this case DeepSeek via the OpenAI client
285
  """
286
 
287
+ model_names = ["deepseek-chat V3", "deepseek-reasoner R1"]
 
 
288
 
289
  def __init__(self, model_name: str, temperature: float):
290
  """
291
  Create a new instance of the OpenAI client
292
  """
293
+ super().__init__(model_name, temperature)
294
  deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
295
  self.client = OpenAI(
296
  api_key=deepseek_api_key, base_url="https://api.deepseek.com"
 
306
  """
307
 
308
  response = self.client.chat.completions.create(
309
+ model=self.api_model_name(),
310
  messages=[
311
  {"role": "system", "content": system},
312
  {"role": "user", "content": user},
313
  ],
 
314
  )
315
  reply = response.choices[0].message.content
316
  return reply
 
327
  """
328
  Create a new instance of the OpenAI client
329
  """
330
+ super().__init__(model_name, temperature)
331
  self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
332
 
333
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
 
341
  system += "\nImportant: avoid overthinking. Think briefly and decisively. The final response must follow the given json format or you forfeit the game. Do not overthink. Respond with json."
342
  user += "\nImportant: avoid overthinking. Think briefly and decisively. The final response must follow the given json format or you forfeit the game. Do not overthink. Respond with json."
343
  response = self.client.chat.completions.create(
344
+ model=self.api_model_name(),
345
  messages=[
346
  {"role": "system", "content": system},
347
  {"role": "user", "content": user},
 
369
  """
370
  Create a new instance of the OpenAI client
371
  """
372
+ super().__init__(model_name, temperature)
373
  self.client = Groq()
374
 
375
  def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
 
381
  :return: the response from the AI
382
  """
383
  response = self.client.chat.completions.create(
384
+ model=self.api_model_name(),
385
  messages=[
386
  {"role": "system", "content": system},
387
  {"role": "user", "content": user},