File size: 15,878 Bytes
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
 
 
 
c9803a3
60c7a7f
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
 
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c7a7f
 
 
 
c9803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import ClassVar, Literal, Optional, Union

import httpx
from httpx import Limits, Timeout
from openai import AsyncOpenAI
from openai.types.chat.chat_completion import (
    ChatCompletion,
)
from pydantic import BaseModel

from proxy_lite.history import MessageHistory
from proxy_lite.logger import logger
from proxy_lite.serializer import (
    BaseSerializer,
    OpenAICompatibleSerializer,
)
from proxy_lite.tools import Tool


class BaseClientConfig(BaseModel):
    http_timeout: float = 50
    http_concurrent_connections: int = 50


class BaseClient(BaseModel, ABC):
    config: BaseClientConfig
    serializer: ClassVar[BaseSerializer]

    @abstractmethod
    async def create_completion(
        self,
        messages: MessageHistory,
        temperature: float = 0.7,
        seed: Optional[int] = None,
        tools: Optional[list[Tool]] = None,
        response_format: Optional[type[BaseModel]] = None,
    ) -> ChatCompletion: ...
    """
    Create completion from model.
    Expect subclasses to adapt from various endpoints that will handle
    requests differently, make sure to raise appropriate warnings.

    Returns:
        ChatCompletion: OpenAI ChatCompletion format for consistency
    """

    @classmethod
    def create(cls, config: BaseClientConfig) -> "BaseClient":
        supported_clients = {
            "openai": OpenAIClient,
            "openai-azure": OpenAIClient,
            "convergence": ConvergenceClient,
            "gemini": GeminiClient,
        }
        # Type assertion - we know the config will have a name attribute from subclasses
        config_name = getattr(config, 'name', None)
        if config_name not in supported_clients:
            error_message = f"Unsupported model: {config_name}."
            raise ValueError(error_message)
        return supported_clients[config_name](config=config)

    @property
    def http_client(self) -> httpx.AsyncClient:
        return httpx.AsyncClient(
            timeout=Timeout(self.config.http_timeout),
            limits=Limits(
                max_connections=self.config.http_concurrent_connections,
                max_keepalive_connections=self.config.http_concurrent_connections,
            ),
        )


class OpenAIClientConfig(BaseClientConfig):
    name: Literal["openai"] = "openai"
    model_id: str = "gpt-4o"
    api_key: str = os.environ.get("OPENAI_API_KEY", "")
    api_base: Optional[str] = None


class OpenAIClient(BaseClient):
    config: OpenAIClientConfig
    serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()

    @cached_property
    def external_client(self) -> AsyncOpenAI:
        client_params = {
            "api_key": self.config.api_key,
            "http_client": self.http_client,
        }
        if self.config.api_base:
            client_params["base_url"] = self.config.api_base
        return AsyncOpenAI(**client_params)

    async def create_completion(
        self,
        messages: MessageHistory,
        temperature: float = 0.7,
        seed: Optional[int] = None,
        tools: Optional[list[Tool]] = None,
        response_format: Optional[type[BaseModel]] = None,
    ) -> ChatCompletion:
        base_params = {
            "model": self.config.model_id,
            "messages": self.serializer.serialize_messages(messages),
            "temperature": temperature,
        }
        optional_params = {
            "seed": seed,
            "tools": self.serializer.serialize_tools(tools) if tools else None,
            "tool_choice": "required" if tools else None,
            "response_format": {"type": "json_object"} if response_format else {"type": "text"},
        }
        base_params.update(
            {k: v for k, v in optional_params.items() if v is not None})
        return await self.external_client.chat.completions.create(**base_params)


class ConvergenceClientConfig(BaseClientConfig):
    name: Literal["convergence"] = "convergence"
    model_id: str = "convergence-ai/proxy-lite-7b"
    api_base: str = "http://localhost:8000/v1"
    api_key: str = "none"


