connect / arena /llm.py
ed-donner's picture
Added claude-3-7
7a13a15
from abc import ABC
from anthropic import Anthropic
from openai import OpenAI
from groq import Groq
import logging
from typing import Dict, Type, Self, List
import os
import time
logger = logging.getLogger(__name__)
class LLMException(Exception):
pass
class LLM(ABC):
"""
An abstract superclass for interacting with LLMs - subclass for Claude and GPT
"""
model_names = []
def __init__(self, model_name: str, temperature: float):
self.model_name = model_name
self.client = None
self.temperature = temperature
def send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
result = self.protected_send(system, user, max_tokens)
left = result.find("{")
right = result.rfind("}")
if left > -1 and right > -1:
result = result[left : right + 1]
return result
def protected_send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Wrap the send call in an exception handler, giving the LLM 3 chances in total, in case
of overload errors. If it fails 3 times, then it forfeits!
"""
retries = 3
while retries:
retries -= 1
try:
return self._send(system, user, max_tokens)
except Exception as e:
logging.error(f"Exception on calling LLM of {e}")
if retries:
logging.warning("Waiting 2s and retrying")
time.sleep(2)
return "{}"
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message to the model - this default implementation follows the OpenAI API structure
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
response = self.client.chat.completions.create(
model=self.api_model_name(),
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
response_format={"type": "json_object"},
)
return response.choices[0].message.content
def api_model_name(self) -> str:
"""
Return the actual model_name to be used in the call to the API; strip out anything after a space
"""
if " " in self.model_name:
return self.model_name.split(" ")[0]
else:
return self.model_name
@classmethod
def model_map(cls) -> Dict[str, Type[Self]]:
"""
Generate a mapping of Model Names to LLM classes, by looking at all subclasses of this one
:return: a mapping dictionary from model name to LLM subclass
"""
mapping = {}
for llm in cls.__subclasses__():
for model_name in llm.model_names:
mapping[model_name] = llm
return mapping
@classmethod
def all_model_names(cls) -> List[str]:
"""
Return a list of all the model names supported.
Use the ones specified in the model_map, but also check if there's an env variable set that restricts the models
"""
models = list(cls.model_map().keys())
allowed = os.getenv("MODELS")
if allowed:
allowed_models = allowed.split(",")
return [model for model in models if model in allowed_models]
else:
return models
@classmethod
def create(cls, model_name: str, temperature: float = 0.5) -> Self:
"""
Return an instance of a subclass that corresponds to this model_name
:param model_name: a string to describe this model
:param temperature: the creativity setting
:return: a new instance of a subclass of LLM
"""
subclass = cls.model_map().get(model_name)
if not subclass:
raise LLMException(f"Unrecognized LLM model name specified: {model_name}")
return subclass(model_name, temperature)
class Claude(LLM):
"""
A class to act as an interface to the remote AI, in this case Claude
"""
model_names = ["claude-3-5-sonnet-latest", "claude-3-7-sonnet-latest"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the Anthropic client
"""
super().__init__(model_name, temperature)
self.client = Anthropic()
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message to Claude
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
response = self.client.messages.create(
model=self.api_model_name(),
max_tokens=max_tokens,
temperature=self.temperature,
system=system,
messages=[
{"role": "user", "content": user},
],
)
return response.content[0].text
class GPT(LLM):
"""
A class to act as an interface to the remote AI, in this case GPT
"""
model_names = ["gpt-4o-mini", "gpt-4o"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client
"""
super().__init__(model_name, temperature)
self.client = OpenAI()
class O1(LLM):
"""
A class to act as an interface to the remote AI, in this case O1
"""
model_names = ["o1-mini"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client
"""
super().__init__(model_name, temperature)
self.client = OpenAI()
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message to O1
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
message = system + "\n\n" + user
response = self.client.chat.completions.create(
model=self.api_model_name(),
messages=[
{"role": "user", "content": message},
],
)
return response.choices[0].message.content
class O3(LLM):
"""
A class to act as an interface to the remote AI, in this case O3
"""
model_names = ["o3-mini"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client
"""
super().__init__(model_name, temperature)
override = os.getenv("OPENAI_API_KEY_O3")
if override:
print("Using special key with o3 access")
self.client = OpenAI(api_key=override)
else:
self.client = OpenAI()
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message to O3
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
message = system + "\n\n" + user
response = self.client.chat.completions.create(
model=self.api_model_name(),
messages=[
{"role": "user", "content": message},
],
)
return response.choices[0].message.content
class Gemini(LLM):
"""
A class to act as an interface to the remote AI, in this case Gemini
"""
model_names = ["gemini-2.0-flash", "gemini-1.5-flash"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client
"""
super().__init__(model_name, temperature)
google_api_key = os.getenv("GOOGLE_API_KEY")
self.client = OpenAI(
api_key=google_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
class Ollama(LLM):
"""
A class to act as an interface to the remote AI, in this case Ollama via the OpenAI client
"""
model_names = ["llama3.2 local", "gemma2 local", "qwen2.5 local", "phi4 local"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client for Ollama
"""
super().__init__(model_name, temperature)
self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message to Ollama
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
response = self.client.chat.completions.create(
model=self.api_model_name(),
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
response_format={"type": "json_object"},
)
reply = response.choices[0].message.content
if "</think>" in reply:
logging.info(
"Thoughts:\n" + reply.split("</think>")[0].replace("<think>", "")
)
reply = reply.split("</think>")[1]
return reply
class DeepSeekAPI(LLM):
"""
A class to act as an interface to the remote AI, in this case DeepSeek via the OpenAI client
"""
model_names = ["deepseek-chat V3", "deepseek-reasoner R1"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client
"""
super().__init__(model_name, temperature)
deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
self.client = OpenAI(
api_key=deepseek_api_key, base_url="https://api.deepseek.com"
)
class DeepSeekLocal(LLM):
"""
A class to act as an interface to the remote AI, in this case Ollama via the OpenAI client
"""
model_names = ["deepseek-r1:14b local"]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the OpenAI client
"""
super().__init__(model_name, temperature)
self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
def _send(self, system: str, user: str, max_tokens: int = 3000) -> str:
"""
Send a message to Ollama
:param system: the context in which this message is to be taken
:param user: the prompt
:param max_tokens: max number of tokens to generate
:return: the response from the AI
"""
system += "\nImportant: avoid overthinking. Think briefly and decisively. The final response must follow the given json format or you forfeit the game. Do not overthink. Respond with json."
user += "\nImportant: avoid overthinking. Think briefly and decisively. The final response must follow the given json format or you forfeit the game. Do not overthink. Respond with json."
response = self.client.chat.completions.create(
model=self.api_model_name(),
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user},
],
)
reply = response.choices[0].message.content
if "</think>" in reply:
logging.info(
"Thoughts:\n" + reply.split("</think>")[0].replace("<think>", "")
)
reply = reply.split("</think>")[1]
return reply
class GroqAPI(LLM):
"""
A class to act as an interface to the remote AI, in this case Groq
"""
model_names = [
"deepseek-r1-distill-llama-70b via Groq",
"llama-3.3-70b-versatile via Groq",
"mixtral-8x7b-32768 via Groq",
]
def __init__(self, model_name: str, temperature: float):
"""
Create a new instance of the Groq client
"""
super().__init__(model_name, temperature)
self.client = Groq()