Spaces:
Running
Running
import os | |
from abc import ABC, abstractmethod | |
from functools import cached_property | |
from typing import ClassVar, Literal, Optional, Union | |
import httpx | |
from httpx import Limits, Timeout | |
from openai import AsyncOpenAI | |
from openai.types.chat.chat_completion import ( | |
ChatCompletion, | |
) | |
from pydantic import BaseModel | |
from proxy_lite.history import MessageHistory | |
from proxy_lite.logger import logger | |
from proxy_lite.serializer import ( | |
BaseSerializer, | |
OpenAICompatibleSerializer, | |
) | |
from proxy_lite.tools import Tool | |
class BaseClientConfig(BaseModel): | |
http_timeout: float = 50 | |
http_concurrent_connections: int = 50 | |
class BaseClient(BaseModel, ABC): | |
config: BaseClientConfig | |
serializer: ClassVar[BaseSerializer] | |
async def create_completion( | |
self, | |
messages: MessageHistory, | |
temperature: float = 0.7, | |
seed: Optional[int] = None, | |
tools: Optional[list[Tool]] = None, | |
response_format: Optional[type[BaseModel]] = None, | |
) -> ChatCompletion: ... | |
""" | |
Create completion from model. | |
Expect subclasses to adapt from various endpoints that will handle | |
requests differently, make sure to raise appropriate warnings. | |
Returns: | |
ChatCompletion: OpenAI ChatCompletion format for consistency | |
""" | |
def create(cls, config: BaseClientConfig) -> "BaseClient": | |
supported_clients = { | |
"openai": OpenAIClient, | |
"openai-azure": OpenAIClient, | |
"convergence": ConvergenceClient, | |
"gemini": GeminiClient, | |
} | |
# Type assertion - we know the config will have a name attribute from subclasses | |
config_name = getattr(config, 'name', None) | |
if config_name not in supported_clients: | |
error_message = f"Unsupported model: {config_name}." | |
raise ValueError(error_message) | |
return supported_clients[config_name](config=config) | |
def http_client(self) -> httpx.AsyncClient: | |
return httpx.AsyncClient( | |
timeout=Timeout(self.config.http_timeout), | |
limits=Limits( | |
max_connections=self.config.http_concurrent_connections, | |
max_keepalive_connections=self.config.http_concurrent_connections, | |
), | |
) | |
class OpenAIClientConfig(BaseClientConfig): | |
name: Literal["openai"] = "openai" | |
model_id: str = "gpt-4o" | |
api_key: str = os.environ.get("OPENAI_API_KEY", "") | |
api_base: Optional[str] = None | |
class OpenAIClient(BaseClient): | |
config: OpenAIClientConfig | |
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() | |
def external_client(self) -> AsyncOpenAI: | |
client_params = { | |
"api_key": self.config.api_key, | |
"http_client": self.http_client, | |
} | |
if self.config.api_base: | |
client_params["base_url"] = self.config.api_base | |
return AsyncOpenAI(**client_params) | |
async def create_completion( | |
self, | |
messages: MessageHistory, | |
temperature: float = 0.7, | |
seed: Optional[int] = None, | |
tools: Optional[list[Tool]] = None, | |
response_format: Optional[type[BaseModel]] = None, | |
) -> ChatCompletion: | |
base_params = { | |
"model": self.config.model_id, | |
"messages": self.serializer.serialize_messages(messages), | |
"temperature": temperature, | |
} | |
optional_params = { | |
"seed": seed, | |
"tools": self.serializer.serialize_tools(tools) if tools else None, | |
"tool_choice": "required" if tools else None, | |
"response_format": {"type": "json_object"} if response_format else {"type": "text"}, | |
} | |
base_params.update( | |
{k: v for k, v in optional_params.items() if v is not None}) | |
return await self.external_client.chat.completions.create(**base_params) | |
class ConvergenceClientConfig(BaseClientConfig): | |
name: Literal["convergence"] = "convergence" | |
model_id: str = "convergence-ai/proxy-lite-7b" | |
api_base: str = "http://localhost:8000/v1" | |
api_key: str = "none" | |
class ConvergenceClient(OpenAIClient): | |
config: ConvergenceClientConfig | |
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() | |
_model_validated: bool = False | |
async def _validate_model(self) -> None: | |
try: | |
response = await self.external_client.models.list() | |
assert self.config.model_id in [model.id for model in response.data], ( | |
f"Model {self.config.model_id} not found in {response.data}" | |
) | |
self._model_validated = True | |
logger.debug(f"Model {self.config.model_id} validated and connected to cluster") | |
except Exception as e: | |
logger.error(f"Error retrieving model: {e}") | |
raise e | |
def external_client(self) -> AsyncOpenAI: | |
return AsyncOpenAI( | |
api_key=self.config.api_key, | |
base_url=self.config.api_base, | |
http_client=self.http_client, | |
) | |
async def create_completion( | |
self, | |
messages: MessageHistory, | |
temperature: float = 0.7, | |
seed: Optional[int] = None, | |
tools: Optional[list[Tool]] = None, | |
response_format: Optional[type[BaseModel]] = None, | |
) -> ChatCompletion: | |
if not self._model_validated: | |
await self._validate_model() | |
base_params = { | |
"model": self.config.model_id, | |
"messages": self.serializer.serialize_messages(messages), | |
"temperature": temperature, | |
} | |
optional_params = { | |
"seed": seed, | |
"tools": self.serializer.serialize_tools(tools) if tools else None, | |
"tool_choice": "auto" if tools else None, # vLLM does not support "required" | |
"response_format": response_format if response_format else {"type": "text"}, | |
} | |
base_params.update({k: v for k, v in optional_params.items() if v is not None}) | |
return await self.external_client.chat.completions.create(**base_params) | |
class GeminiClientConfig(BaseClientConfig): | |
name: Literal["gemini"] = "gemini" | |
model_id: str = "gemini-2.0-flash-001" | |
api_key: str = "" | |
class GeminiClient(BaseClient): | |
config: GeminiClientConfig | |
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() | |
def _convert_messages_to_gemini_format(self, messages): | |
"""Convert OpenAI format messages to Gemini format""" | |
gemini_parts = [] | |
for msg in messages: | |
if msg["role"] == "user": | |
gemini_parts.append({"text": msg["content"]}) | |
elif msg["role"] == "assistant": | |
gemini_parts.append({"text": msg["content"]}) | |
# Skip system messages or add them to the first user message | |
return gemini_parts | |
def _clean_schema_for_gemini(self, schema): | |
"""Clean up JSON schema for Gemini function calling - remove $defs and $ref""" | |
if not isinstance(schema, dict): | |
return schema | |
cleaned = {} | |
for key, value in schema.items(): | |
if key == "$defs": | |
# Skip $defs - we'll inline the definitions | |
continue | |
elif key == "$ref": | |
# Skip $ref - we'll inline the referenced schema | |
continue | |
elif isinstance(value, dict): | |
cleaned[key] = self._clean_schema_for_gemini(value) | |
elif isinstance(value, list): | |
cleaned[key] = [self._clean_schema_for_gemini(item) for item in value] | |
else: | |
cleaned[key] = value | |
# If we have $defs, we need to inline them | |
if "$defs" in schema: | |
cleaned = self._inline_definitions(cleaned, schema["$defs"]) | |
return cleaned | |
def _inline_definitions(self, schema, definitions): | |
"""Inline $ref definitions into the schema""" | |
if not isinstance(schema, dict): | |
return schema | |
if "$ref" in schema: | |
# Extract the reference name (e.g., "#/$defs/TypeEntry" -> "TypeEntry") | |
ref_name = schema["$ref"].split("/")[-1] | |
if ref_name in definitions: | |
# Replace the $ref with the actual definition | |
return self._inline_definitions(definitions[ref_name], definitions) | |
else: | |
# If we can't find the definition, remove the $ref | |
return {k: v for k, v in schema.items() if k != "$ref"} | |
# Recursively process nested objects | |
inlined = {} | |
for key, value in schema.items(): | |
if isinstance(value, dict): | |
inlined[key] = self._inline_definitions(value, definitions) | |
elif isinstance(value, list): | |
inlined[key] = [self._inline_definitions(item, definitions) for item in value] | |
else: | |
inlined[key] = value | |
return inlined | |
async def create_completion( | |
self, | |
messages: MessageHistory, | |
temperature: float = 0.7, | |
seed: Optional[int] = None, | |
tools: Optional[list[Tool]] = None, | |
response_format: Optional[type[BaseModel]] = None, | |
) -> ChatCompletion: | |
import json | |
from openai.types.chat.chat_completion import ChatCompletion, Choice | |
from openai.types.chat.chat_completion_message import ChatCompletionMessage | |
from openai.types.completion_usage import CompletionUsage | |
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function | |
# Convert messages to format expected by Gemini | |
serialized_messages = self.serializer.serialize_messages(messages) | |
# For Gemini API, we need to format contents correctly with proper roles | |
contents = [] | |
current_user_text = "" | |
for msg in serialized_messages: | |
# Extract the actual text content from the serialized message | |
content_text = "" | |
if isinstance(msg["content"], list): | |
# Handle complex content format | |
for item in msg["content"]: | |
if isinstance(item, dict) and "text" in item: | |
content_text += item["text"] | |
elif isinstance(item, str): | |
content_text += item | |
elif isinstance(msg["content"], str): | |
content_text = msg["content"] | |
if msg["role"] == "user": | |
# Accumulate user messages | |
current_user_text += content_text + "\n" | |
elif msg["role"] == "assistant": | |
# If we have accumulated user text, add it first | |
if current_user_text.strip(): | |
contents.append({ | |
"role": "user", | |
"parts": [{"text": current_user_text.strip()}] | |
}) | |
current_user_text = "" | |
# Add assistant message with role "model" | |
contents.append({ | |
"role": "model", | |
"parts": [{"text": content_text}] | |
}) | |
elif msg["role"] == "tool": | |
# Add tool messages as user messages so they're included in context | |
# Format tool message more clearly for the agent to understand | |
current_user_text += f"[ACTION COMPLETED] {content_text}\n" | |
# Add any remaining user text | |
if current_user_text.strip(): | |
contents.append({ | |
"role": "user", | |
"parts": [{"text": current_user_text.strip()}] | |
}) | |
payload = { | |
"contents": contents, | |
"generationConfig": { | |
"temperature": temperature, | |
} | |
} | |
# Add function calling support if tools are provided | |
if tools: | |
# Convert tools to Gemini function declaration format | |
function_declarations = [] | |
for tool in tools: | |
for tool_schema in tool.schema: | |
# Clean up the schema for Gemini - remove $defs and $ref | |
cleaned_parameters = self._clean_schema_for_gemini(tool_schema["parameters"]) | |
function_declarations.append({ | |
"name": tool_schema["name"], | |
"description": tool_schema["description"], | |
"parameters": cleaned_parameters | |
}) | |
payload["tools"] = [{ | |
"function_declarations": function_declarations | |
}] | |
# Make direct HTTP request to native Gemini API | |
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.config.model_id}:generateContent?key={self.config.api_key}" | |
response = await self.http_client.post( | |
url, | |
json=payload, | |
headers={"Content-Type": "application/json"} | |
) | |
response.raise_for_status() | |
response_data = response.json() | |
# Convert Gemini response to OpenAI ChatCompletion format | |
if "candidates" in response_data and len(response_data["candidates"]) > 0: | |
candidate = response_data["candidates"][0] | |
# Extract text from response | |
content = "" | |
tool_calls = [] | |
if "content" in candidate and "parts" in candidate["content"]: | |
for part in candidate["content"]["parts"]: | |
if "text" in part: | |
content += part["text"] | |
elif "functionCall" in part: | |
# Handle function call | |
func_call = part["functionCall"] | |
tool_call = ChatCompletionMessageToolCall( | |
id=f"call_{hash(str(func_call))}"[:16], | |
type="function", | |
function=Function( | |
name=func_call["name"], | |
arguments=json.dumps(func_call.get("args", {})) | |
) | |
) | |
tool_calls.append(tool_call) | |
choice = Choice( | |
index=0, | |
message=ChatCompletionMessage( | |
role="assistant", | |
content=content if content else None, | |
tool_calls=tool_calls if tool_calls else None | |
), | |
finish_reason="stop" | |
) | |
# Create a mock ChatCompletion response | |
completion = ChatCompletion( | |
id="gemini-" + str(hash(content))[:8], | |
choices=[choice], | |
created=int(__import__('time').time()), | |
model=self.config.model_id, | |
object="chat.completion", | |
usage=CompletionUsage( | |
completion_tokens=len(content.split()), | |
prompt_tokens=sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages), | |
total_tokens=len(content.split()) + sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages) | |
) | |
) | |
return completion | |
else: | |
raise Exception(f"No valid response from Gemini API: {response_data}") | |
ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig, GeminiClientConfig] | |
ClientTypes = Union[OpenAIClient, ConvergenceClient, GeminiClient] | |