class ConvergenceClient(OpenAIClient):
    config: ConvergenceClientConfig
    serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()
    _model_validated: bool = False

    async def _validate_model(self) -> None:
        try:
            response = await self.external_client.models.list()
            assert self.config.model_id in [model.id for model in response.data], (
                f"Model {self.config.model_id} not found in {response.data}"
            )
            self._model_validated = True
            logger.debug(f"Model {self.config.model_id} validated and connected to cluster")
        except Exception as e:
            logger.error(f"Error retrieving model: {e}")
            raise e

    @cached_property
    def external_client(self) -> AsyncOpenAI:
        return AsyncOpenAI(
            api_key=self.config.api_key,
            base_url=self.config.api_base,
            http_client=self.http_client,
        )

    async def create_completion(
        self,
        messages: MessageHistory,
        temperature: float = 0.7,
        seed: Optional[int] = None,
        tools: Optional[list[Tool]] = None,
        response_format: Optional[type[BaseModel]] = None,
    ) -> ChatCompletion:
        if not self._model_validated:
            await self._validate_model()
        base_params = {
            "model": self.config.model_id,
            "messages": self.serializer.serialize_messages(messages),
            "temperature": temperature,
        }
        optional_params = {
            "seed": seed,
            "tools": self.serializer.serialize_tools(tools) if tools else None,
            "tool_choice": "auto" if tools else None,  # vLLM does not support "required"
            "response_format": response_format if response_format else {"type": "text"},
        }
        base_params.update({k: v for k, v in optional_params.items() if v is not None})
        return await self.external_client.chat.completions.create(**base_params)


class GeminiClientConfig(BaseClientConfig):
    name: Literal["gemini"] = "gemini"
    model_id: str = "gemini-2.0-flash-001"
    api_key: str = ""


