Trisha Tomy
fixes+permset
60c7a7f
raw
history blame
15.9 kB
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import ClassVar, Literal, Optional, Union
import httpx
from httpx import Limits, Timeout
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import (
ChatCompletion,
)
from pydantic import BaseModel
from proxy_lite.history import MessageHistory
from proxy_lite.logger import logger
from proxy_lite.serializer import (
BaseSerializer,
OpenAICompatibleSerializer,
)
from proxy_lite.tools import Tool
class BaseClientConfig(BaseModel):
http_timeout: float = 50
http_concurrent_connections: int = 50
class BaseClient(BaseModel, ABC):
config: BaseClientConfig
serializer: ClassVar[BaseSerializer]
@abstractmethod
async def create_completion(
self,
messages: MessageHistory,
temperature: float = 0.7,
seed: Optional[int] = None,
tools: Optional[list[Tool]] = None,
response_format: Optional[type[BaseModel]] = None,
) -> ChatCompletion: ...
"""
Create completion from model.
Expect subclasses to adapt from various endpoints that will handle
requests differently, make sure to raise appropriate warnings.
Returns:
ChatCompletion: OpenAI ChatCompletion format for consistency
"""
@classmethod
def create(cls, config: BaseClientConfig) -> "BaseClient":
supported_clients = {
"openai": OpenAIClient,
"openai-azure": OpenAIClient,
"convergence": ConvergenceClient,
"gemini": GeminiClient,
}
# Type assertion - we know the config will have a name attribute from subclasses
config_name = getattr(config, 'name', None)
if config_name not in supported_clients:
error_message = f"Unsupported model: {config_name}."
raise ValueError(error_message)
return supported_clients[config_name](config=config)
@property
def http_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
timeout=Timeout(self.config.http_timeout),
limits=Limits(
max_connections=self.config.http_concurrent_connections,
max_keepalive_connections=self.config.http_concurrent_connections,
),
)
class OpenAIClientConfig(BaseClientConfig):
name: Literal["openai"] = "openai"
model_id: str = "gpt-4o"
api_key: str = os.environ.get("OPENAI_API_KEY", "")
api_base: Optional[str] = None
class OpenAIClient(BaseClient):
config: OpenAIClientConfig
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
@cached_property
def external_client(self) -> AsyncOpenAI:
client_params = {
"api_key": self.config.api_key,
"http_client": self.http_client,
}
if self.config.api_base:
client_params["base_url"] = self.config.api_base
return AsyncOpenAI(**client_params)
async def create_completion(
self,
messages: MessageHistory,
temperature: float = 0.7,
seed: Optional[int] = None,
tools: Optional[list[Tool]] = None,
response_format: Optional[type[BaseModel]] = None,
) -> ChatCompletion:
base_params = {
"model": self.config.model_id,
"messages": self.serializer.serialize_messages(messages),
"temperature": temperature,
}
optional_params = {
"seed": seed,
"tools": self.serializer.serialize_tools(tools) if tools else None,
"tool_choice": "required" if tools else None,
"response_format": {"type": "json_object"} if response_format else {"type": "text"},
}
base_params.update(
{k: v for k, v in optional_params.items() if v is not None})
return await self.external_client.chat.completions.create(**base_params)
class ConvergenceClientConfig(BaseClientConfig):
name: Literal["convergence"] = "convergence"
model_id: str = "convergence-ai/proxy-lite-7b"
api_base: str = "http://localhost:8000/v1"
api_key: str = "none"
class ConvergenceClient(OpenAIClient):
config: ConvergenceClientConfig
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
_model_validated: bool = False
async def _validate_model(self) -> None:
try:
response = await self.external_client.models.list()
assert self.config.model_id in [model.id for model in response.data], (
f"Model {self.config.model_id} not found in {response.data}"
)
self._model_validated = True
logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
except Exception as e:
logger.error(f"Error retrieving model: {e}")
raise e
@cached_property
def external_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.config.api_key,
base_url=self.config.api_base,
http_client=self.http_client,
)
async def create_completion(
self,
messages: MessageHistory,
temperature: float = 0.7,
seed: Optional[int] = None,
tools: Optional[list[Tool]] = None,
response_format: Optional[type[BaseModel]] = None,
) -> ChatCompletion:
if not self._model_validated:
await self._validate_model()
base_params = {
"model": self.config.model_id,
"messages": self.serializer.serialize_messages(messages),
"temperature": temperature,
}
optional_params = {
"seed": seed,
"tools": self.serializer.serialize_tools(tools) if tools else None,
"tool_choice": "auto" if tools else None, # vLLM does not support "required"
"response_format": response_format if response_format else {"type": "text"},
}
base_params.update({k: v for k, v in optional_params.items() if v is not None})
return await self.external_client.chat.completions.create(**base_params)
class GeminiClientConfig(BaseClientConfig):
name: Literal["gemini"] = "gemini"
model_id: str = "gemini-2.0-flash-001"
api_key: str = ""
class GeminiClient(BaseClient):
config: GeminiClientConfig
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
def _convert_messages_to_gemini_format(self, messages):
"""Convert OpenAI format messages to Gemini format"""
gemini_parts = []
for msg in messages:
if msg["role"] == "user":
gemini_parts.append({"text": msg["content"]})
elif msg["role"] == "assistant":
gemini_parts.append({"text": msg["content"]})
# Skip system messages or add them to the first user message
return gemini_parts
def _clean_schema_for_gemini(self, schema):
"""Clean up JSON schema for Gemini function calling - remove $defs and $ref"""
if not isinstance(schema, dict):
return schema
cleaned = {}
for key, value in schema.items():
if key == "$defs":
# Skip $defs - we'll inline the definitions
continue
elif key == "$ref":
# Skip $ref - we'll inline the referenced schema
continue
elif isinstance(value, dict):
cleaned[key] = self._clean_schema_for_gemini(value)
elif isinstance(value, list):
cleaned[key] = [self._clean_schema_for_gemini(item) for item in value]
else:
cleaned[key] = value
# If we have $defs, we need to inline them
if "$defs" in schema:
cleaned = self._inline_definitions(cleaned, schema["$defs"])
return cleaned
def _inline_definitions(self, schema, definitions):
"""Inline $ref definitions into the schema"""
if not isinstance(schema, dict):
return schema
if "$ref" in schema:
# Extract the reference name (e.g., "#/$defs/TypeEntry" -> "TypeEntry")
ref_name = schema["$ref"].split("/")[-1]
if ref_name in definitions:
# Replace the $ref with the actual definition
return self._inline_definitions(definitions[ref_name], definitions)
else:
# If we can't find the definition, remove the $ref
return {k: v for k, v in schema.items() if k != "$ref"}
# Recursively process nested objects
inlined = {}
for key, value in schema.items():
if isinstance(value, dict):
inlined[key] = self._inline_definitions(value, definitions)
elif isinstance(value, list):
inlined[key] = [self._inline_definitions(item, definitions) for item in value]
else:
inlined[key] = value
return inlined
async def create_completion(
self,
messages: MessageHistory,
temperature: float = 0.7,
seed: Optional[int] = None,
tools: Optional[list[Tool]] = None,
response_format: Optional[type[BaseModel]] = None,
) -> ChatCompletion:
import json
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
# Convert messages to format expected by Gemini
serialized_messages = self.serializer.serialize_messages(messages)
# For Gemini API, we need to format contents correctly with proper roles
contents = []
current_user_text = ""
for msg in serialized_messages:
# Extract the actual text content from the serialized message
content_text = ""
if isinstance(msg["content"], list):
# Handle complex content format
for item in msg["content"]:
if isinstance(item, dict) and "text" in item:
content_text += item["text"]
elif isinstance(item, str):
content_text += item
elif isinstance(msg["content"], str):
content_text = msg["content"]
if msg["role"] == "user":
# Accumulate user messages
current_user_text += content_text + "\n"
elif msg["role"] == "assistant":
# If we have accumulated user text, add it first
if current_user_text.strip():
contents.append({
"role": "user",
"parts": [{"text": current_user_text.strip()}]
})
current_user_text = ""
# Add assistant message with role "model"
contents.append({
"role": "model",
"parts": [{"text": content_text}]
})
elif msg["role"] == "tool":
# Add tool messages as user messages so they're included in context
# Format tool message more clearly for the agent to understand
current_user_text += f"[ACTION COMPLETED] {content_text}\n"
# Add any remaining user text
if current_user_text.strip():
contents.append({
"role": "user",
"parts": [{"text": current_user_text.strip()}]
})
payload = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
}
}
# Add function calling support if tools are provided
if tools:
# Convert tools to Gemini function declaration format
function_declarations = []
for tool in tools:
for tool_schema in tool.schema:
# Clean up the schema for Gemini - remove $defs and $ref
cleaned_parameters = self._clean_schema_for_gemini(tool_schema["parameters"])
function_declarations.append({
"name": tool_schema["name"],
"description": tool_schema["description"],
"parameters": cleaned_parameters
})
payload["tools"] = [{
"function_declarations": function_declarations
}]
# Make direct HTTP request to native Gemini API
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.config.model_id}:generateContent?key={self.config.api_key}"
response = await self.http_client.post(
url,
json=payload,
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
response_data = response.json()
# Convert Gemini response to OpenAI ChatCompletion format
if "candidates" in response_data and len(response_data["candidates"]) > 0:
candidate = response_data["candidates"][0]
# Extract text from response
content = ""
tool_calls = []
if "content" in candidate and "parts" in candidate["content"]:
for part in candidate["content"]["parts"]:
if "text" in part:
content += part["text"]
elif "functionCall" in part:
# Handle function call
func_call = part["functionCall"]
tool_call = ChatCompletionMessageToolCall(
id=f"call_{hash(str(func_call))}"[:16],
type="function",
function=Function(
name=func_call["name"],
arguments=json.dumps(func_call.get("args", {}))
)
)
tool_calls.append(tool_call)
choice = Choice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=content if content else None,
tool_calls=tool_calls if tool_calls else None
),
finish_reason="stop"
)
# Create a mock ChatCompletion response
completion = ChatCompletion(
id="gemini-" + str(hash(content))[:8],
choices=[choice],
created=int(__import__('time').time()),
model=self.config.model_id,
object="chat.completion",
usage=CompletionUsage(
completion_tokens=len(content.split()),
prompt_tokens=sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages),
total_tokens=len(content.split()) + sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages)
)
)
return completion
else:
raise Exception(f"No valid response from Gemini API: {response_data}")
ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig, GeminiClientConfig]
ClientTypes = Union[OpenAIClient, ConvergenceClient, GeminiClient]