import os from abc import ABC, abstractmethod from functools import cached_property from typing import ClassVar, Literal, Optional, Union import httpx from httpx import Limits, Timeout from openai import AsyncOpenAI from openai.types.chat.chat_completion import ( ChatCompletion, ) from pydantic import BaseModel from proxy_lite.history import MessageHistory from proxy_lite.logger import logger from proxy_lite.serializer import ( BaseSerializer, OpenAICompatibleSerializer, ) from proxy_lite.tools import Tool class BaseClientConfig(BaseModel): http_timeout: float = 50 http_concurrent_connections: int = 50 class BaseClient(BaseModel, ABC): config: BaseClientConfig serializer: ClassVar[BaseSerializer] @abstractmethod async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: ... """ Create completion from model. Expect subclasses to adapt from various endpoints that will handle requests differently, make sure to raise appropriate warnings. Returns: ChatCompletion: OpenAI ChatCompletion format for consistency """ @classmethod def create(cls, config: BaseClientConfig) -> "BaseClient": supported_clients = { "openai": OpenAIClient, "openai-azure": OpenAIClient, "convergence": ConvergenceClient, "gemini": GeminiClient, } # Type assertion - we know the config will have a name attribute from subclasses config_name = getattr(config, 'name', None) if config_name not in supported_clients: error_message = f"Unsupported model: {config_name}." raise ValueError(error_message) return supported_clients[config_name](config=config) @property def http_client(self) -> httpx.AsyncClient: return httpx.AsyncClient( timeout=Timeout(self.config.http_timeout), limits=Limits( max_connections=self.config.http_concurrent_connections, max_keepalive_connections=self.config.http_concurrent_connections, ), ) class OpenAIClientConfig(BaseClientConfig): name: Literal["openai"] = "openai" model_id: str = "gpt-4o" api_key: str = os.environ.get("OPENAI_API_KEY", "") api_base: Optional[str] = None class OpenAIClient(BaseClient): config: OpenAIClientConfig serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() @cached_property def external_client(self) -> AsyncOpenAI: client_params = { "api_key": self.config.api_key, "http_client": self.http_client, } if self.config.api_base: client_params["base_url"] = self.config.api_base return AsyncOpenAI(**client_params) async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: base_params = { "model": self.config.model_id, "messages": self.serializer.serialize_messages(messages), "temperature": temperature, } optional_params = { "seed": seed, "tools": self.serializer.serialize_tools(tools) if tools else None, "tool_choice": "required" if tools else None, "response_format": {"type": "json_object"} if response_format else {"type": "text"}, } base_params.update( {k: v for k, v in optional_params.items() if v is not None}) return await self.external_client.chat.completions.create(**base_params) class ConvergenceClientConfig(BaseClientConfig): name: Literal["convergence"] = "convergence" model_id: str = "convergence-ai/proxy-lite-7b" api_base: str = "http://localhost:8000/v1" api_key: str = "none" class ConvergenceClient(OpenAIClient): config: ConvergenceClientConfig serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() _model_validated: bool = False async def _validate_model(self) -> None: try: response = await self.external_client.models.list() assert self.config.model_id in [model.id for model in response.data], ( f"Model {self.config.model_id} not found in {response.data}" ) self._model_validated = True logger.debug(f"Model {self.config.model_id} validated and connected to cluster") except Exception as e: logger.error(f"Error retrieving model: {e}") raise e @cached_property def external_client(self) -> AsyncOpenAI: return AsyncOpenAI( api_key=self.config.api_key, base_url=self.config.api_base, http_client=self.http_client, ) async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: if not self._model_validated: await self._validate_model() base_params = { "model": self.config.model_id, "messages": self.serializer.serialize_messages(messages), "temperature": temperature, } optional_params = { "seed": seed, "tools": self.serializer.serialize_tools(tools) if tools else None, "tool_choice": "auto" if tools else None, # vLLM does not support "required" "response_format": response_format if response_format else {"type": "text"}, } base_params.update({k: v for k, v in optional_params.items() if v is not None}) return await self.external_client.chat.completions.create(**base_params) class GeminiClientConfig(BaseClientConfig): name: Literal["gemini"] = "gemini" model_id: str = "gemini-2.0-flash-001" api_key: str = "" class GeminiClient(BaseClient): config: GeminiClientConfig serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() def _convert_messages_to_gemini_format(self, messages): """Convert OpenAI format messages to Gemini format""" gemini_parts = [] for msg in messages: if msg["role"] == "user": gemini_parts.append({"text": msg["content"]}) elif msg["role"] == "assistant": gemini_parts.append({"text": msg["content"]}) # Skip system messages or add them to the first user message return gemini_parts def _clean_schema_for_gemini(self, schema): """Clean up JSON schema for Gemini function calling - remove $defs and $ref""" if not isinstance(schema, dict): return schema cleaned = {} for key, value in schema.items(): if key == "$defs": # Skip $defs - we'll inline the definitions continue elif key == "$ref": # Skip $ref - we'll inline the referenced schema continue elif isinstance(value, dict): cleaned[key] = self._clean_schema_for_gemini(value) elif isinstance(value, list): cleaned[key] = [self._clean_schema_for_gemini(item) for item in value] else: cleaned[key] = value # If we have $defs, we need to inline them if "$defs" in schema: cleaned = self._inline_definitions(cleaned, schema["$defs"]) return cleaned def _inline_definitions(self, schema, definitions): """Inline $ref definitions into the schema""" if not isinstance(schema, dict): return schema if "$ref" in schema: # Extract the reference name (e.g., "#/$defs/TypeEntry" -> "TypeEntry") ref_name = schema["$ref"].split("/")[-1] if ref_name in definitions: # Replace the $ref with the actual definition return self._inline_definitions(definitions[ref_name], definitions) else: # If we can't find the definition, remove the $ref return {k: v for k, v in schema.items() if k != "$ref"} # Recursively process nested objects inlined = {} for key, value in schema.items(): if isinstance(value, dict): inlined[key] = self._inline_definitions(value, definitions) elif isinstance(value, list): inlined[key] = [self._inline_definitions(item, definitions) for item in value] else: inlined[key] = value return inlined async def create_completion( self, messages: MessageHistory, temperature: float = 0.7, seed: Optional[int] = None, tools: Optional[list[Tool]] = None, response_format: Optional[type[BaseModel]] = None, ) -> ChatCompletion: import json from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.completion_usage import CompletionUsage from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function # Convert messages to format expected by Gemini serialized_messages = self.serializer.serialize_messages(messages) # For Gemini API, we need to format contents correctly with proper roles contents = [] current_user_text = "" for msg in serialized_messages: # Extract the actual text content from the serialized message content_text = "" if isinstance(msg["content"], list): # Handle complex content format for item in msg["content"]: if isinstance(item, dict) and "text" in item: content_text += item["text"] elif isinstance(item, str): content_text += item elif isinstance(msg["content"], str): content_text = msg["content"] if msg["role"] == "user": # Accumulate user messages current_user_text += content_text + "\n" elif msg["role"] == "assistant": # If we have accumulated user text, add it first if current_user_text.strip(): contents.append({ "role": "user", "parts": [{"text": current_user_text.strip()}] }) current_user_text = "" # Add assistant message with role "model" contents.append({ "role": "model", "parts": [{"text": content_text}] }) elif msg["role"] == "tool": # Add tool messages as user messages so they're included in context # Format tool message more clearly for the agent to understand current_user_text += f"[ACTION COMPLETED] {content_text}\n" # Add any remaining user text if current_user_text.strip(): contents.append({ "role": "user", "parts": [{"text": current_user_text.strip()}] }) payload = { "contents": contents, "generationConfig": { "temperature": temperature, } } # Add function calling support if tools are provided if tools: # Convert tools to Gemini function declaration format function_declarations = [] for tool in tools: for tool_schema in tool.schema: # Clean up the schema for Gemini - remove $defs and $ref cleaned_parameters = self._clean_schema_for_gemini(tool_schema["parameters"]) function_declarations.append({ "name": tool_schema["name"], "description": tool_schema["description"], "parameters": cleaned_parameters }) payload["tools"] = [{ "function_declarations": function_declarations }] # Make direct HTTP request to native Gemini API url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.config.model_id}:generateContent?key={self.config.api_key}" response = await self.http_client.post( url, json=payload, headers={"Content-Type": "application/json"} ) response.raise_for_status() response_data = response.json() # Convert Gemini response to OpenAI ChatCompletion format if "candidates" in response_data and len(response_data["candidates"]) > 0: candidate = response_data["candidates"][0] # Extract text from response content = "" tool_calls = [] if "content" in candidate and "parts" in candidate["content"]: for part in candidate["content"]["parts"]: if "text" in part: content += part["text"] elif "functionCall" in part: # Handle function call func_call = part["functionCall"] tool_call = ChatCompletionMessageToolCall( id=f"call_{hash(str(func_call))}"[:16], type="function", function=Function( name=func_call["name"], arguments=json.dumps(func_call.get("args", {})) ) ) tool_calls.append(tool_call) choice = Choice( index=0, message=ChatCompletionMessage( role="assistant", content=content if content else None, tool_calls=tool_calls if tool_calls else None ), finish_reason="stop" ) # Create a mock ChatCompletion response completion = ChatCompletion( id="gemini-" + str(hash(content))[:8], choices=[choice], created=int(__import__('time').time()), model=self.config.model_id, object="chat.completion", usage=CompletionUsage( completion_tokens=len(content.split()), prompt_tokens=sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages), total_tokens=len(content.split()) + sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages) ) ) return completion else: raise Exception(f"No valid response from Gemini API: {response_data}") ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig, GeminiClientConfig] ClientTypes = Union[OpenAIClient, ConvergenceClient, GeminiClient]