class GeminiClient(BaseClient):
    config: GeminiClientConfig
    serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer()

    def _convert_messages_to_gemini_format(self, messages):
        """Convert OpenAI format messages to Gemini format"""
        gemini_parts = []
        for msg in messages:
            if msg["role"] == "user":
                gemini_parts.append({"text": msg["content"]})
            elif msg["role"] == "assistant":
                gemini_parts.append({"text": msg["content"]})
            # Skip system messages or add them to the first user message
        return gemini_parts
    
    def _clean_schema_for_gemini(self, schema):
        """Clean up JSON schema for Gemini function calling - remove $defs and $ref"""
        if not isinstance(schema, dict):
            return schema
        
        cleaned = {}
        for key, value in schema.items():
            if key == "$defs":
                # Skip $defs - we'll inline the definitions
                continue
            elif key == "$ref":
                # Skip $ref - we'll inline the referenced schema
                continue
            elif isinstance(value, dict):
                cleaned[key] = self._clean_schema_for_gemini(value)
            elif isinstance(value, list):
                cleaned[key] = [self._clean_schema_for_gemini(item) for item in value]
            else:
                cleaned[key] = value
        
        # If we have $defs, we need to inline them
        if "$defs" in schema:
            cleaned = self._inline_definitions(cleaned, schema["$defs"])
        
        return cleaned
    
    def _inline_definitions(self, schema, definitions):
        """Inline $ref definitions into the schema"""
        if not isinstance(schema, dict):
            return schema
        
        if "$ref" in schema:
            # Extract the reference name (e.g., "#/$defs/TypeEntry" -> "TypeEntry")
            ref_name = schema["$ref"].split("/")[-1]
            if ref_name in definitions:
                # Replace the $ref with the actual definition
                return self._inline_definitions(definitions[ref_name], definitions)
            else:
                # If we can't find the definition, remove the $ref
                return {k: v for k, v in schema.items() if k != "$ref"}
        
        # Recursively process nested objects
        inlined = {}
        for key, value in schema.items():
            if isinstance(value, dict):
                inlined[key] = self._inline_definitions(value, definitions)
            elif isinstance(value, list):
                inlined[key] = [self._inline_definitions(item, definitions) for item in value]
            else:
                inlined[key] = value
        
        return inlined

    async def create_completion(
        self,
        messages: MessageHistory,
        temperature: float = 0.7,
        seed: Optional[int] = None,
        tools: Optional[list[Tool]] = None,
        response_format: Optional[type[BaseModel]] = None,
    ) -> ChatCompletion:
        import json
        from openai.types.chat.chat_completion import ChatCompletion, Choice
        from openai.types.chat.chat_completion_message import ChatCompletionMessage
        from openai.types.completion_usage import CompletionUsage
        from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
        
        # Convert messages to format expected by Gemini
        serialized_messages = self.serializer.serialize_messages(messages)
        
        # For Gemini API, we need to format contents correctly with proper roles
        contents = []
        current_user_text = ""
        
        for msg in serialized_messages:
            # Extract the actual text content from the serialized message
            content_text = ""
            if isinstance(msg["content"], list):
                # Handle complex content format
                for item in msg["content"]:
                    if isinstance(item, dict) and "text" in item:
                        content_text += item["text"]
                    elif isinstance(item, str):
                        content_text += item
            elif isinstance(msg["content"], str):
                content_text = msg["content"]
            
            if msg["role"] == "user":
                # Accumulate user messages
                current_user_text += content_text + "\n"
            elif msg["role"] == "assistant":
                # If we have accumulated user text, add it first
                if current_user_text.strip():
                    contents.append({
                        "role": "user",
                        "parts": [{"text": current_user_text.strip()}]
                    })
                    current_user_text = ""
                
                # Add assistant message with role "model"
                contents.append({
                    "role": "model", 
                    "parts": [{"text": content_text}]
                })
            elif msg["role"] == "tool":
                # Add tool messages as user messages so they're included in context
                # Format tool message more clearly for the agent to understand
                current_user_text += f"[ACTION COMPLETED] {content_text}\n"
        
        # Add any remaining user text
        if current_user_text.strip():
            contents.append({
                "role": "user",
                "parts": [{"text": current_user_text.strip()}]
            })
        
        payload = {
            "contents": contents,
            "generationConfig": {
                "temperature": temperature,
            }
        }
        
        # Add function calling support if tools are provided
        if tools:
            # Convert tools to Gemini function declaration format
            function_declarations = []
            for tool in tools:
                for tool_schema in tool.schema:
                    # Clean up the schema for Gemini - remove $defs and $ref
                    cleaned_parameters = self._clean_schema_for_gemini(tool_schema["parameters"])
                    function_declarations.append({
                        "name": tool_schema["name"],
                        "description": tool_schema["description"],
                        "parameters": cleaned_parameters
                    })
            
            payload["tools"] = [{
                "function_declarations": function_declarations
            }]
        
        # Make direct HTTP request to native Gemini API
        url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.config.model_id}:generateContent?key={self.config.api_key}"
        
        response = await self.http_client.post(
            url,
            json=payload,
            headers={"Content-Type": "application/json"}
        )
        
        response.raise_for_status()
        response_data = response.json()
        
        # Convert Gemini response to OpenAI ChatCompletion format
        if "candidates" in response_data and len(response_data["candidates"]) > 0:
            candidate = response_data["candidates"][0]
            
            # Extract text from response
            content = ""
            tool_calls = []
            
            if "content" in candidate and "parts" in candidate["content"]:
                for part in candidate["content"]["parts"]:
                    if "text" in part:
                        content += part["text"]
                    elif "functionCall" in part:
                        # Handle function call
                        func_call = part["functionCall"]
                        tool_call = ChatCompletionMessageToolCall(
                            id=f"call_{hash(str(func_call))}"[:16],
                            type="function",
                            function=Function(
                                name=func_call["name"],
                                arguments=json.dumps(func_call.get("args", {}))
                            )
                        )
                        tool_calls.append(tool_call)
            
            choice = Choice(
                index=0,
                message=ChatCompletionMessage(
                    role="assistant",
                    content=content if content else None,
                    tool_calls=tool_calls if tool_calls else None
                ),
                finish_reason="stop"
            )
            
            # Create a mock ChatCompletion response
            completion = ChatCompletion(
                id="gemini-" + str(hash(content))[:8],
                choices=[choice],
                created=int(__import__('time').time()),
                model=self.config.model_id,
                object="chat.completion",
                usage=CompletionUsage(
                    completion_tokens=len(content.split()),
                    prompt_tokens=sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages),
                    total_tokens=len(content.split()) + sum(len(str(msg.get("content", "")).split()) for msg in serialized_messages)
                )
            )
            
            return completion
        else:
            raise Exception(f"No valid response from Gemini API: {response_data}")


ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig, GeminiClientConfig]
ClientTypes = Union[OpenAIClient, ConvergenceClient, GeminiClient]